"""PostgreSQL implementation of CollectionRepository. This module provides the PostgreSQL-specific implementation of the CollectionRepository protocol using SQLAlchemy async sessions. The implementation uses PostgreSQL's ON CONFLICT for efficient upserts. Example: async with get_db_session() as db: repo = PostgresCollectionRepository(db) entries = await repo.get_all(user_id) """ from datetime import UTC, datetime from uuid import UUID from sqlalchemy import delete, select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession from app.db.models.collection import CardSource, Collection from app.repositories.protocols import CollectionEntry def _to_dto(model: Collection) -> CollectionEntry: """Convert ORM model to DTO.""" return CollectionEntry( id=model.id, user_id=model.user_id, card_definition_id=model.card_definition_id, quantity=model.quantity, source=model.source, obtained_at=model.obtained_at, created_at=model.created_at, updated_at=model.updated_at, ) class PostgresCollectionRepository: """PostgreSQL implementation of CollectionRepository. Uses SQLAlchemy async sessions for database access. All operations commit immediately for simplicity - transaction management should be handled at the service layer if needed. Attributes: _db: The async database session. """ def __init__(self, db: AsyncSession) -> None: """Initialize with database session. Args: db: SQLAlchemy async session. """ self._db = db async def get_all(self, user_id: UUID) -> list[CollectionEntry]: """Get all collection entries for a user. Args: user_id: The user's UUID. Returns: List of all collection entries, ordered by card_definition_id. """ result = await self._db.execute( select(Collection) .where(Collection.user_id == user_id) .order_by(Collection.card_definition_id) ) return [_to_dto(model) for model in result.scalars().all()] async def get_by_card(self, user_id: UUID, card_definition_id: str) -> CollectionEntry | None: """Get a specific collection entry. Args: user_id: The user's UUID. card_definition_id: The card ID to look up. Returns: CollectionEntry if exists, None otherwise. """ result = await self._db.execute( select(Collection).where( Collection.user_id == user_id, Collection.card_definition_id == card_definition_id, ) ) model = result.scalar_one_or_none() return _to_dto(model) if model else None async def get_quantity(self, user_id: UUID, card_definition_id: str) -> int: """Get quantity of a specific card owned by user. Args: user_id: The user's UUID. card_definition_id: Card ID to check. Returns: Number of copies owned (0 if not owned). """ result = await self._db.execute( select(Collection.quantity).where( Collection.user_id == user_id, Collection.card_definition_id == card_definition_id, ) ) quantity = result.scalar_one_or_none() return quantity if quantity is not None else 0 async def upsert( self, user_id: UUID, card_definition_id: str, quantity: int, source: CardSource, ) -> CollectionEntry: """Add or update a collection entry using PostgreSQL ON CONFLICT. If entry exists, increments quantity. Otherwise creates new entry. Args: user_id: The user's UUID. card_definition_id: Card ID to add. quantity: Number of copies to add. source: How the cards were obtained. Returns: The created or updated CollectionEntry. """ now = datetime.now(UTC) stmt = pg_insert(Collection).values( user_id=user_id, card_definition_id=card_definition_id, quantity=quantity, source=source, obtained_at=now, ) stmt = stmt.on_conflict_do_update( constraint="uq_collection_user_card", set_={ "quantity": Collection.quantity + quantity, "updated_at": now, }, ) await self._db.execute(stmt) await self._db.commit() # Fetch and return the updated entry entry = await self.get_by_card(user_id, card_definition_id) return entry # type: ignore[return-value] async def decrement( self, user_id: UUID, card_definition_id: str, quantity: int, ) -> CollectionEntry | None: """Decrement quantity of a collection entry. If quantity reaches 0 or below, deletes the entry. Args: user_id: The user's UUID. card_definition_id: Card ID to decrement. quantity: Number of copies to remove. Returns: Updated entry, or None if entry was deleted or didn't exist. """ # Get current entry result = await self._db.execute( select(Collection).where( Collection.user_id == user_id, Collection.card_definition_id == card_definition_id, ) ) model = result.scalar_one_or_none() if model is None: return None new_quantity = model.quantity - quantity if new_quantity <= 0: # Delete the entry await self._db.execute( delete(Collection).where( Collection.user_id == user_id, Collection.card_definition_id == card_definition_id, ) ) await self._db.commit() return None # Update quantity model.quantity = new_quantity await self._db.commit() await self._db.refresh(model) return _to_dto(model) async def exists_with_source(self, user_id: UUID, source: CardSource) -> bool: """Check if user has any entries with the given source. Args: user_id: The user's UUID. source: The CardSource to check for. Returns: True if any entries exist with that source. """ result = await self._db.execute( select(Collection.id) .where( Collection.user_id == user_id, Collection.source == source, ) .limit(1) ) return result.scalar_one_or_none() is not None