109 lines
3.0 KiB
Python
109 lines
3.0 KiB
Python
"""数据库引擎、会话工厂、初始化。"""
|
|
|
|
from sqlalchemy import event, create_engine, text
|
|
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
|
|
|
from app.config import settings
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
|
|
# ── FTS5 和索引 DDL(与 ORM 模型分开管理)───────────────────────────────
|
|
|
|
FTS5_CREATE_SQL = """
|
|
CREATE VIRTUAL TABLE IF NOT EXISTS papers_fts USING fts5(
|
|
title_en,
|
|
title_zh,
|
|
abstract,
|
|
authors,
|
|
tags,
|
|
summary_text,
|
|
tokenize='unicode61'
|
|
);
|
|
"""
|
|
|
|
FTS5_TRIGGER_INDEX = """
|
|
-- partial index for task_locks running
|
|
CREATE UNIQUE INDEX IF NOT EXISTS uq_task_locks_running
|
|
ON task_locks(task, lock_key) WHERE status = 'running';
|
|
"""
|
|
|
|
|
|
def _make_engine():
|
|
"""创建 SQLite 引擎,启用 foreign_keys。"""
|
|
engine = create_engine(
|
|
settings.DATABASE_URL,
|
|
echo=settings.APP_DEBUG,
|
|
connect_args={"check_same_thread": False},
|
|
)
|
|
|
|
@event.listens_for(engine, "connect")
|
|
def _set_sqlite_pragma(dbapi_connection, _connection_record):
|
|
cursor = dbapi_connection.cursor()
|
|
cursor.execute("PRAGMA foreign_keys=ON")
|
|
cursor.execute("PRAGMA journal_mode=WAL")
|
|
cursor.close()
|
|
|
|
return engine
|
|
|
|
|
|
engine = _make_engine()
|
|
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
|
|
|
|
|
def get_db():
|
|
"""FastAPI 依赖注入:获取数据库会话。"""
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def _migrate(engine) -> None:
|
|
"""自动给已有表补齐缺失的列(SQLite ALTER TABLE ADD COLUMN)。"""
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 定义需要确保存在的列:{表名: [(列名, 列类型 SQL), ...]}
|
|
_MIGRATIONS: dict[str, list[tuple[str, str]]] = {
|
|
"paper_summaries": [
|
|
("figures_json", "TEXT"),
|
|
],
|
|
"crawl_logs": [
|
|
("details_json", "TEXT"),
|
|
],
|
|
}
|
|
|
|
with engine.connect() as conn:
|
|
for table, columns in _MIGRATIONS.items():
|
|
# 获取已有列名
|
|
existing = {
|
|
row[1]
|
|
for row in conn.execute(text(f"PRAGMA table_info({table})"))
|
|
}
|
|
for col_name, col_type in columns:
|
|
if col_name not in existing:
|
|
conn.execute(
|
|
text(
|
|
f"ALTER TABLE {table} ADD COLUMN {col_name} {col_type}"
|
|
)
|
|
)
|
|
logger.info("Migrated: %s.%s added", table, col_name)
|
|
conn.commit()
|
|
|
|
|
|
def init_db(engine):
|
|
"""创建所有 ORM 表 + FTS5 虚拟表 + 自动迁移。"""
|
|
from app.models import Base # noqa: F811 — 避免循环导入,延迟导入
|
|
|
|
Base.metadata.create_all(engine)
|
|
with engine.connect() as conn:
|
|
conn.execute(text(FTS5_CREATE_SQL))
|
|
conn.execute(text(FTS5_TRIGGER_INDEX))
|
|
conn.commit()
|
|
_migrate(engine)
|