Files
Rain-Bus 21f16e6756 feat: refactor summarizer and PDF extraction pipeline
- Split summarizer into summary_generator and summary_persister modules
- Refactor pdf_image_extractor to two-phase pipeline with PicoDet layout detection
- Add layout_detector service for PicoDet-S_layout_3cls integration
- Add exceptions module with ConflictError and NotFoundError
- Improve admin dashboard with better statistics and task management
- Add design review document with system optimization suggestions
- Add new tests for crawler, pdf_downloader, pipeline, and summary_utils
- Update dependencies and configuration
- Clean up dead code and improve error handling
2026-06-13 13:16:47 +08:00

206 lines
6.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""公共工具 — 消除各模块间的重复代码。"""
from __future__ import annotations
import json
from datetime import date, datetime, timedelta, timezone
from pathlib import Path
from typing import Any
from zoneinfo import ZoneInfo
import bleach
import httpx
from fastapi.templating import Jinja2Templates
from app.config import settings
# ── 路径常量 ──────────────────────────────────────────────────────────
DATA_DIR = Path("data")
PAPERS_DIR = DATA_DIR / "papers"
TMP_DIR = DATA_DIR / "tmp"
# ── 模板单例 ──────────────────────────────────────────────────────────
class _Templates(Jinja2Templates):
"""自动注入 is_admin 到模板上下文的 Jinja2Templates 子类。"""
def TemplateResponse(self, request, name, context=None, **kwargs):
context = context or {}
context.setdefault("is_admin", request.session.get("is_admin", False))
return super().TemplateResponse(request, name, context, **kwargs)
templates = _Templates(directory="app/templates")
# ── 时区工具 ──────────────────────────────────────────────────────────
def utc_now() -> datetime:
"""当前 UTC 时间(替代 datetime.now(timezone.utc) 的简写)。"""
return datetime.now(timezone.utc)
def today_str() -> str:
"""当前日期字符串(按 APP_TIMEZONE)。"""
tz = ZoneInfo(settings.APP_TIMEZONE)
return datetime.now(tz).strftime("%Y-%m-%d")
def yesterday_str() -> str:
"""昨天日期字符串(按 APP_TIMEZONE)。"""
tz = ZoneInfo(settings.APP_TIMEZONE)
yesterday = datetime.now(tz).date() - timedelta(days=1)
return yesterday.isoformat()
def recent_date_strs(n: int) -> list[str]:
"""最近 N 天的日期字符串列表(含今天,按 APP_TIMEZONE)。"""
tz = ZoneInfo(settings.APP_TIMEZONE)
today = datetime.now(tz).date()
return [(today - timedelta(days=i)).isoformat() for i in range(n)]
def latest_paper_date(db) -> str:
"""查询数据库中最新的 paper_date,无数据时回退到 today_str()。"""
from sqlalchemy import func, select
from app.models import Paper
result = db.scalar(select(func.max(Paper.paper_date)))
if result is not None:
return result.isoformat() if isinstance(result, date) else str(result)
return today_str()
# ── 锁释放 ────────────────────────────────────────────────────────────
def release_lock(db, lock) -> None:
"""释放 TaskLock。"""
try:
lock.status = "finished"
lock.released_at = utc_now()
db.commit()
except Exception:
db.rollback()
# ── HTTP 客户端工厂 ───────────────────────────────────────────────────
def make_http_client(
*, sync: bool = False, follow_redirects: bool = False, **kwargs
) -> httpx.AsyncClient | httpx.Client:
"""创建带 proxy 和默认配置的 httpx 客户端。
Args:
sync: True 返回同步 ClientFalse 返回 AsyncClient
follow_redirects: 是否跟随重定向
**kwargs: 覆盖默认参数
"""
defaults: dict = {
"timeout": settings.HTTP_TIMEOUT_SECONDS,
"headers": {"User-Agent": settings.HTTP_USER_AGENT},
"follow_redirects": follow_redirects,
}
if settings.http_proxy:
defaults["transport"] = (
httpx.HTTPTransport(proxy=settings.http_proxy)
if sync
else httpx.AsyncHTTPTransport(proxy=settings.http_proxy)
)
defaults.update(kwargs)
if sync:
return httpx.Client(**defaults)
return httpx.AsyncClient(**defaults)
# ── JSON 安全解析 ──────────────────────────────────────────────────────
def safe_json_loads(text: str | None, default: Any = None) -> Any:
"""安全解析 JSON 字符串,解析失败返回 default 值(不会抛异常)。"""
if not text:
return default
try:
return json.loads(text)
except (json.JSONDecodeError, TypeError, ValueError):
return default
# ── HTML 清洗 ──────────────────────────────────────────────────────────
# AI 生成内容中允许的 HTML 标签和属性
_ALLOWED_TAGS = {
"p",
"br",
"strong",
"b",
"em",
"i",
"u",
"s",
"del",
"h3",
"h4",
"h5",
"h6",
"ul",
"ol",
"li",
"a",
"code",
"pre",
"blockquote",
"table",
"thead",
"tbody",
"tr",
"th",
"td",
"sup",
"sub",
"span",
}
_ALLOWED_ATTRS = {
"a": {"href", "title"},
"th": {"colspan", "rowspan"},
"td": {"colspan", "rowspan"},
"span": {"class"},
}
def sanitize_html(text: str | None) -> str:
"""清洗 AI 生成的 HTML,移除危险标签但保留安全的富文本。
- 移除: <script>, <iframe>, on* 事件属性, javascript: 链接
- 保留: 段落、加粗、列表、表格、链接等排印元素
"""
if not text:
return ""
cleaned = bleach.clean(
text,
tags=_ALLOWED_TAGS,
attributes=_ALLOWED_ATTRS,
strip=True,
)
return cleaned
# ── 错误消息截断 ────────────────────────────────────────────────────────
_ERROR_TRUNCATE_LIMIT = 500
def truncate_error(exc: Exception | str, limit: int = _ERROR_TRUNCATE_LIMIT) -> str:
"""将异常或字符串截断到指定长度,保持统一的错误消息格式。"""
text = str(exc)
if len(text) <= limit:
return text
return text[:limit] + f"... ({len(text)} chars total)"