"""Database session management for async SQLAlchemy. This module provides async database engine creation and session management with proper connection pooling and cleanup. Usage: # As async context manager (recommended) async with get_session() as session: result = await session.execute(select(User)) users = result.scalars().all() # For FastAPI dependency injection async def get_user( user_id: str, session: AsyncSession = Depends(get_session) ): return await session.get(User, user_id) Lifecycle: 1. Call init_db() on application startup 2. Use get_session() for database operations 3. Call close_db() on application shutdown """ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import TYPE_CHECKING from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) from app.config import settings if TYPE_CHECKING: pass # Global engine instance (initialized by init_db) _engine: AsyncEngine | None = None _session_factory: async_sessionmaker[AsyncSession] | None = None def get_engine() -> AsyncEngine: """Get the async database engine. Returns: The initialized AsyncEngine instance. Raises: RuntimeError: If init_db() has not been called. """ if _engine is None: raise RuntimeError("Database engine not initialized. Call init_db() first.") return _engine async def init_db() -> AsyncEngine: """Initialize the database engine and session factory. Creates an async engine with connection pooling configured based on application settings. Should be called once during application startup. Returns: The initialized AsyncEngine instance. Example: @app.on_event("startup") async def startup(): await init_db() """ global _engine, _session_factory if _engine is not None: return _engine _engine = create_async_engine( str(settings.database_url), pool_size=settings.database_pool_size, max_overflow=settings.database_max_overflow, echo=settings.database_echo, pool_pre_ping=True, # Verify connections before using ) _session_factory = async_sessionmaker( bind=_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, ) return _engine async def close_db() -> None: """Close database connections and dispose of the engine. Should be called during application shutdown to properly release database connections. Example: @app.on_event("shutdown") async def shutdown(): await close_db() """ global _engine, _session_factory if _engine is not None: await _engine.dispose() _engine = None _session_factory = None @asynccontextmanager async def get_session() -> AsyncGenerator[AsyncSession, None]: """Get an async database session. Provides a session as an async context manager that automatically handles commit/rollback and cleanup. Yields: An AsyncSession instance. Raises: RuntimeError: If init_db() has not been called. Example: async with get_session() as session: user = User(email="test@example.com") session.add(user) await session.commit() """ if _session_factory is None: raise RuntimeError("Session factory not initialized. Call init_db() first.") session = _session_factory() try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close() async def get_session_dependency() -> AsyncGenerator[AsyncSession, None]: """FastAPI dependency for database sessions. Use this with FastAPI's Depends() for automatic session management in route handlers. Yields: An AsyncSession instance. Example: @app.get("/users/{user_id}") async def get_user( user_id: str, session: AsyncSession = Depends(get_session_dependency) ): return await session.get(User, user_id) """ async with get_session() as session: yield session async def create_all_tables() -> None: """Create all database tables. Uses the Base metadata to create all registered tables. Useful for testing; use Alembic migrations in production. Note: This is primarily for testing. Production should use Alembic migrations for schema management. """ from app.db.base import Base engine = get_engine() async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) async def drop_all_tables() -> None: """Drop all database tables. WARNING: This will delete all data! Only use in testing. """ from app.db.base import Base engine = get_engine() async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all)