"""PostgreSQL implementation of UserRepository. This module provides the PostgreSQL-specific implementation of the UserRepository protocol using SQLAlchemy async sessions. Example: async with get_db_session() as db: repo = PostgresUserRepository(db) user = await repo.get_by_id(user_id) """ from datetime import UTC, datetime from uuid import UUID from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.db.models.user import User from app.repositories.protocols import UNSET, UserEntry def _to_dto(model: User) -> UserEntry: """Convert ORM model to DTO.""" return UserEntry( id=model.id, email=model.email, display_name=model.display_name, avatar_url=model.avatar_url, oauth_provider=model.oauth_provider, oauth_id=model.oauth_id, is_premium=model.is_premium, premium_until=model.premium_until, last_login=model.last_login, created_at=model.created_at, updated_at=model.updated_at, ) class PostgresUserRepository: """PostgreSQL implementation of UserRepository. Uses SQLAlchemy async sessions for database access. Attributes: _db: The async database session. """ def __init__(self, db: AsyncSession) -> None: """Initialize with database session. Args: db: SQLAlchemy async session. """ self._db = db async def get_by_id(self, user_id: UUID) -> UserEntry | None: """Get a user by their ID. Args: user_id: The user's UUID. Returns: UserEntry if found, None otherwise. """ result = await self._db.execute(select(User).where(User.id == user_id)) model = result.scalar_one_or_none() return _to_dto(model) if model else None async def get_by_email(self, email: str) -> UserEntry | None: """Get a user by their email address. Args: email: The user's email address. Returns: UserEntry if found, None otherwise. """ result = await self._db.execute(select(User).where(User.email == email)) model = result.scalar_one_or_none() return _to_dto(model) if model else None async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None: """Get a user by their OAuth provider and ID. Args: provider: OAuth provider name (google, discord). oauth_id: Unique ID from the OAuth provider. Returns: UserEntry if found, None otherwise. """ result = await self._db.execute( select(User).where( User.oauth_provider == provider, User.oauth_id == oauth_id, ) ) model = result.scalar_one_or_none() return _to_dto(model) if model else None async def create( self, email: str, display_name: str, oauth_provider: str, oauth_id: str, avatar_url: str | None = None, ) -> UserEntry: """Create a new user. Args: email: User's email address. display_name: Public display name. oauth_provider: OAuth provider name. oauth_id: Unique ID from the OAuth provider. avatar_url: Optional avatar URL. Returns: The created UserEntry. """ user = User( email=email, display_name=display_name, oauth_provider=oauth_provider, oauth_id=oauth_id, avatar_url=avatar_url, ) self._db.add(user) await self._db.commit() await self._db.refresh(user) return _to_dto(user) async def update( self, user_id: UUID, display_name: str | None = None, avatar_url: str | None = UNSET, # type: ignore[assignment] oauth_provider: str | None = None, oauth_id: str | None = None, ) -> UserEntry | None: """Update user profile fields. Only provided (non-None/non-UNSET) fields are updated. Args: user_id: The user's UUID. display_name: New display name (None keeps existing). avatar_url: New avatar URL (UNSET=keep, None=clear, str=set). oauth_provider: New OAuth provider (None keeps existing). oauth_id: New OAuth ID (None keeps existing). Returns: Updated UserEntry, or None if user not found. """ result = await self._db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None: return None if display_name is not None: user.display_name = display_name if avatar_url is not UNSET: user.avatar_url = avatar_url if oauth_provider is not None: user.oauth_provider = oauth_provider if oauth_id is not None: user.oauth_id = oauth_id await self._db.commit() await self._db.refresh(user) return _to_dto(user) async def update_last_login(self, user_id: UUID) -> UserEntry | None: """Update the user's last login timestamp to now. Args: user_id: The user's UUID. Returns: Updated UserEntry, or None if user not found. """ result = await self._db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None: return None user.last_login = datetime.now(UTC) await self._db.commit() await self._db.refresh(user) return _to_dto(user) async def update_premium( self, user_id: UUID, is_premium: bool, premium_until: datetime | None, ) -> UserEntry | None: """Update user's premium subscription status. Args: user_id: The user's UUID. is_premium: Whether user has premium. premium_until: When premium expires, or None if not premium. Returns: Updated UserEntry, or None if user not found. """ result = await self._db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None: return None user.is_premium = is_premium user.premium_until = premium_until await self._db.commit() await self._db.refresh(user) return _to_dto(user) async def delete(self, user_id: UUID) -> bool: """Delete a user account. This will cascade delete all related data (decks, collection, etc.) based on database constraints. Args: user_id: The user's UUID. Returns: True if deleted, False if not found. """ result = await self._db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None: return False await self._db.delete(user) await self._db.commit() return True