"""公共工具 — 消除各模块间的重复代码。""" from __future__ import annotations import json from datetime import date, datetime, timedelta, timezone from pathlib import Path from typing import Any from zoneinfo import ZoneInfo import bleach import httpx from fastapi.templating import Jinja2Templates from app.config import settings # ── 路径常量 ────────────────────────────────────────────────────────── DATA_DIR = Path("data") PAPERS_DIR = DATA_DIR / "papers" TMP_DIR = DATA_DIR / "tmp" # ── 模板单例 ────────────────────────────────────────────────────────── class _Templates(Jinja2Templates): """自动注入 is_admin 到模板上下文的 Jinja2Templates 子类。""" def TemplateResponse(self, request, name, context=None, **kwargs): context = context or {} context.setdefault("is_admin", request.session.get("is_admin", False)) return super().TemplateResponse(request, name, context, **kwargs) templates = _Templates(directory="app/templates") # ── 时区工具 ────────────────────────────────────────────────────────── def utc_now() -> datetime: """当前 UTC 时间(替代 datetime.now(timezone.utc) 的简写)。""" return datetime.now(timezone.utc) def today_str() -> str: """当前日期字符串(按 APP_TIMEZONE)。""" tz = ZoneInfo(settings.APP_TIMEZONE) return datetime.now(tz).strftime("%Y-%m-%d") def yesterday_str() -> str: """昨天日期字符串(按 APP_TIMEZONE)。""" tz = ZoneInfo(settings.APP_TIMEZONE) yesterday = datetime.now(tz).date() - timedelta(days=1) return yesterday.isoformat() def recent_date_strs(n: int) -> list[str]: """最近 N 天的日期字符串列表(含今天,按 APP_TIMEZONE)。""" tz = ZoneInfo(settings.APP_TIMEZONE) today = datetime.now(tz).date() return [(today - timedelta(days=i)).isoformat() for i in range(n)] def latest_paper_date(db) -> str: """查询数据库中最新的 paper_date,无数据时回退到 today_str()。""" from sqlalchemy import func, select from app.models import Paper result = db.scalar(select(func.max(Paper.paper_date))) if result is not None: return result.isoformat() if isinstance(result, date) else str(result) return today_str() # ── 锁释放 ──────────────────────────────────────────────────────────── def release_lock(db, lock) -> None: """释放 TaskLock。""" try: lock.status = "finished" lock.released_at = utc_now() db.commit() except Exception: db.rollback() # ── HTTP 客户端工厂 ─────────────────────────────────────────────────── def make_http_client( *, sync: bool = False, follow_redirects: bool = False, **kwargs ) -> httpx.AsyncClient | httpx.Client: """创建带 proxy 和默认配置的 httpx 客户端。 Args: sync: True 返回同步 Client,False 返回 AsyncClient follow_redirects: 是否跟随重定向 **kwargs: 覆盖默认参数 """ defaults: dict = { "timeout": settings.HTTP_TIMEOUT_SECONDS, "headers": {"User-Agent": settings.HTTP_USER_AGENT}, "follow_redirects": follow_redirects, } if settings.http_proxy: defaults["transport"] = ( httpx.HTTPTransport(proxy=settings.http_proxy) if sync else httpx.AsyncHTTPTransport(proxy=settings.http_proxy) ) defaults.update(kwargs) if sync: return httpx.Client(**defaults) return httpx.AsyncClient(**defaults) # ── JSON 安全解析 ────────────────────────────────────────────────────── def safe_json_loads(text: str | None, default: Any = None) -> Any: """安全解析 JSON 字符串,解析失败返回 default 值(不会抛异常)。""" if not text: return default try: return json.loads(text) except (json.JSONDecodeError, TypeError, ValueError): return default # ── HTML 清洗 ────────────────────────────────────────────────────────── # AI 生成内容中允许的 HTML 标签和属性 _ALLOWED_TAGS = { "p", "br", "strong", "b", "em", "i", "u", "s", "del", "h3", "h4", "h5", "h6", "ul", "ol", "li", "a", "code", "pre", "blockquote", "table", "thead", "tbody", "tr", "th", "td", "sup", "sub", "span", } _ALLOWED_ATTRS = { "a": {"href", "title"}, "th": {"colspan", "rowspan"}, "td": {"colspan", "rowspan"}, "span": {"class"}, } def sanitize_html(text: str | None) -> str: """清洗 AI 生成的 HTML,移除危险标签但保留安全的富文本。 - 移除: