"""Async model factories for Mantimon TCG database tests. This module provides factory classes for creating test data with sensible defaults. Unlike factory_boy, these factories are designed for async SQLAlchemy sessions. Usage: @pytest.mark.asyncio async def test_something(db_session): # Create with defaults user = await UserFactory.create(db_session) # Create with overrides premium_user = await UserFactory.create( db_session, is_premium=True, display_name="VIP Player" ) # Create batch users = await UserFactory.create_batch(db_session, count=5) # Create with relationships deck = await DeckFactory.create_with_user(db_session, card_count=40) Design Principles: - Each factory has sensible defaults for all required fields - Unique fields use counters or UUIDs to avoid conflicts - Factories return persisted objects (flushed, with IDs) - Relationship helpers create associated objects automatically """ from datetime import UTC, datetime, timedelta from typing import Generic, TypeVar from uuid import uuid4 from sqlalchemy.ext.asyncio import AsyncSession from app.db.models import ( ActiveGame, CampaignProgress, CardSource, Collection, Deck, EndReason, GameHistory, GameType, User, ) T = TypeVar("T") class AsyncFactory(Generic[T]): # noqa: UP046 - Using Generic for Python 3.11 compat """Base class for async model factories. Provides common methods for creating and persisting model instances. Subclasses must define `model` and implement `get_defaults()`. """ model: type[T] _counter: int = 0 @classmethod def _next_counter(cls) -> int: """Get the next unique counter value for this factory.""" cls._counter += 1 return cls._counter @classmethod async def create( cls, session: AsyncSession, **overrides, ) -> T: """Create and persist a model instance. Args: session: Async SQLAlchemy session. **overrides: Field values to override defaults. Returns: Persisted model instance with generated ID. """ defaults = cls.get_defaults() defaults.update(overrides) instance = cls.model(**defaults) session.add(instance) await session.flush() await session.refresh(instance) return instance @classmethod async def create_batch( cls, session: AsyncSession, count: int, **overrides, ) -> list[T]: """Create multiple model instances. Args: session: Async SQLAlchemy session. count: Number of instances to create. **overrides: Field values to override defaults (same for all). Returns: List of persisted model instances. """ return [await cls.create(session, **overrides) for _ in range(count)] @classmethod def get_defaults(cls) -> dict: """Get default field values for the model. Override in subclasses to provide model-specific defaults. Returns: Dictionary of field name -> default value. """ raise NotImplementedError("Subclasses must implement get_defaults()") class UserFactory(AsyncFactory[User]): """Factory for creating test User instances. Defaults: - Unique email based on counter - Google OAuth provider - Non-premium account - No last login Example: user = await UserFactory.create(db_session) premium = await UserFactory.create(db_session, is_premium=True) """ model = User @classmethod def get_defaults(cls) -> dict: counter = cls._next_counter() return { "email": f"testuser{counter}@example.com", "display_name": f"Test User {counter}", "avatar_url": f"https://example.com/avatars/{counter}.png", "oauth_provider": "google", "oauth_id": f"google_{uuid4().hex}", "is_premium": False, "premium_until": None, "last_login": None, } @classmethod async def create_premium( cls, session: AsyncSession, days_remaining: int = 30, **overrides, ) -> User: """Create a premium user with active subscription. Args: session: Async SQLAlchemy session. days_remaining: Days until premium expires. **overrides: Additional field overrides. Returns: Premium user instance. """ premium_until = datetime.now(UTC) + timedelta(days=days_remaining) return await cls.create( session, is_premium=True, premium_until=premium_until, **overrides, ) class CollectionFactory(AsyncFactory[Collection]): """Factory for creating test Collection instances. Defaults: - Card from booster pack - Quantity of 1 - Current timestamp for obtained_at Note: Requires a user_id to be provided or use create_for_user(). Example: user = await UserFactory.create(db_session) card = await CollectionFactory.create(db_session, user_id=user.id) """ model = Collection @classmethod def get_defaults(cls) -> dict: counter = cls._next_counter() return { "user_id": None, # Must be provided "card_definition_id": f"test_card_{counter:04d}", "quantity": 1, "source": CardSource.BOOSTER, "obtained_at": datetime.now(UTC), } @classmethod async def create_for_user( cls, session: AsyncSession, user: User, card_count: int = 1, source: CardSource = CardSource.BOOSTER, **overrides, ) -> list[Collection]: """Create collection entries for a user. Args: session: Async SQLAlchemy session. user: User to own the cards. card_count: Number of different cards to add. source: How cards were obtained. **overrides: Additional field overrides. Returns: List of Collection entries. """ entries = [] for _ in range(card_count): entry = await cls.create( session, user_id=user.id, card_definition_id=f"card_{uuid4().hex[:8]}", source=source, **overrides, ) entries.append(entry) return entries class DeckFactory(AsyncFactory[Deck]): """Factory for creating test Deck instances. Defaults: - Named "Test Deck N" - Empty cards and energy_cards JSONB - Invalid deck (is_valid=False) - Not a starter deck Note: Requires a user_id to be provided or use create_for_user(). Example: user = await UserFactory.create(db_session) deck = await DeckFactory.create(db_session, user_id=user.id) """ model = Deck @classmethod def get_defaults(cls) -> dict: counter = cls._next_counter() return { "user_id": None, # Must be provided "name": f"Test Deck {counter}", "cards": {}, # JSONB - empty by default "energy_cards": {}, # JSONB - empty by default "is_valid": False, "validation_errors": None, "is_starter": False, "starter_type": None, "description": None, } @classmethod async def create_for_user( cls, session: AsyncSession, user: User, **overrides, ) -> Deck: """Create a deck owned by a user. Args: session: Async SQLAlchemy session. user: User to own the deck. **overrides: Additional field overrides. Returns: Deck instance. """ return await cls.create(session, user_id=user.id, **overrides) @classmethod async def create_valid_deck( cls, session: AsyncSession, user: User, **overrides, ) -> Deck: """Create a valid deck with sample cards. Creates a deck that passes basic validation with: - 20 Pokemon cards - 10 Trainer cards - 10 Energy cards Args: session: Async SQLAlchemy session. user: User to own the deck. **overrides: Additional field overrides. Returns: Valid deck instance. """ cards = { "pikachu_base_001": 4, "raichu_base_001": 2, "charmander_base_001": 4, "charmeleon_base_001": 2, "charizard_base_001": 2, "bulbasaur_base_001": 4, "ivysaur_base_001": 2, "potion_001": 4, "professor_oak_001": 4, "pokeball_001": 2, } energy_cards = { "lightning": 4, "fire": 4, "grass": 2, } return await cls.create( session, user_id=user.id, cards=cards, energy_cards=energy_cards, is_valid=True, validation_errors=None, **overrides, ) @classmethod async def create_starter_deck( cls, session: AsyncSession, user: User, starter_type: str = "fire", **overrides, ) -> Deck: """Create a starter deck. Args: session: Async SQLAlchemy session. user: User to own the deck. starter_type: Type of starter (fire, water, grass, etc.). **overrides: Additional field overrides. Returns: Starter deck instance. """ return await cls.create( session, user_id=user.id, name=f"{starter_type.title()} Starter Deck", is_starter=True, starter_type=starter_type, is_valid=True, **overrides, ) class CampaignProgressFactory(AsyncFactory[CampaignProgress]): """Factory for creating test CampaignProgress instances. Defaults: - At grass_club - No medals or defeated NPCs - Zero wins/losses - Zero booster packs and mantibucks Note: Requires a user_id to be provided or use create_for_user(). Each user can only have one CampaignProgress (one-to-one). Example: user = await UserFactory.create(db_session) progress = await CampaignProgressFactory.create(db_session, user_id=user.id) """ model = CampaignProgress @classmethod def get_defaults(cls) -> dict: return { "user_id": None, # Must be provided (one-to-one) "current_club": "grass_club", "medals": [], # JSONB "defeated_npcs": [], # JSONB "total_wins": 0, "total_losses": 0, "booster_packs": 0, "mantibucks": 0, } @classmethod async def create_for_user( cls, session: AsyncSession, user: User, **overrides, ) -> CampaignProgress: """Create campaign progress for a user. Args: session: Async SQLAlchemy session. user: User to track progress for. **overrides: Additional field overrides. Returns: CampaignProgress instance. """ return await cls.create(session, user_id=user.id, **overrides) @classmethod async def create_advanced( cls, session: AsyncSession, user: User, medals_count: int = 4, **overrides, ) -> CampaignProgress: """Create campaign progress partway through the game. Args: session: Async SQLAlchemy session. user: User to track progress for. medals_count: Number of medals earned (0-8). **overrides: Additional field overrides. Returns: CampaignProgress instance with progress. """ medal_types = [ "grass_medal", "fire_medal", "water_medal", "lightning_medal", "psychic_medal", "fighting_medal", "science_medal", "rock_medal", ] medals = medal_types[:medals_count] npcs = [f"trainer_{i}" for i in range(medals_count * 3)] return await cls.create( session, user_id=user.id, medals=medals, defeated_npcs=npcs, total_wins=medals_count * 5, total_losses=medals_count * 2, booster_packs=medals_count, mantibucks=medals_count * 100, **overrides, ) class ActiveGameFactory(AsyncFactory[ActiveGame]): """Factory for creating test ActiveGame instances. Defaults: - Campaign game type - Turn 1 - Empty game state and rules config Note: Requires player1_id to be provided. Example: user = await UserFactory.create(db_session) game = await ActiveGameFactory.create(db_session, player1_id=user.id) """ model = ActiveGame @classmethod def get_defaults(cls) -> dict: return { "game_type": GameType.CAMPAIGN, "player1_id": None, # Must be provided "player2_id": None, # Optional for PvP "npc_id": "grass_trainer_1", # Default NPC opponent "rules_config": {"prize_count": 4, "deck_size": 40}, # JSONB "game_state": {"turn": 1, "phase": "main"}, # JSONB placeholder "turn_number": 1, "started_at": datetime.now(UTC), "last_action_at": datetime.now(UTC), "turn_deadline": None, } @classmethod async def create_campaign_game( cls, session: AsyncSession, player: User, npc_id: str = "grass_trainer_1", **overrides, ) -> ActiveGame: """Create a campaign game against an NPC. Args: session: Async SQLAlchemy session. player: Player in the campaign. npc_id: ID of the NPC opponent. **overrides: Additional field overrides. Returns: ActiveGame instance. """ return await cls.create( session, game_type=GameType.CAMPAIGN, player1_id=player.id, player2_id=None, npc_id=npc_id, **overrides, ) @classmethod async def create_pvp_game( cls, session: AsyncSession, player1: User, player2: User, game_type: GameType = GameType.FREEPLAY, **overrides, ) -> ActiveGame: """Create a PvP game between two players. Args: session: Async SQLAlchemy session. player1: First player. player2: Second player. game_type: FREEPLAY or RANKED. **overrides: Additional field overrides. Returns: ActiveGame instance. """ return await cls.create( session, game_type=game_type, player1_id=player1.id, player2_id=player2.id, npc_id=None, **overrides, ) class GameHistoryFactory(AsyncFactory[GameHistory]): """Factory for creating test GameHistory instances. Defaults: - Campaign game that player won - Ended by taking all prizes - 10 turns, 300 seconds duration Note: Requires player1_id to be provided. Example: user = await UserFactory.create(db_session) history = await GameHistoryFactory.create(db_session, player1_id=user.id) """ model = GameHistory @classmethod def get_defaults(cls) -> dict: return { "game_type": GameType.CAMPAIGN, "player1_id": None, # Must be provided "player2_id": None, "npc_id": "grass_trainer_1", "winner_id": None, # Set to player1_id for player win "winner_is_npc": False, "end_reason": EndReason.PRIZES_TAKEN, "turn_count": 10, "duration_seconds": 300, "replay_data": None, # JSONB "played_at": datetime.now(UTC), } @classmethod async def create_player_win( cls, session: AsyncSession, player: User, npc_id: str = "grass_trainer_1", **overrides, ) -> GameHistory: """Create a game history where the player won. Args: session: Async SQLAlchemy session. player: Winning player. npc_id: NPC opponent ID. **overrides: Additional field overrides. Returns: GameHistory instance. """ return await cls.create( session, game_type=GameType.CAMPAIGN, player1_id=player.id, winner_id=player.id, winner_is_npc=False, npc_id=npc_id, **overrides, ) @classmethod async def create_player_loss( cls, session: AsyncSession, player: User, npc_id: str = "grass_trainer_1", end_reason: EndReason = EndReason.NO_POKEMON, **overrides, ) -> GameHistory: """Create a game history where the player lost to NPC. Args: session: Async SQLAlchemy session. player: Losing player. npc_id: NPC opponent ID. end_reason: How the game ended. **overrides: Additional field overrides. Returns: GameHistory instance. """ return await cls.create( session, game_type=GameType.CAMPAIGN, player1_id=player.id, winner_id=None, winner_is_npc=True, npc_id=npc_id, end_reason=end_reason, **overrides, ) @classmethod async def create_pvp_game( cls, session: AsyncSession, player1: User, player2: User, winner: User | None = None, game_type: GameType = GameType.FREEPLAY, **overrides, ) -> GameHistory: """Create a PvP game history. Args: session: Async SQLAlchemy session. player1: First player. player2: Second player. winner: Winning player (None for draw). game_type: FREEPLAY or RANKED. **overrides: Additional field overrides. Returns: GameHistory instance. """ end_reason = EndReason.DRAW if winner is None else EndReason.PRIZES_TAKEN return await cls.create( session, game_type=game_type, player1_id=player1.id, player2_id=player2.id, npc_id=None, winner_id=winner.id if winner else None, winner_is_npc=False, end_reason=end_reason, **overrides, )