from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.orm import DeclarativeBase from sqlalchemy import text from config import settings class Base(DeclarativeBase): """SQLAlchemy ORM 基类,所有数据库模型均继承此类。""" pass # 异步数据库引擎,连接池大小 20,最大溢出 40,启用连接健康检查 async_engine = create_async_engine( settings.DATABASE_URL, pool_size=20, max_overflow=40, pool_pre_ping=True, pool_recycle=3600, echo=False, ) # 异步数据库会话工厂,用于创建数据库会话实例 AsyncSessionLocal = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, ) async def init_db(): """初始化数据库:创建所有表并执行增量迁移。""" async with async_engine.begin() as conn: from models import Base as MBase await conn.run_sync(MBase.metadata.create_all) await _run_migrations() async def _run_migrations(): """执行数据库增量迁移,在已有表上安全添加新字段。""" async with async_engine.begin() as conn: await conn.execute(text( "ALTER TABLE flow_definitions ADD COLUMN IF NOT EXISTS published_version_id UUID REFERENCES flow_versions(id)" )) await conn.execute(text( "ALTER TABLE flow_definitions ADD COLUMN IF NOT EXISTS draft_definition_json JSONB" )) async def get_db(): """FastAPI 依赖注入函数,提供数据库会话,自动提交或回滚事务。""" async with AsyncSessionLocal() as session: try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close()