diff --git a/backend/app/core/models/game_state.py b/backend/app/core/models/game_state.py index 38c1453..ea26d5c 100644 --- a/backend/app/core/models/game_state.py +++ b/backend/app/core/models/game_state.py @@ -379,6 +379,7 @@ class GameState(BaseModel): forced_actions: Queue of ForcedAction items that must be completed before game proceeds. Actions are processed in FIFO order (first added = first to resolve). action_log: Log of actions taken (for replays/debugging). + rng_seed: Optional seed for deterministic RNG. When set, enables replay capability. """ game_id: str @@ -412,6 +413,9 @@ class GameState(BaseModel): # Optional action log for replays action_log: list[dict[str, Any]] = Field(default_factory=list) + # Optional RNG seed for deterministic replays + rng_seed: int | None = None + def get_current_player(self) -> PlayerState: """Get the PlayerState for the current player. diff --git a/backend/app/main.py b/backend/app/main.py index 529e250..45f8a36 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -26,6 +26,7 @@ from app.config import settings from app.db import close_db, init_db from app.db.redis import close_redis, init_redis from app.services import get_card_service +from app.socketio import create_socketio_app logger = logging.getLogger(__name__) @@ -173,3 +174,16 @@ app.include_router(decks_router, prefix="/api") # app.include_router(cards.router, prefix="/api/cards", tags=["cards"]) # app.include_router(games.router, prefix="/api/games", tags=["games"]) # app.include_router(campaign.router, prefix="/api/campaign", tags=["campaign"]) + + +# === Socket.IO Integration === +# Wrap FastAPI with Socket.IO ASGI app for WebSocket support. +# The combined app handles: +# - WebSocket connections at /socket.io/ +# - HTTP requests passed through to FastAPI +# +# Note: The module-level 'app' variable is now the combined ASGI app. +# This is intentional - uvicorn imports 'app' and will serve both +# FastAPI HTTP routes and Socket.IO WebSocket connections. +_fastapi_app = app +app = create_socketio_app(_fastapi_app) diff --git a/backend/app/schemas/ws_messages.py b/backend/app/schemas/ws_messages.py new file mode 100644 index 0000000..ed03b57 --- /dev/null +++ b/backend/app/schemas/ws_messages.py @@ -0,0 +1,491 @@ +"""WebSocket message schemas for Mantimon TCG real-time communication. + +This module defines Pydantic models for all WebSocket message types, following +the discriminated union pattern from app/core/models/actions.py. Messages are +categorized into client-to-server and server-to-client types. + +Message Design Principles: + 1. All messages have a 'type' discriminator field for automatic parsing + 2. All messages have a 'message_id' for idempotency and event ordering + 3. Server messages include a 'timestamp' for client-side latency tracking + 4. Game-scoped messages include 'game_id' for multi-game support + +Example: + # Parse incoming client message + data = {"type": "join_game", "message_id": "abc-123", "game_id": "game-456"} + message = parse_client_message(data) + assert isinstance(message, JoinGameMessage) + + # Create server response + response = GameStateMessage( + message_id=str(uuid4()), + game_id="game-456", + state=visible_state, + ) + await websocket.send_text(response.model_dump_json()) +""" + +from datetime import UTC, datetime +from enum import StrEnum +from typing import Annotated, Any, Literal +from uuid import uuid4 + +from pydantic import BaseModel, Field, field_validator + +from app.core.enums import GameEndReason +from app.core.models.actions import Action +from app.core.visibility import VisibleGameState + + +class WSErrorCode(StrEnum): + """Error codes for WebSocket error messages. + + These codes help clients handle errors programmatically without + parsing error message text. + """ + + # Connection errors + AUTHENTICATION_FAILED = "authentication_failed" + CONNECTION_CLOSED = "connection_closed" + RATE_LIMITED = "rate_limited" + + # Game errors + GAME_NOT_FOUND = "game_not_found" + NOT_IN_GAME = "not_in_game" + ALREADY_IN_GAME = "already_in_game" + GAME_FULL = "game_full" + GAME_ENDED = "game_ended" + + # Action errors + INVALID_ACTION = "invalid_action" + NOT_YOUR_TURN = "not_your_turn" + ACTION_NOT_ALLOWED = "action_not_allowed" + + # Protocol errors + INVALID_MESSAGE = "invalid_message" + UNKNOWN_MESSAGE_TYPE = "unknown_message_type" + + # Server errors + INTERNAL_ERROR = "internal_error" + + +class ConnectionStatus(StrEnum): + """Connection status for opponent status messages.""" + + CONNECTED = "connected" + DISCONNECTED = "disconnected" + RECONNECTING = "reconnecting" + + +# ============================================================================= +# Base Message Classes +# ============================================================================= + + +def _generate_message_id() -> str: + """Generate a unique message ID.""" + return str(uuid4()) + + +def _utc_now() -> datetime: + """Get current UTC timestamp.""" + return datetime.now(UTC) + + +class BaseClientMessage(BaseModel): + """Base class for all client-to-server messages. + + All client messages must include a message_id for idempotency. The client + generates this ID to allow detecting and handling duplicate messages. + + Attributes: + message_id: Client-generated UUID for idempotency and tracking. + """ + + message_id: str = Field( + default_factory=_generate_message_id, + description="Client-generated UUID for idempotency", + ) + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, v: str) -> str: + """Ensure message_id is not empty.""" + if not v or not v.strip(): + raise ValueError("message_id cannot be empty") + return v + + +class BaseServerMessage(BaseModel): + """Base class for all server-to-client messages. + + All server messages include a message_id and timestamp. The timestamp + enables client-side latency calculation and event ordering. + + Attributes: + message_id: Server-generated UUID for tracking. + timestamp: UTC timestamp when the message was created. + """ + + message_id: str = Field( + default_factory=_generate_message_id, + description="Server-generated UUID for tracking", + ) + timestamp: datetime = Field( + default_factory=_utc_now, + description="UTC timestamp when message was created", + ) + + +# ============================================================================= +# Client -> Server Messages +# ============================================================================= + + +class JoinGameMessage(BaseClientMessage): + """Request to join or rejoin a game session. + + When rejoining after a disconnect, the client can provide last_event_id + to receive any missed events since that point. + + Attributes: + type: Discriminator field, always "join_game". + game_id: The game to join. + last_event_id: For reconnection - ID of last received event for replay. + """ + + type: Literal["join_game"] = "join_game" + game_id: str = Field(..., description="ID of the game to join") + last_event_id: str | None = Field( + default=None, + description="Last event ID received (for reconnection replay)", + ) + + +class ActionMessage(BaseClientMessage): + """Submit a game action for execution. + + Wraps the Action union type from app/core/models/actions.py to provide + consistent message envelope with message_id and game context. + + Attributes: + type: Discriminator field, always "action". + game_id: The game this action is for. + action: The game action to execute. + """ + + type: Literal["action"] = "action" + game_id: str = Field(..., description="ID of the game") + action: Action = Field(..., description="The game action to execute") + + +class ResignMessage(BaseClientMessage): + """Request to resign from a game. + + This is separate from ResignAction in the game engine to allow + resignation handling at the WebSocket layer (e.g., when the game + engine is unavailable). + + Attributes: + type: Discriminator field, always "resign". + game_id: The game to resign from. + """ + + type: Literal["resign"] = "resign" + game_id: str = Field(..., description="ID of the game to resign from") + + +class HeartbeatMessage(BaseClientMessage): + """Keep-alive message to maintain connection. + + Clients should send heartbeats periodically (e.g., every 30 seconds) + to prevent connection timeout. The server responds with a HeartbeatAck. + + Attributes: + type: Discriminator field, always "heartbeat". + """ + + type: Literal["heartbeat"] = "heartbeat" + + +# Union type for client messages +ClientMessage = Annotated[ + JoinGameMessage | ActionMessage | ResignMessage | HeartbeatMessage, + Field(discriminator="type"), +] + + +# ============================================================================= +# Server -> Client Messages +# ============================================================================= + + +class GameStateMessage(BaseServerMessage): + """Full game state update. + + Sent when a player joins a game or when a full state sync is needed. + Contains the complete visible game state from that player's perspective. + + Attributes: + type: Discriminator field, always "game_state". + game_id: The game this state is for. + state: The full visible game state. + event_id: Monotonic event ID for reconnection replay. + """ + + type: Literal["game_state"] = "game_state" + game_id: str = Field(..., description="ID of the game") + state: VisibleGameState = Field(..., description="Full visible game state") + event_id: str = Field( + default_factory=_generate_message_id, + description="Event ID for reconnection replay", + ) + + +class ActionResultMessage(BaseServerMessage): + """Result of a player action. + + Sent after processing an ActionMessage to confirm success or failure. + On success, includes changes that resulted from the action. + + Attributes: + type: Discriminator field, always "action_result". + game_id: The game this result is for. + request_message_id: The message_id of the original ActionMessage. + success: Whether the action succeeded. + action_type: The type of action that was attempted. + changes: Description of state changes (on success). + error_code: Error code if action failed. + error_message: Human-readable error message if failed. + """ + + type: Literal["action_result"] = "action_result" + game_id: str = Field(..., description="ID of the game") + request_message_id: str = Field(..., description="message_id of the original action request") + success: bool = Field(..., description="Whether the action succeeded") + action_type: str = Field(..., description="Type of action that was attempted") + changes: dict[str, Any] = Field( + default_factory=dict, + description="State changes resulting from the action", + ) + error_code: WSErrorCode | None = Field(default=None, description="Error code if action failed") + error_message: str | None = Field(default=None, description="Human-readable error message") + + +class ErrorMessage(BaseServerMessage): + """Error notification for protocol or connection errors. + + Used for errors not associated with a specific action, such as + invalid message format, authentication failures, or server errors. + + Attributes: + type: Discriminator field, always "error". + code: Machine-readable error code. + message: Human-readable error description. + details: Additional error context. + request_message_id: ID of the message that caused the error, if applicable. + """ + + type: Literal["error"] = "error" + code: WSErrorCode = Field(..., description="Error code") + message: str = Field(..., description="Human-readable error description") + details: dict[str, Any] = Field(default_factory=dict, description="Additional error context") + request_message_id: str | None = Field( + default=None, description="ID of message that caused error" + ) + + +class TurnStartMessage(BaseServerMessage): + """Notification that a new turn has started. + + Sent to both players when a turn begins. Indicates whose turn it is + and the current turn number. + + Attributes: + type: Discriminator field, always "turn_start". + game_id: The game this notification is for. + player_id: The player whose turn is starting. + turn_number: The current turn number. + event_id: Event ID for reconnection replay. + """ + + type: Literal["turn_start"] = "turn_start" + game_id: str = Field(..., description="ID of the game") + player_id: str = Field(..., description="Player whose turn is starting") + turn_number: int = Field(..., description="Current turn number", ge=1) + event_id: str = Field( + default_factory=_generate_message_id, + description="Event ID for reconnection replay", + ) + + +class TurnTimeoutMessage(BaseServerMessage): + """Timeout warning or expiration notification. + + Sent when a player's turn is approaching timeout (warning) or has + expired. Warnings give players time to complete their action. + + Attributes: + type: Discriminator field, always "turn_timeout". + game_id: The game this notification is for. + remaining_seconds: Seconds remaining before timeout. + is_warning: True if this is a warning, False if timeout has occurred. + player_id: The player whose turn is timing out. + """ + + type: Literal["turn_timeout"] = "turn_timeout" + game_id: str = Field(..., description="ID of the game") + remaining_seconds: int = Field(..., description="Seconds remaining before timeout", ge=0) + is_warning: bool = Field(..., description="True if warning, False if timeout occurred") + player_id: str = Field(..., description="Player whose turn is timing out") + + +class GameOverMessage(BaseServerMessage): + """Notification that the game has ended. + + Sent to all players when the game concludes, regardless of the reason. + + Attributes: + type: Discriminator field, always "game_over". + game_id: The game that ended. + winner_id: The winning player, or None for a draw. + end_reason: Why the game ended. + final_state: The final visible game state. + event_id: Event ID for reconnection replay. + """ + + type: Literal["game_over"] = "game_over" + game_id: str = Field(..., description="ID of the game that ended") + winner_id: str | None = Field(default=None, description="Winner player ID, or None for draw") + end_reason: GameEndReason = Field(..., description="Reason the game ended") + final_state: VisibleGameState = Field(..., description="Final visible game state") + event_id: str = Field( + default_factory=_generate_message_id, + description="Event ID for reconnection replay", + ) + + +class OpponentStatusMessage(BaseServerMessage): + """Notification of opponent connection status change. + + Sent when the opponent connects, disconnects, or is reconnecting. + Allows the UI to show connection status to the player. + + Attributes: + type: Discriminator field, always "opponent_status". + game_id: The game this status is for. + opponent_id: The opponent's player ID. + status: The opponent's current connection status. + """ + + type: Literal["opponent_status"] = "opponent_status" + game_id: str = Field(..., description="ID of the game") + opponent_id: str = Field(..., description="Opponent's player ID") + status: ConnectionStatus = Field(..., description="Opponent's connection status") + + +class HeartbeatAckMessage(BaseServerMessage): + """Acknowledgment of a client heartbeat. + + Sent in response to HeartbeatMessage to confirm the connection is alive. + + Attributes: + type: Discriminator field, always "heartbeat_ack". + """ + + type: Literal["heartbeat_ack"] = "heartbeat_ack" + + +# Union type for server messages +ServerMessage = Annotated[ + GameStateMessage + | ActionResultMessage + | ErrorMessage + | TurnStartMessage + | TurnTimeoutMessage + | GameOverMessage + | OpponentStatusMessage + | HeartbeatAckMessage, + Field(discriminator="type"), +] + + +# ============================================================================= +# Parsing Functions +# ============================================================================= + + +def parse_client_message(data: dict[str, Any]) -> ClientMessage: + """Parse an incoming client message from a dictionary. + + Uses the 'type' field to determine which message model to use. + + Args: + data: Dictionary containing message data with a 'type' field. + + Returns: + The appropriate ClientMessage subtype. + + Raises: + ValueError: If the message type is unknown. + ValidationError: If the data doesn't match the message schema. + + Example: + data = {"type": "join_game", "message_id": "abc", "game_id": "123"} + message = parse_client_message(data) + assert isinstance(message, JoinGameMessage) + """ + from pydantic import TypeAdapter + + adapter: TypeAdapter[ClientMessage] = TypeAdapter(ClientMessage) + return adapter.validate_python(data) + + +def parse_server_message(data: dict[str, Any]) -> ServerMessage: + """Parse a server message from a dictionary. + + Primarily used for testing - clients receive JSON and parse directly. + + Args: + data: Dictionary containing message data with a 'type' field. + + Returns: + The appropriate ServerMessage subtype. + + Raises: + ValueError: If the message type is unknown. + ValidationError: If the data doesn't match the message schema. + """ + from pydantic import TypeAdapter + + adapter: TypeAdapter[ServerMessage] = TypeAdapter(ServerMessage) + return adapter.validate_python(data) + + +__all__ = [ + # Enums + "ConnectionStatus", + "WSErrorCode", + # Base classes + "BaseClientMessage", + "BaseServerMessage", + # Client messages + "ActionMessage", + "ClientMessage", + "HeartbeatMessage", + "JoinGameMessage", + "ResignMessage", + # Server messages + "ActionResultMessage", + "ErrorMessage", + "GameOverMessage", + "GameStateMessage", + "HeartbeatAckMessage", + "OpponentStatusMessage", + "ServerMessage", + "TurnStartMessage", + "TurnTimeoutMessage", + # Parsing functions + "parse_client_message", + "parse_server_message", +] diff --git a/backend/app/services/connection_manager.py b/backend/app/services/connection_manager.py new file mode 100644 index 0000000..185d3a8 --- /dev/null +++ b/backend/app/services/connection_manager.py @@ -0,0 +1,578 @@ +"""WebSocket connection management with Redis-backed session tracking. + +This module manages WebSocket connection state using Redis for persistence +and cross-instance coordination. It tracks: +- Individual connection sessions (sid -> user_id, game_id, timestamps) +- User to connection mapping (user_id -> sid) +- Game participants (game_id -> set of sids) + +Key Patterns: + conn:{sid} - Hash with connection details (user_id, game_id, connected_at, last_seen) + user_conn:{user_id} - String with active sid + game_conns:{game_id} - Set of sids for game participants + +Lifecycle: + 1. on_connect: Register connection, map user to sid + 2. on_join_game: Add sid to game's connection set + 3. on_heartbeat: Update last_seen timestamp + 4. on_disconnect: Remove from all tracking structures + +Example: + manager = ConnectionManager() + + # On WebSocket connect + await manager.register_connection(sid, user_id) + + # On joining a game + await manager.join_game(sid, game_id) + + # Check if opponent is online + if await manager.is_user_online(opponent_id): + # Send real-time update + + # On disconnect + await manager.unregister_connection(sid) +""" + +import logging +from dataclasses import dataclass +from datetime import UTC, datetime +from uuid import UUID + +from app.db.redis import get_redis + +logger = logging.getLogger(__name__) + +# Redis key patterns +CONN_PREFIX = "conn:" +USER_CONN_PREFIX = "user_conn:" +GAME_CONNS_PREFIX = "game_conns:" + +# Connection TTL (auto-expire stale connections) +DEFAULT_CONN_TTL_SECONDS = 3600 # 1 hour +HEARTBEAT_INTERVAL_SECONDS = 30 + + +@dataclass +class ConnectionInfo: + """Information about a WebSocket connection. + + Attributes: + sid: Socket.IO session ID. + user_id: Authenticated user's UUID. + game_id: Current game ID if in a game, None otherwise. + connected_at: When the connection was established. + last_seen: Last heartbeat or activity timestamp. + """ + + sid: str + user_id: str + game_id: str | None + connected_at: datetime + last_seen: datetime + + def is_stale(self, threshold_seconds: int = HEARTBEAT_INTERVAL_SECONDS * 3) -> bool: + """Check if connection is stale (no recent activity). + + Args: + threshold_seconds: Seconds since last_seen to consider stale. + + Returns: + True if connection hasn't been seen recently. + """ + elapsed = (datetime.now(UTC) - self.last_seen).total_seconds() + return elapsed > threshold_seconds + + +class ConnectionManager: + """Manages WebSocket connections with Redis-backed session tracking. + + Provides methods to: + - Register/unregister connections + - Track which game a connection is participating in + - Check user online status + - Get all connections for a game + - Handle heartbeats and stale connection detection + + Attributes: + conn_ttl_seconds: TTL for connection records in Redis. + """ + + def __init__(self, conn_ttl_seconds: int = DEFAULT_CONN_TTL_SECONDS) -> None: + """Initialize the ConnectionManager. + + Args: + conn_ttl_seconds: How long to keep connection records in Redis. + """ + self.conn_ttl_seconds = conn_ttl_seconds + + def _conn_key(self, sid: str) -> str: + """Generate Redis key for a connection.""" + return f"{CONN_PREFIX}{sid}" + + def _user_conn_key(self, user_id: str) -> str: + """Generate Redis key for user-to-connection mapping.""" + return f"{USER_CONN_PREFIX}{user_id}" + + def _game_conns_key(self, game_id: str) -> str: + """Generate Redis key for game connection set.""" + return f"{GAME_CONNS_PREFIX}{game_id}" + + # ========================================================================= + # Connection Lifecycle + # ========================================================================= + + async def register_connection( + self, + sid: str, + user_id: str | UUID, + ) -> None: + """Register a new WebSocket connection. + + Creates connection record in Redis and maps user to this connection. + If user had a previous connection, it will be replaced. + + Args: + sid: Socket.IO session ID. + user_id: Authenticated user's ID (UUID or string). + + Example: + await manager.register_connection("abc123", user.id) + """ + user_id_str = str(user_id) + now = datetime.now(UTC).isoformat() + + async with get_redis() as redis: + # Check for existing connection and clean it up + old_sid = await redis.get(self._user_conn_key(user_id_str)) + if old_sid and old_sid != sid: + logger.info(f"Replacing old connection {old_sid} for user {user_id_str}") + await self._cleanup_connection(old_sid) + + # Create connection record (hash) + conn_key = self._conn_key(sid) + await redis.hset( + conn_key, + mapping={ + "user_id": user_id_str, + "game_id": "", + "connected_at": now, + "last_seen": now, + }, + ) + await redis.expire(conn_key, self.conn_ttl_seconds) + + # Map user to this connection + user_conn_key = self._user_conn_key(user_id_str) + await redis.set(user_conn_key, sid) + await redis.expire(user_conn_key, self.conn_ttl_seconds) + + logger.debug(f"Registered connection: sid={sid}, user_id={user_id_str}") + + async def unregister_connection(self, sid: str) -> ConnectionInfo | None: + """Unregister a WebSocket connection and clean up all related data. + + Removes connection from: + - Connection record + - User-to-connection mapping + - Any game connection sets + + Args: + sid: Socket.IO session ID. + + Returns: + ConnectionInfo of the removed connection, or None if not found. + + Example: + info = await manager.unregister_connection("abc123") + if info and info.game_id: + # Notify opponent of disconnect + """ + # Get connection info before cleanup + info = await self.get_connection(sid) + if info is None: + logger.debug(f"Connection not found for unregister: {sid}") + return None + + await self._cleanup_connection(sid) + logger.debug(f"Unregistered connection: sid={sid}, user_id={info.user_id}") + return info + + async def _cleanup_connection(self, sid: str) -> None: + """Internal cleanup of a connection's Redis data. + + Args: + sid: Socket.IO session ID. + """ + async with get_redis() as redis: + conn_key = self._conn_key(sid) + + # Get connection data for cleanup + data = await redis.hgetall(conn_key) + if not data: + return + + user_id = data.get("user_id", "") + game_id = data.get("game_id", "") + + # Remove from game connection set if in a game + if game_id: + game_conns_key = self._game_conns_key(game_id) + await redis.srem(game_conns_key, sid) + + # Remove user-to-connection mapping if it points to this sid + if user_id: + user_conn_key = self._user_conn_key(user_id) + current_sid = await redis.get(user_conn_key) + if current_sid == sid: + await redis.delete(user_conn_key) + + # Delete connection record + await redis.delete(conn_key) + + # ========================================================================= + # Game Association + # ========================================================================= + + async def join_game(self, sid: str, game_id: str) -> bool: + """Associate a connection with a game. + + Updates the connection's game_id and adds it to the game's connection set. + If connection was in a different game, leaves that game first. + + Args: + sid: Socket.IO session ID. + game_id: Game to join. + + Returns: + True if successful, False if connection not found. + + Example: + await manager.join_game("abc123", "game-456") + """ + async with get_redis() as redis: + conn_key = self._conn_key(sid) + + # Check connection exists + exists = await redis.exists(conn_key) + if not exists: + logger.warning(f"Cannot join game: connection not found {sid}") + return False + + # Leave current game if in one + current_game = await redis.hget(conn_key, "game_id") + if current_game and current_game != game_id: + old_game_conns_key = self._game_conns_key(current_game) + await redis.srem(old_game_conns_key, sid) + logger.debug(f"Connection {sid} left game {current_game}") + + # Update connection's game_id + await redis.hset(conn_key, "game_id", game_id) + + # Add to game's connection set + game_conns_key = self._game_conns_key(game_id) + await redis.sadd(game_conns_key, sid) + + # Set TTL on game conns set (cleanup if game becomes inactive) + await redis.expire(game_conns_key, self.conn_ttl_seconds) + + logger.debug(f"Connection {sid} joined game {game_id}") + return True + + async def leave_game(self, sid: str) -> str | None: + """Remove a connection from its current game. + + Args: + sid: Socket.IO session ID. + + Returns: + The game_id that was left, or None if not in a game. + + Example: + game_id = await manager.leave_game("abc123") + """ + async with get_redis() as redis: + conn_key = self._conn_key(sid) + + # Get current game + game_id = await redis.hget(conn_key, "game_id") + if not game_id: + return None + + # Remove from game's connection set + game_conns_key = self._game_conns_key(game_id) + await redis.srem(game_conns_key, sid) + + # Clear game_id on connection + await redis.hset(conn_key, "game_id", "") + + logger.debug(f"Connection {sid} left game {game_id}") + return game_id + + # ========================================================================= + # Heartbeat / Activity + # ========================================================================= + + async def update_heartbeat(self, sid: str) -> bool: + """Update the last_seen timestamp for a connection. + + Should be called on each heartbeat message from the client. + Also refreshes TTL on connection records. + + Args: + sid: Socket.IO session ID. + + Returns: + True if successful, False if connection not found. + + Example: + await manager.update_heartbeat("abc123") + """ + now = datetime.now(UTC).isoformat() + + async with get_redis() as redis: + conn_key = self._conn_key(sid) + + # Check exists and update atomically + exists = await redis.exists(conn_key) + if not exists: + return False + + # Update last_seen + await redis.hset(conn_key, "last_seen", now) + + # Refresh TTL + await redis.expire(conn_key, self.conn_ttl_seconds) + + # Also refresh user mapping TTL + user_id = await redis.hget(conn_key, "user_id") + if user_id: + user_conn_key = self._user_conn_key(user_id) + await redis.expire(user_conn_key, self.conn_ttl_seconds) + + return True + + # ========================================================================= + # Query Methods + # ========================================================================= + + async def get_connection(self, sid: str) -> ConnectionInfo | None: + """Get connection info by session ID. + + Args: + sid: Socket.IO session ID. + + Returns: + ConnectionInfo if found, None otherwise. + + Example: + info = await manager.get_connection("abc123") + if info: + print(f"User {info.user_id} connected at {info.connected_at}") + """ + async with get_redis() as redis: + conn_key = self._conn_key(sid) + data = await redis.hgetall(conn_key) + + if not data: + return None + + return ConnectionInfo( + sid=sid, + user_id=data.get("user_id", ""), + game_id=data.get("game_id") or None, + connected_at=datetime.fromisoformat(data["connected_at"]), + last_seen=datetime.fromisoformat(data["last_seen"]), + ) + + async def get_user_connection(self, user_id: str | UUID) -> ConnectionInfo | None: + """Get connection info for a user. + + Args: + user_id: User's UUID or string ID. + + Returns: + ConnectionInfo if user is connected, None otherwise. + + Example: + info = await manager.get_user_connection(opponent_id) + if info: + # User is online + """ + user_id_str = str(user_id) + + async with get_redis() as redis: + user_conn_key = self._user_conn_key(user_id_str) + sid = await redis.get(user_conn_key) + + if not sid: + return None + + return await self.get_connection(sid) + + async def is_user_online(self, user_id: str | UUID) -> bool: + """Check if a user has an active connection. + + Args: + user_id: User's UUID or string ID. + + Returns: + True if user is connected, False otherwise. + + Example: + if await manager.is_user_online(opponent_id): + # Send real-time notification + """ + info = await self.get_user_connection(user_id) + if info is None: + return False + + # Check if connection is stale + if info.is_stale(): + logger.debug(f"User {user_id} connection is stale") + return False + + return True + + async def get_game_connections(self, game_id: str) -> list[ConnectionInfo]: + """Get all connections for a game. + + Args: + game_id: Game ID. + + Returns: + List of ConnectionInfo for all connected participants. + + Example: + connections = await manager.get_game_connections("game-456") + for conn in connections: + # Send state update to each participant + """ + async with get_redis() as redis: + game_conns_key = self._game_conns_key(game_id) + sids = await redis.smembers(game_conns_key) + + connections = [] + for sid in sids: + info = await self.get_connection(sid) + if info is not None: + connections.append(info) + + return connections + + async def get_game_user_sids(self, game_id: str) -> dict[str, str]: + """Get mapping of user_id to sid for a game. + + Useful for targeted message delivery to specific players. + + Args: + game_id: Game ID. + + Returns: + Dict mapping user_id to their sid. + + Example: + user_sids = await manager.get_game_user_sids("game-456") + player1_sid = user_sids.get(player1_id) + """ + connections = await self.get_game_connections(game_id) + return {conn.user_id: conn.sid for conn in connections} + + async def get_opponent_sid( + self, + game_id: str, + current_user_id: str | UUID, + ) -> str | None: + """Get the sid of the opponent in a 2-player game. + + Args: + game_id: Game ID. + current_user_id: The current user's ID. + + Returns: + Opponent's sid if found and connected, None otherwise. + + Example: + opponent_sid = await manager.get_opponent_sid("game-456", user_id) + if opponent_sid: + # Send message to opponent + """ + current_user_str = str(current_user_id) + connections = await self.get_game_connections(game_id) + + for conn in connections: + if conn.user_id != current_user_str: + return conn.sid + + return None + + # ========================================================================= + # Maintenance + # ========================================================================= + + async def cleanup_stale_connections( + self, + threshold_seconds: int = HEARTBEAT_INTERVAL_SECONDS * 3, + ) -> int: + """Clean up stale connections that haven't sent heartbeats. + + Should be called periodically by a background task. + + Args: + threshold_seconds: Seconds since last_seen to consider stale. + + Returns: + Number of connections cleaned up. + + Example: + # In background task + count = await manager.cleanup_stale_connections() + logger.info(f"Cleaned up {count} stale connections") + """ + count = 0 + async with get_redis() as redis: + # Scan for all connection keys + async for key in redis.scan_iter(match=f"{CONN_PREFIX}*"): + sid = key[len(CONN_PREFIX) :] + info = await self.get_connection(sid) + + if info and info.is_stale(threshold_seconds): + await self._cleanup_connection(sid) + count += 1 + logger.debug(f"Cleaned up stale connection: {sid}") + + if count > 0: + logger.info(f"Cleaned up {count} stale connections") + + return count + + async def get_connection_count(self) -> int: + """Get the total number of active connections. + + Useful for monitoring and admin dashboards. + + Returns: + Number of active connections. + """ + count = 0 + async with get_redis() as redis: + async for _ in redis.scan_iter(match=f"{CONN_PREFIX}*"): + count += 1 + return count + + async def get_game_connection_count(self, game_id: str) -> int: + """Get the number of connections for a specific game. + + Args: + game_id: Game ID. + + Returns: + Number of connected participants. + """ + async with get_redis() as redis: + game_conns_key = self._game_conns_key(game_id) + return await redis.scard(game_conns_key) + + +# Global singleton instance +connection_manager = ConnectionManager() diff --git a/backend/app/services/game_service.py b/backend/app/services/game_service.py new file mode 100644 index 0000000..5578ce4 --- /dev/null +++ b/backend/app/services/game_service.py @@ -0,0 +1,558 @@ +"""Game service for orchestrating game lifecycle in Mantimon TCG. + +This service is the bridge between WebSocket communication and the core +GameEngine. It handles: +- Game creation with deck loading and validation +- Action execution with persistence +- Game state retrieval with visibility filtering +- Game lifecycle (join, resign, end) + +IMPORTANT: This service is stateless. All game-specific configuration +(RulesConfig) is stored in the GameState itself, not in this service. +Rules come from the frontend request at game creation time. + +Architecture: + WebSocket Layer -> GameService -> GameEngine + GameStateManager + -> DeckService + CardService + +Example: + from app.services.game_service import GameService, game_service + + # Create a new game (rules from frontend) + result = await game_service.create_game( + player1_id=user1.id, + player2_id=user2.id, + deck1_id=deck1.id, + deck2_id=deck2.id, + rules_config=rules_from_request, # Frontend provides this + ) + + # Execute an action (uses rules stored in game state) + result = await game_service.execute_action( + game_id=game.id, + player_id=user1.id, + action=AttackAction(attack_index=0), + ) +""" + +import logging +from dataclasses import dataclass, field +from typing import Any +from uuid import UUID + +from app.core.engine import ActionResult, GameEngine +from app.core.enums import GameEndReason +from app.core.models.actions import Action, ResignAction +from app.core.models.game_state import GameState +from app.core.rng import create_rng +from app.core.visibility import VisibleGameState, get_visible_state +from app.services.card_service import CardService, get_card_service +from app.services.game_state_manager import GameStateManager, game_state_manager + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Exceptions +# ============================================================================= + + +class GameServiceError(Exception): + """Base exception for GameService errors.""" + + pass + + +class GameNotFoundError(GameServiceError): + """Raised when a game cannot be found.""" + + def __init__(self, game_id: str) -> None: + self.game_id = game_id + super().__init__(f"Game not found: {game_id}") + + +class NotPlayerTurnError(GameServiceError): + """Raised when a player tries to act out of turn.""" + + def __init__(self, game_id: str, player_id: str, current_player_id: str) -> None: + self.game_id = game_id + self.player_id = player_id + self.current_player_id = current_player_id + super().__init__( + f"Not player's turn: {player_id} tried to act, but it's {current_player_id}'s turn" + ) + + +class InvalidActionError(GameServiceError): + """Raised when an action is invalid.""" + + def __init__(self, game_id: str, player_id: str, reason: str) -> None: + self.game_id = game_id + self.player_id = player_id + self.reason = reason + super().__init__(f"Invalid action in game {game_id}: {reason}") + + +class PlayerNotInGameError(GameServiceError): + """Raised when a player is not a participant in the game.""" + + def __init__(self, game_id: str, player_id: str) -> None: + self.game_id = game_id + self.player_id = player_id + super().__init__(f"Player {player_id} is not in game {game_id}") + + +class GameAlreadyEndedError(GameServiceError): + """Raised when trying to act on a game that has already ended.""" + + def __init__(self, game_id: str) -> None: + self.game_id = game_id + super().__init__(f"Game {game_id} has already ended") + + +# ============================================================================= +# Result Types +# ============================================================================= + + +@dataclass +class GameActionResult: + """Result of executing a game action. + + Attributes: + success: Whether the action succeeded. + game_id: The game ID. + action_type: The type of action executed. + message: Description of what happened. + state_changes: Dict of state changes for client updates. + game_over: Whether the game ended as a result. + winner_id: Winner's player ID if game ended with a winner. + end_reason: Reason the game ended, if applicable. + """ + + success: bool + game_id: str + action_type: str + message: str = "" + state_changes: dict[str, Any] = field(default_factory=dict) + game_over: bool = False + winner_id: str | None = None + end_reason: GameEndReason | None = None + + +@dataclass +class GameJoinResult: + """Result of joining a game. + + Attributes: + success: Whether the join succeeded. + game_id: The game ID. + player_id: The joining player's ID. + visible_state: The game state visible to the player. + is_your_turn: Whether it's this player's turn. + message: Additional information or error message. + """ + + success: bool + game_id: str + player_id: str + visible_state: VisibleGameState | None = None + is_your_turn: bool = False + message: str = "" + + +# ============================================================================= +# GameService +# ============================================================================= + + +class GameService: + """Service for orchestrating game lifecycle operations. + + This service coordinates between the WebSocket layer and the core + GameEngine, handling persistence and state management. + + IMPORTANT: This service is STATELESS regarding game rules. + - Rules are stored in each GameState (set at creation time) + - The GameEngine is instantiated per-operation with the game's rules + - No RulesConfig is stored in this service + + Attributes: + _state_manager: GameStateManager for persistence. + _card_service: CardService for card definitions. + """ + + def __init__( + self, + state_manager: GameStateManager | None = None, + card_service: CardService | None = None, + ) -> None: + """Initialize the GameService. + + Note: No GameEngine or RulesConfig here - those are per-game, + not per-service. The engine is created as needed using the + rules stored in each game's state. + + Args: + state_manager: GameStateManager instance. Uses global if not provided. + card_service: CardService instance. Uses global if not provided. + """ + self._state_manager = state_manager or game_state_manager + self._card_service = card_service or get_card_service() + + def _create_engine_for_game(self, game: GameState) -> GameEngine: + """Create a GameEngine configured for a specific game's rules. + + The engine is created on-demand using the rules stored in the + game state. This ensures each game uses its own configuration. + + For deterministic replay support, we derive a unique seed per action + by combining the game's base seed with the action count. This ensures: + - Same game + same action sequence = identical RNG results + - Each action gets a unique but reproducible random sequence + + Args: + game: The game state containing the rules to use. + + Returns: + A GameEngine configured with the game's rules and RNG. + """ + if game.rng_seed is not None: + # Derive unique seed per action for deterministic replay + # Action count ensures each action gets different but reproducible RNG + action_count = len(game.action_log) + action_seed = game.rng_seed + action_count + rng = create_rng(seed=action_seed) + else: + # No seed - use cryptographically secure RNG + rng = create_rng() + + return GameEngine(rules=game.rules, rng=rng) + + # ========================================================================= + # Game State Access + # ========================================================================= + + async def get_game_state(self, game_id: str) -> GameState: + """Get the full game state. + + Args: + game_id: The game ID. + + Returns: + The GameState. + + Raises: + GameNotFoundError: If game doesn't exist. + """ + state = await self._state_manager.load_state(game_id) + if state is None: + raise GameNotFoundError(game_id) + return state + + async def get_player_view( + self, + game_id: str, + player_id: str, + ) -> VisibleGameState: + """Get the game state filtered for a specific player's view. + + This applies visibility rules to hide opponent's hidden information + (hand, deck, prizes). + + Args: + game_id: The game ID. + player_id: The player to get the view for. + + Returns: + VisibleGameState with appropriate filtering. + + Raises: + GameNotFoundError: If game doesn't exist. + PlayerNotInGameError: If player is not in the game. + """ + state = await self.get_game_state(game_id) + + if player_id not in state.players: + raise PlayerNotInGameError(game_id, player_id) + + return get_visible_state(state, player_id) + + async def is_player_turn(self, game_id: str, player_id: str) -> bool: + """Check if it's the specified player's turn. + + Args: + game_id: The game ID. + player_id: The player to check. + + Returns: + True if it's the player's turn. + + Raises: + GameNotFoundError: If game doesn't exist. + """ + state = await self.get_game_state(game_id) + return state.current_player_id == player_id + + async def game_exists(self, game_id: str) -> bool: + """Check if a game exists. + + Args: + game_id: The game ID. + + Returns: + True if the game exists in cache or database. + """ + return await self._state_manager.cache_exists(game_id) + + # ========================================================================= + # Game Lifecycle + # ========================================================================= + + async def join_game( + self, + game_id: str, + player_id: str, + last_event_id: str | None = None, + ) -> GameJoinResult: + """Join or rejoin a game session. + + Loads the game state and returns the player's visible view. + Used when a player connects or reconnects to a game. + + Args: + game_id: The game to join. + player_id: The joining player's ID. + last_event_id: Last event ID for reconnection replay (future use). + + Returns: + GameJoinResult with the visible state. + """ + try: + state = await self.get_game_state(game_id) + except GameNotFoundError: + return GameJoinResult( + success=False, + game_id=game_id, + player_id=player_id, + message="Game not found", + ) + + if player_id not in state.players: + return GameJoinResult( + success=False, + game_id=game_id, + player_id=player_id, + message="You are not a participant in this game", + ) + + # Check if game already ended + if state.winner_id is not None or state.end_reason is not None: + visible = get_visible_state(state, player_id) + return GameJoinResult( + success=True, + game_id=game_id, + player_id=player_id, + visible_state=visible, + is_your_turn=False, + message="Game has ended", + ) + + visible = get_visible_state(state, player_id) + + logger.info(f"Player {player_id} joined game {game_id}") + + return GameJoinResult( + success=True, + game_id=game_id, + player_id=player_id, + visible_state=visible, + is_your_turn=state.current_player_id == player_id, + ) + + async def execute_action( + self, + game_id: str, + player_id: str, + action: Action, + ) -> GameActionResult: + """Execute a player action in the game. + + Validates the action, executes it through GameEngine, and + persists the updated state. The GameEngine is created using + the rules stored in the game state. + + Args: + game_id: The game ID. + player_id: The acting player's ID. + action: The action to execute. + + Returns: + GameActionResult with success status and state changes. + + Raises: + GameNotFoundError: If game doesn't exist. + PlayerNotInGameError: If player is not in the game. + GameAlreadyEndedError: If game has already ended. + NotPlayerTurnError: If it's not the player's turn. + InvalidActionError: If the action is invalid. + """ + # Load game state + state = await self.get_game_state(game_id) + + # Validate player is in game + if player_id not in state.players: + raise PlayerNotInGameError(game_id, player_id) + + # Check game hasn't ended + if state.winner_id is not None or state.end_reason is not None: + raise GameAlreadyEndedError(game_id) + + # Check it's player's turn (unless resignation, which can happen anytime) + if not isinstance(action, ResignAction) and state.current_player_id != player_id: + raise NotPlayerTurnError(game_id, player_id, state.current_player_id) + + # Create engine with this game's rules + engine = self._create_engine_for_game(state) + + # Execute the action + result: ActionResult = await engine.execute_action(state, player_id, action) + + if not result.success: + raise InvalidActionError(game_id, player_id, result.message) + + # Save state to cache (fast path) + await self._state_manager.save_to_cache(state) + + # Check if turn ended - persist to DB at turn boundaries + # TODO: Implement turn boundary detection for DB persistence + + # Build response + action_result = GameActionResult( + success=True, + game_id=game_id, + action_type=action.type, + message=result.message, + state_changes={ + "changes": result.state_changes, + }, + ) + + # Check for game over + if result.win_result is not None: + action_result.game_over = True + action_result.winner_id = result.win_result.winner_id + action_result.end_reason = result.win_result.end_reason + + # Persist final state to DB + await self._state_manager.persist_to_db(state) + + logger.info( + f"Game {game_id} ended: winner={result.win_result.winner_id}, " + f"reason={result.win_result.end_reason}" + ) + + logger.debug(f"Action executed: game={game_id}, player={player_id}, type={action.type}") + + return action_result + + async def resign_game( + self, + game_id: str, + player_id: str, + ) -> GameActionResult: + """Resign from a game. + + Convenience method that executes a ResignAction. + + Args: + game_id: The game ID. + player_id: The resigning player's ID. + + Returns: + GameActionResult indicating game over. + """ + return await self.execute_action( + game_id=game_id, + player_id=player_id, + action=ResignAction(), + ) + + async def end_game( + self, + game_id: str, + winner_id: str | None, + end_reason: GameEndReason, + ) -> None: + """Forcibly end a game (e.g., due to timeout or disconnection). + + This should be called by the timeout system or when a player + disconnects without reconnecting within the grace period. + + Args: + game_id: The game ID. + winner_id: The winner's player ID, or None for a draw. + end_reason: Why the game ended. + + Raises: + GameNotFoundError: If game doesn't exist. + """ + state = await self.get_game_state(game_id) + + # Set winner and end reason + state.winner_id = winner_id + state.end_reason = end_reason + + # Persist to both cache and DB + await self._state_manager.save_to_cache(state) + await self._state_manager.persist_to_db(state) + + logger.info(f"Game {game_id} forcibly ended: winner={winner_id}, reason={end_reason}") + + # ========================================================================= + # Game Creation (Skeleton - Full implementation in GS-002) + # ========================================================================= + + async def create_game( + self, + player1_id: str | UUID, + player2_id: str | UUID, + deck1_id: str | UUID | None = None, + deck2_id: str | UUID | None = None, + # Rules come from the frontend request - this is required, not optional + # Defaulting to None here only for the skeleton; GS-002 will make it required + ) -> None: + """Create a new game between two players. + + This is a skeleton that will be fully implemented in GS-002. + + IMPORTANT: rules_config will be a required parameter - it comes + from the frontend request, not from server-side defaults. + + Args: + player1_id: First player's ID. + player2_id: Second player's ID. + deck1_id: First player's deck ID. + deck2_id: Second player's deck ID. + + Raises: + NotImplementedError: Until GS-002 is complete. + """ + # TODO (GS-002): Full implementation with: + # - rules_config: RulesConfig parameter (required, from frontend) + # - Load decks via DeckService + # - Load card registry from CardService + # - Convert to CardInstances with unique IDs + # - Create GameState with the provided rules_config + # - Persist to Redis and Postgres + + raise NotImplementedError( + "Game creation not yet implemented - see GS-002. " + "Rules will come from frontend request, not server defaults." + ) + + +# Global singleton instance +# Note: This is safe because GameService is stateless regarding game rules. +# Each game's rules are stored in its GameState, not in this service. +game_service = GameService() diff --git a/backend/app/socketio/__init__.py b/backend/app/socketio/__init__.py new file mode 100644 index 0000000..1e5469f --- /dev/null +++ b/backend/app/socketio/__init__.py @@ -0,0 +1,46 @@ +"""Socket.IO module for real-time game communication. + +This module provides WebSocket-based real-time communication for: +- Active game sessions (actions, state updates, turn notifications) +- Connection management with session recovery +- Turn timeout handling +- JWT-based authentication + +Architecture: + - Uses python-socketio with ASGI mode for FastAPI integration + - /game namespace handles all active game communication + - /lobby namespace (Phase 6) will handle matchmaking + - JWT authentication on connect (WS-004) + +Usage: + The Socket.IO server is mounted alongside FastAPI in app/main.py. + Clients connect to ws://host/socket.io/ and join the /game namespace. + + Client connection example: + socket = io("ws://host", { + auth: { token: "JWT_ACCESS_TOKEN" } + }); +""" + +from app.socketio.auth import ( + AuthResult, + authenticate_connection, + cleanup_authenticated_session, + get_session_user_id, + require_auth, + setup_authenticated_session, +) +from app.socketio.server import create_socketio_app, sio + +__all__ = [ + # Server + "create_socketio_app", + "sio", + # Auth + "AuthResult", + "authenticate_connection", + "cleanup_authenticated_session", + "get_session_user_id", + "require_auth", + "setup_authenticated_session", +] diff --git a/backend/app/socketio/auth.py b/backend/app/socketio/auth.py new file mode 100644 index 0000000..125afec --- /dev/null +++ b/backend/app/socketio/auth.py @@ -0,0 +1,283 @@ +"""Socket.IO authentication middleware for WebSocket connections. + +This module provides JWT-based authentication for Socket.IO connections. +It validates access tokens and attaches user information to the socket session. + +Authentication Flow: + 1. Client connects with `auth: { token: "JWT_ACCESS_TOKEN" }` + 2. Server extracts and validates the JWT + 3. If valid, user_id is stored in socket session + 4. If invalid, connection is rejected with appropriate error + +Session Data: + After successful authentication, the socket session contains: + - user_id: str (UUID as string) + - authenticated_at: str (ISO timestamp) + +Example: + # In connect handler: + auth_result = await authenticate_connection(sid, auth) + if not auth_result.success: + return False # Reject connection + + # Later, get user_id: + session = await sio.get_session(sid, namespace="/game") + user_id = session.get("user_id") +""" + +import logging +from dataclasses import dataclass +from datetime import UTC, datetime +from uuid import UUID + +from app.services.connection_manager import connection_manager +from app.services.jwt_service import verify_access_token + +logger = logging.getLogger(__name__) + + +@dataclass +class AuthResult: + """Result of authentication attempt. + + Attributes: + success: Whether authentication succeeded. + user_id: User's UUID if successful, None otherwise. + error_code: Error code for client if failed. + error_message: Human-readable error message if failed. + """ + + success: bool + user_id: UUID | None = None + error_code: str | None = None + error_message: str | None = None + + +def extract_token(auth: dict[str, object] | None) -> str | None: + """Extract JWT token from Socket.IO auth data. + + Clients should send the token in the auth dict: + socket.connect({ auth: { token: "JWT_TOKEN" } }) + + Also supports: + - auth.authorization: "Bearer TOKEN" + - auth.access_token: "TOKEN" + + Args: + auth: Authentication data from Socket.IO connect. + + Returns: + JWT token string if found, None otherwise. + + Example: + token = extract_token({"token": "eyJ..."}) + token = extract_token({"authorization": "Bearer eyJ..."}) + """ + if auth is None: + return None + + # Primary: auth.token + token = auth.get("token") + if token and isinstance(token, str): + return token + + # Alternative: auth.authorization (Bearer token) + authorization = auth.get("authorization") + if authorization and isinstance(authorization, str): + if authorization.lower().startswith("bearer "): + return authorization[7:] + return authorization + + # Alternative: auth.access_token + access_token = auth.get("access_token") + if access_token and isinstance(access_token, str): + return access_token + + return None + + +async def authenticate_connection( + sid: str, + auth: dict[str, object] | None, +) -> AuthResult: + """Authenticate a Socket.IO connection using JWT. + + Extracts the JWT from auth data, validates it, and returns the result. + Does NOT modify socket session - caller should handle that. + + Args: + sid: Socket session ID (for logging). + auth: Authentication data from connect event. + + Returns: + AuthResult with success status and user_id or error details. + + Example: + result = await authenticate_connection(sid, auth) + if result.success: + await sio.save_session(sid, {"user_id": str(result.user_id)}) + else: + logger.warning(f"Auth failed: {result.error_message}") + return False # Reject connection + """ + # Extract token from auth data + token = extract_token(auth) + + if token is None: + logger.debug(f"Connection {sid}: No token provided") + return AuthResult( + success=False, + error_code="missing_token", + error_message="Authentication token required", + ) + + # Validate the token + user_id = verify_access_token(token) + + if user_id is None: + logger.debug(f"Connection {sid}: Invalid or expired token") + return AuthResult( + success=False, + error_code="invalid_token", + error_message="Invalid or expired token", + ) + + logger.debug(f"Connection {sid}: Authenticated as user {user_id}") + return AuthResult( + success=True, + user_id=user_id, + ) + + +async def setup_authenticated_session( + sio: object, + sid: str, + user_id: UUID, + namespace: str = "/game", +) -> None: + """Set up socket session with authenticated user data. + + Saves user_id and authentication timestamp to the socket session, + and registers the connection with ConnectionManager. + + Args: + sio: Socket.IO AsyncServer instance. + sid: Socket session ID. + user_id: Authenticated user's UUID. + namespace: Socket.IO namespace. + + Example: + if auth_result.success: + await setup_authenticated_session(sio, sid, auth_result.user_id) + """ + # Import here to avoid circular dependency + from app.socketio.server import sio as server_sio + + # Use provided sio or fall back to server sio + socket_server = sio if sio is not None else server_sio + + # Save to socket session + session_data = { + "user_id": str(user_id), + "authenticated_at": datetime.now(UTC).isoformat(), + } + await socket_server.save_session(sid, session_data, namespace=namespace) + + # Register with ConnectionManager + await connection_manager.register_connection(sid, user_id) + + logger.info(f"Session established: sid={sid}, user_id={user_id}") + + +async def cleanup_authenticated_session( + sid: str, + namespace: str = "/game", +) -> str | None: + """Clean up session data on disconnect. + + Unregisters the connection from ConnectionManager and returns + the user_id for any additional cleanup needed. + + Args: + sid: Socket session ID. + namespace: Socket.IO namespace. + + Returns: + user_id if session was authenticated, None otherwise. + + Example: + user_id = await cleanup_authenticated_session(sid) + if user_id: + # Notify opponent, etc. + """ + # Unregister from ConnectionManager + conn_info = await connection_manager.unregister_connection(sid) + + if conn_info: + logger.info(f"Session cleaned up: sid={sid}, user_id={conn_info.user_id}") + return conn_info.user_id + + logger.debug(f"No session to clean up for {sid}") + return None + + +async def get_session_user_id( + sio: object, + sid: str, + namespace: str = "/game", +) -> str | None: + """Get the authenticated user_id from a socket session. + + Convenience function to extract user_id from session data. + + Args: + sio: Socket.IO AsyncServer instance. + sid: Socket session ID. + namespace: Socket.IO namespace. + + Returns: + user_id string if authenticated, None otherwise. + + Example: + user_id = await get_session_user_id(sio, sid) + if not user_id: + await sio.emit("error", {"message": "Not authenticated"}, to=sid) + return + """ + try: + session = await sio.get_session(sid, namespace=namespace) + return session.get("user_id") if session else None + except Exception: + return None + + +async def require_auth( + sio: object, + sid: str, + namespace: str = "/game", +) -> str | None: + """Require authentication for an event handler. + + Returns the user_id if authenticated, None if not. + Logs a warning if authentication is missing. + + Args: + sio: Socket.IO AsyncServer instance. + sid: Socket session ID. + namespace: Socket.IO namespace. + + Returns: + user_id string if authenticated, None otherwise. + + Example: + @sio.on("game:action", namespace="/game") + async def on_action(sid, data): + user_id = await require_auth(sio, sid) + if not user_id: + return {"error": "Not authenticated"} + # ... handle action + """ + user_id = await get_session_user_id(sio, sid, namespace) + if user_id is None: + logger.warning(f"Unauthenticated event from {sid}") + return user_id diff --git a/backend/app/socketio/server.py b/backend/app/socketio/server.py new file mode 100644 index 0000000..3f4127c --- /dev/null +++ b/backend/app/socketio/server.py @@ -0,0 +1,240 @@ +"""Socket.IO server configuration and ASGI app creation. + +This module sets up the python-socketio AsyncServer and creates the combined +ASGI application that mounts Socket.IO alongside FastAPI. + +Architecture: + - AsyncServer handles WebSocket connections with async_mode='asgi' + - Socket.IO app is mounted at /socket.io path + - CORS settings match FastAPI configuration + - JWT authentication on connect via auth.py + - Namespaces are registered for different communication domains + +Namespaces: + /game - Active game communication (actions, state updates) + /lobby - Pre-game lobby (matchmaking, invites) - Phase 6 + +Authentication: + Clients must provide a JWT access token in the auth parameter: + socket.connect({ auth: { token: "JWT_ACCESS_TOKEN" } }) + +Example: + from app.socketio import create_socketio_app + from fastapi import FastAPI + + fastapi_app = FastAPI() + combined_app = create_socketio_app(fastapi_app) + # Run with: uvicorn app.main:app +""" + +import logging +from typing import TYPE_CHECKING + +import socketio + +from app.config import settings +from app.socketio.auth import ( + authenticate_connection, + cleanup_authenticated_session, + require_auth, + setup_authenticated_session, +) + +if TYPE_CHECKING: + from fastapi import FastAPI + +logger = logging.getLogger(__name__) + +# Create the AsyncServer instance +# - async_mode='asgi' for ASGI compatibility with uvicorn +# - cors_allowed_origins matches FastAPI CORS settings +# - logger enables Socket.IO internal logging in debug mode +sio = socketio.AsyncServer( + async_mode="asgi", + cors_allowed_origins=settings.cors_origins if settings.cors_origins else [], + logger=settings.debug, + engineio_logger=settings.debug, +) + + +# ============================================================================= +# /game Namespace - Active Game Communication +# ============================================================================= +# These are skeleton handlers that will be fully implemented in WS-005. +# For now, they provide basic connection lifecycle handling. + + +@sio.event(namespace="/game") +async def connect( + sid: str, environ: dict[str, object], auth: dict[str, object] | None = None +) -> bool | None: + """Handle client connection to /game namespace. + + Authenticates the connection using JWT from auth data. + Rejects connections without valid authentication. + + Args: + sid: Socket session ID assigned by Socket.IO. + environ: WSGI/ASGI environ dict with request info. + auth: Authentication data sent by client (JWT token). + Expected format: { token: "JWT_ACCESS_TOKEN" } + + Returns: + True to accept connection, False to reject. + None is treated as True (accept). + """ + # Authenticate the connection + auth_result = await authenticate_connection(sid, auth) + + if not auth_result.success: + logger.warning( + f"Connection rejected for {sid}: {auth_result.error_code} - {auth_result.error_message}" + ) + # Emit error before rejecting (client may receive this) + await sio.emit( + "auth_error", + { + "code": auth_result.error_code, + "message": auth_result.error_message, + }, + to=sid, + namespace="/game", + ) + return False + + # Set up authenticated session and register connection + await setup_authenticated_session(sio, sid, auth_result.user_id, namespace="/game") + + logger.info(f"Client authenticated to /game: sid={sid}, user_id={auth_result.user_id}") + return True + + +@sio.event(namespace="/game") +async def disconnect(sid: str) -> None: + """Handle client disconnection from /game namespace. + + Cleans up connection state and notifies other game participants. + + Args: + sid: Socket session ID of disconnecting client. + """ + # Clean up session and get user info + user_id = await cleanup_authenticated_session(sid, namespace="/game") + + if user_id: + logger.info(f"Client disconnected from /game: sid={sid}, user_id={user_id}") + # TODO (WS-005): Notify opponent of disconnection if in game + else: + logger.debug(f"Unauthenticated client disconnected: {sid}") + + +@sio.on("game:join", namespace="/game") +async def on_game_join(sid: str, data: dict[str, object]) -> dict[str, object]: + """Handle request to join/rejoin a game session. + + Args: + sid: Socket session ID. + data: Message containing game_id and optional last_event_id for resume. + + Returns: + Response with game state or error. + """ + logger.debug(f"game:join from {sid}: {data}") + # TODO (WS-005): Implement with GameService + return {"error": "Not implemented yet"} + + +@sio.on("game:action", namespace="/game") +async def on_game_action(sid: str, data: dict[str, object]) -> dict[str, object]: + """Handle game action from player. + + Args: + sid: Socket session ID. + data: Action message with type and parameters. + + Returns: + Action result or error. + """ + logger.debug(f"game:action from {sid}: {data}") + # TODO (WS-005): Implement with GameService + return {"error": "Not implemented yet"} + + +@sio.on("game:resign", namespace="/game") +async def on_game_resign(sid: str, data: dict[str, object]) -> dict[str, object]: + """Handle player resignation. + + Args: + sid: Socket session ID. + data: Resignation message (may be empty). + + Returns: + Confirmation or error. + """ + logger.debug(f"game:resign from {sid}: {data}") + # TODO (WS-005): Implement with GameService + return {"error": "Not implemented yet"} + + +@sio.on("game:heartbeat", namespace="/game") +async def on_game_heartbeat(sid: str, data: dict[str, object] | None = None) -> dict[str, object]: + """Handle heartbeat to keep connection alive. + + Updates last_seen timestamp in ConnectionManager to prevent + the connection from being marked as stale. + + Args: + sid: Socket session ID. + data: Optional heartbeat data (message_id for tracking). + + Returns: + Heartbeat acknowledgment with server timestamp. + """ + from datetime import UTC, datetime + + from app.services.connection_manager import connection_manager + + # Require authentication + user_id = await require_auth(sio, sid) + if not user_id: + return {"error": "Not authenticated", "code": "unauthenticated"} + + # Update last_seen in ConnectionManager + await connection_manager.update_heartbeat(sid) + + # Return acknowledgment with timestamp + return { + "type": "heartbeat_ack", + "timestamp": datetime.now(UTC).isoformat(), + "message_id": data.get("message_id") if data else None, + } + + +# ============================================================================= +# ASGI App Creation +# ============================================================================= + + +def create_socketio_app(fastapi_app: "FastAPI") -> socketio.ASGIApp: + """Create combined ASGI app with Socket.IO mounted alongside FastAPI. + + This wraps the FastAPI app with Socket.IO, handling: + - WebSocket connections at /socket.io/ + - HTTP requests passed through to FastAPI + + Args: + fastapi_app: The FastAPI application instance. + + Returns: + Combined ASGI application to use with uvicorn. + + Example: + app = FastAPI() + combined = create_socketio_app(app) + # uvicorn will use 'combined' as the ASGI app + """ + return socketio.ASGIApp( + sio, + other_asgi_app=fastapi_app, + socketio_path="socket.io", + ) diff --git a/backend/project_plans/PHASE_4_GAME_SERVICE.json b/backend/project_plans/PHASE_4_GAME_SERVICE.json index bda096c..12cc4cf 100644 --- a/backend/project_plans/PHASE_4_GAME_SERVICE.json +++ b/backend/project_plans/PHASE_4_GAME_SERVICE.json @@ -9,8 +9,8 @@ "description": "Real-time gameplay infrastructure - WebSocket communication, game lifecycle management, reconnection handling, and turn timeout system", "totalEstimatedHours": 45, "totalTasks": 18, - "completedTasks": 0, - "status": "not_started", + "completedTasks": 5, + "status": "in_progress", "masterPlan": "../PROJECT_PLAN_MASTER.json" }, @@ -105,8 +105,8 @@ "description": "Install and configure python-socketio ASGI server, mount alongside FastAPI app", "category": "infrastructure", "priority": 1, - "completed": false, - "tested": false, + "completed": true, + "tested": true, "dependencies": [], "files": [ {"path": "app/socketio/__init__.py", "status": "create"}, @@ -130,8 +130,8 @@ "description": "Define Pydantic models for all WebSocket message types", "category": "schemas", "priority": 2, - "completed": false, - "tested": false, + "completed": true, + "tested": true, "dependencies": ["WS-001"], "files": [ {"path": "app/schemas/ws_messages.py", "status": "create"} @@ -153,8 +153,8 @@ "description": "Manage WebSocket connections with Redis-backed session tracking", "category": "services", "priority": 3, - "completed": false, - "tested": false, + "completed": true, + "tested": true, "dependencies": ["WS-001"], "files": [ {"path": "app/services/connection_manager.py", "status": "create"} @@ -177,8 +177,8 @@ "description": "Authenticate WebSocket connections using JWT tokens", "category": "auth", "priority": 4, - "completed": false, - "tested": false, + "completed": true, + "tested": true, "dependencies": ["WS-001", "WS-003"], "files": [ {"path": "app/socketio/auth.py", "status": "create"}, @@ -614,6 +614,13 @@ "risk": "Turn timeout drift due to server restart", "mitigation": "Store absolute deadline in Redis/Postgres, recalculate on startup", "priority": "medium" + }, + { + "risk": "ConnectionManager race condition on rapid reconnects", + "mitigation": "Consider Redis Lua script for atomic old-connection cleanup in register_connection. Low probability in practice but could cause connection tracking issues during rapid connect/disconnect cycles.", + "priority": "low", + "status": "identified-in-review", + "notes": "Also consider: periodic cleanup of game_conns sets for long-running games, rate limiting in auth layer" } ], diff --git a/backend/pyproject.toml b/backend/pyproject.toml index bebe842..d420a10 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -107,6 +107,11 @@ module = [ ] ignore_missing_imports = true +[[tool.mypy.overrides]] +# Socket.IO handlers use untyped decorators from the library +module = "app.socketio.*" +disallow_untyped_decorators = false + # Coverage configuration [tool.coverage.run] source = ["app"] diff --git a/backend/tests/socketio/__init__.py b/backend/tests/socketio/__init__.py new file mode 100644 index 0000000..e8a0422 --- /dev/null +++ b/backend/tests/socketio/__init__.py @@ -0,0 +1 @@ +# Socket.IO integration tests diff --git a/backend/tests/socketio/test_auth.py b/backend/tests/socketio/test_auth.py new file mode 100644 index 0000000..073c1c0 --- /dev/null +++ b/backend/tests/socketio/test_auth.py @@ -0,0 +1,384 @@ +"""Tests for Socket.IO authentication middleware. + +This module tests JWT-based authentication for WebSocket connections, +including token extraction, validation, and session management. +""" + +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest + +from app.socketio.auth import ( + AuthResult, + authenticate_connection, + cleanup_authenticated_session, + extract_token, + get_session_user_id, + require_auth, + setup_authenticated_session, +) + + +class TestExtractToken: + """Tests for token extraction from Socket.IO auth data.""" + + def test_extract_token_from_token_field(self) -> None: + """Test extracting token from the primary 'token' field. + + The standard way for clients to pass the JWT is via auth.token. + This is the recommended format for Socket.IO clients. + """ + auth = {"token": "my-jwt-token"} + result = extract_token(auth) + assert result == "my-jwt-token" + + def test_extract_token_from_authorization_bearer(self) -> None: + """Test extracting token from Bearer authorization header format. + + Some clients may pass the token as a Bearer token for consistency + with HTTP API authentication patterns. + """ + auth = {"authorization": "Bearer my-jwt-token"} + result = extract_token(auth) + assert result == "my-jwt-token" + + def test_extract_token_from_authorization_without_bearer(self) -> None: + """Test extracting token from authorization without Bearer prefix. + + If the client provides just the token in authorization field, + we should still accept it. + """ + auth = {"authorization": "my-jwt-token"} + result = extract_token(auth) + assert result == "my-jwt-token" + + def test_extract_token_from_access_token_field(self) -> None: + """Test extracting token from access_token field. + + Alternative field name for OAuth-style clients. + """ + auth = {"access_token": "my-jwt-token"} + result = extract_token(auth) + assert result == "my-jwt-token" + + def test_extract_token_returns_none_for_none_auth(self) -> None: + """Test that None auth data returns None token. + + Clients that don't provide any auth should get None, + triggering the authentication failure path. + """ + result = extract_token(None) + assert result is None + + def test_extract_token_returns_none_for_empty_auth(self) -> None: + """Test that empty auth dict returns None token. + + An empty auth object should be treated as unauthenticated. + """ + result = extract_token({}) + assert result is None + + def test_extract_token_returns_none_for_non_string_token(self) -> None: + """Test that non-string token values are rejected. + + Only string tokens are valid - reject numbers, objects, etc. + """ + result = extract_token({"token": 12345}) + assert result is None + + result = extract_token({"token": {"nested": "value"}}) + assert result is None + + def test_extract_token_prefers_token_field(self) -> None: + """Test that 'token' field takes precedence over alternatives. + + If multiple token fields are present, we should use the + primary 'token' field. + """ + auth = { + "token": "primary-token", + "authorization": "Bearer secondary-token", + "access_token": "tertiary-token", + } + result = extract_token(auth) + assert result == "primary-token" + + +class TestAuthenticateConnection: + """Tests for connection authentication.""" + + @pytest.mark.asyncio + async def test_authenticate_success_with_valid_token(self) -> None: + """Test successful authentication with a valid JWT. + + A valid access token should result in AuthResult with + success=True and the user_id from the token. + """ + user_id = uuid4() + + with patch("app.socketio.auth.verify_access_token") as mock_verify: + mock_verify.return_value = user_id + + result = await authenticate_connection("test-sid", {"token": "valid-token"}) + + assert result.success is True + assert result.user_id == user_id + assert result.error_code is None + mock_verify.assert_called_once_with("valid-token") + + @pytest.mark.asyncio + async def test_authenticate_fails_with_missing_token(self) -> None: + """Test authentication failure when no token is provided. + + Connections without any auth data should fail with + a 'missing_token' error code. + """ + result = await authenticate_connection("test-sid", None) + + assert result.success is False + assert result.user_id is None + assert result.error_code == "missing_token" + assert "required" in result.error_message.lower() + + @pytest.mark.asyncio + async def test_authenticate_fails_with_empty_auth(self) -> None: + """Test authentication failure with empty auth object. + + An auth object without any token fields should fail. + """ + result = await authenticate_connection("test-sid", {}) + + assert result.success is False + assert result.error_code == "missing_token" + + @pytest.mark.asyncio + async def test_authenticate_fails_with_invalid_token(self) -> None: + """Test authentication failure with invalid/expired JWT. + + When verify_access_token returns None (invalid token), + we should fail with 'invalid_token' error. + """ + with patch("app.socketio.auth.verify_access_token") as mock_verify: + mock_verify.return_value = None # Token validation failed + + result = await authenticate_connection("test-sid", {"token": "invalid-token"}) + + assert result.success is False + assert result.user_id is None + assert result.error_code == "invalid_token" + assert ( + "invalid" in result.error_message.lower() + or "expired" in result.error_message.lower() + ) + + @pytest.mark.asyncio + async def test_authenticate_extracts_token_from_bearer(self) -> None: + """Test that authentication works with Bearer format. + + The authenticate function should handle Bearer token format + through the extract_token function. + """ + user_id = uuid4() + + with patch("app.socketio.auth.verify_access_token") as mock_verify: + mock_verify.return_value = user_id + + result = await authenticate_connection("test-sid", {"authorization": "Bearer my-token"}) + + assert result.success is True + mock_verify.assert_called_once_with("my-token") + + +class TestSetupAuthenticatedSession: + """Tests for session setup after authentication.""" + + @pytest.mark.asyncio + async def test_setup_saves_session_data(self) -> None: + """Test that session setup saves user_id and timestamp. + + After authentication, the socket session should contain + the user_id and authentication timestamp. + """ + user_id = uuid4() + mock_sio = AsyncMock() + + with patch("app.socketio.auth.connection_manager") as mock_cm: + mock_cm.register_connection = AsyncMock() + + await setup_authenticated_session(mock_sio, "test-sid", user_id) + + # Verify session was saved + mock_sio.save_session.assert_called_once() + call_args = mock_sio.save_session.call_args + assert call_args.args[0] == "test-sid" + + session_data = call_args.args[1] + assert session_data["user_id"] == str(user_id) + assert "authenticated_at" in session_data + + @pytest.mark.asyncio + async def test_setup_registers_with_connection_manager(self) -> None: + """Test that session setup registers with ConnectionManager. + + The connection should be tracked in ConnectionManager for + presence detection and game association. + """ + user_id = uuid4() + mock_sio = AsyncMock() + + with patch("app.socketio.auth.connection_manager") as mock_cm: + mock_cm.register_connection = AsyncMock() + + await setup_authenticated_session(mock_sio, "test-sid", user_id) + + mock_cm.register_connection.assert_called_once_with("test-sid", user_id) + + +class TestCleanupAuthenticatedSession: + """Tests for session cleanup on disconnect.""" + + @pytest.mark.asyncio + async def test_cleanup_unregisters_connection(self) -> None: + """Test that cleanup unregisters from ConnectionManager. + + On disconnect, the connection should be removed from + ConnectionManager to update presence tracking. + """ + with patch("app.socketio.auth.connection_manager") as mock_cm: + mock_conn_info = MagicMock() + mock_conn_info.user_id = "user-123" + mock_cm.unregister_connection = AsyncMock(return_value=mock_conn_info) + + result = await cleanup_authenticated_session("test-sid") + + assert result == "user-123" + mock_cm.unregister_connection.assert_called_once_with("test-sid") + + @pytest.mark.asyncio + async def test_cleanup_returns_none_for_unknown_session(self) -> None: + """Test cleanup returns None for non-existent sessions. + + If the connection wasn't registered (e.g., auth failed), + cleanup should return None gracefully. + """ + with patch("app.socketio.auth.connection_manager") as mock_cm: + mock_cm.unregister_connection = AsyncMock(return_value=None) + + result = await cleanup_authenticated_session("unknown-sid") + + assert result is None + + +class TestGetSessionUserId: + """Tests for session user_id retrieval.""" + + @pytest.mark.asyncio + async def test_get_session_user_id_returns_id(self) -> None: + """Test retrieving user_id from authenticated session. + + For authenticated sessions, get_session_user_id should + return the stored user_id string. + """ + mock_sio = AsyncMock() + mock_sio.get_session = AsyncMock( + return_value={"user_id": "user-123", "authenticated_at": "2024-01-01"} + ) + + result = await get_session_user_id(mock_sio, "test-sid") + + assert result == "user-123" + mock_sio.get_session.assert_called_once_with("test-sid", namespace="/game") + + @pytest.mark.asyncio + async def test_get_session_user_id_returns_none_for_missing(self) -> None: + """Test that missing session returns None. + + If no session exists for the sid, we should return None + rather than raising an error. + """ + mock_sio = AsyncMock() + mock_sio.get_session = AsyncMock(return_value=None) + + result = await get_session_user_id(mock_sio, "test-sid") + + assert result is None + + @pytest.mark.asyncio + async def test_get_session_user_id_handles_exception(self) -> None: + """Test that exceptions are caught and return None. + + If get_session raises an exception, we should catch it + and return None to avoid breaking the event handler. + """ + mock_sio = AsyncMock() + mock_sio.get_session = AsyncMock(side_effect=Exception("Session error")) + + result = await get_session_user_id(mock_sio, "test-sid") + + assert result is None + + +class TestRequireAuth: + """Tests for the require_auth helper.""" + + @pytest.mark.asyncio + async def test_require_auth_returns_user_id_for_authenticated(self) -> None: + """Test that require_auth returns user_id for valid sessions. + + Authenticated sessions should return the user_id for use + in event handlers. + """ + mock_sio = AsyncMock() + mock_sio.get_session = AsyncMock(return_value={"user_id": "user-123"}) + + result = await require_auth(mock_sio, "test-sid") + + assert result == "user-123" + + @pytest.mark.asyncio + async def test_require_auth_returns_none_for_unauthenticated(self) -> None: + """Test that require_auth returns None for unauthenticated sessions. + + Unauthenticated events should get None, allowing handlers + to return an error response. + """ + mock_sio = AsyncMock() + mock_sio.get_session = AsyncMock(return_value=None) + + result = await require_auth(mock_sio, "test-sid") + + assert result is None + + +class TestAuthResultDataclass: + """Tests for the AuthResult dataclass.""" + + def test_auth_result_success(self) -> None: + """Test creating a successful AuthResult. + + Success results should have user_id and no error fields. + """ + user_id = uuid4() + result = AuthResult(success=True, user_id=user_id) + + assert result.success is True + assert result.user_id == user_id + assert result.error_code is None + assert result.error_message is None + + def test_auth_result_failure(self) -> None: + """Test creating a failed AuthResult. + + Failure results should have error code/message and no user_id. + """ + result = AuthResult( + success=False, + error_code="invalid_token", + error_message="Token expired", + ) + + assert result.success is False + assert result.user_id is None + assert result.error_code == "invalid_token" + assert result.error_message == "Token expired" diff --git a/backend/tests/socketio/test_server_setup.py b/backend/tests/socketio/test_server_setup.py new file mode 100644 index 0000000..310ad57 --- /dev/null +++ b/backend/tests/socketio/test_server_setup.py @@ -0,0 +1,113 @@ +"""Tests for Socket.IO server setup and configuration. + +These tests verify that the Socket.IO server is correctly configured +and can be mounted alongside FastAPI. +""" + + +class TestSocketIOSetup: + """Tests for Socket.IO server initialization.""" + + def test_sio_server_exists(self) -> None: + """ + Verify that the Socket.IO AsyncServer is created. + + The server instance should be available as a module-level export + and configured for ASGI mode. + """ + from app.socketio import sio + + assert sio is not None + assert sio.async_mode == "asgi" + + def test_game_namespace_handlers_registered(self) -> None: + """ + Verify that /game namespace event handlers are registered. + + All required event handlers should be available on the sio instance + before any connections are made. + """ + from app.socketio import sio + + # Check that handlers are registered for /game namespace + assert "/game" in sio.handlers + + game_handlers = sio.handlers["/game"] + expected_events = [ + "connect", + "disconnect", + "game:join", + "game:action", + "game:resign", + "game:heartbeat", + ] + + for event in expected_events: + assert event in game_handlers, f"Missing handler for {event}" + + def test_create_socketio_app_returns_asgi_app(self) -> None: + """ + Verify that create_socketio_app returns a valid ASGI application. + + The combined app should wrap the FastAPI app and handle both + HTTP and WebSocket connections. + """ + import socketio + from fastapi import FastAPI + + from app.socketio import create_socketio_app + + # Create a minimal FastAPI app for testing + test_app = FastAPI() + + # Create combined ASGI app + combined = create_socketio_app(test_app) + + # Verify it's a Socket.IO ASGI app + assert isinstance(combined, socketio.ASGIApp) + + def test_cors_configured_from_settings(self) -> None: + """ + Verify that Socket.IO CORS settings match application settings. + + The Socket.IO server should allow the same origins as FastAPI + to ensure consistent cross-origin behavior. + """ + from app.config import settings + from app.socketio import sio + + # Socket.IO stores CORS origins in eio (Engine.IO) settings + # The cors_allowed_origins should match settings + assert sio.eio.cors_allowed_origins == settings.cors_origins + + +class TestMainAppIntegration: + """Tests for Socket.IO integration with main app.""" + + def test_main_app_is_combined_asgi_app(self) -> None: + """ + Verify that the main app module exports the combined ASGI app. + + The 'app' variable in main.py should be the Socket.IO wrapped + FastAPI application, not the raw FastAPI app. + """ + import socketio + + from app.main import app + + # The exported 'app' should be a Socket.IO ASGI app + assert isinstance(app, socketio.ASGIApp) + + def test_fastapi_app_accessible(self) -> None: + """ + Verify that the underlying FastAPI app is still accessible. + + The FastAPI app should be available as _fastapi_app in main + for testing and direct access when needed. + """ + from fastapi import FastAPI + + from app.main import _fastapi_app + + assert isinstance(_fastapi_app, FastAPI) + assert _fastapi_app.title == "Mantimon TCG" diff --git a/backend/tests/unit/schemas/__init__.py b/backend/tests/unit/schemas/__init__.py new file mode 100644 index 0000000..5a3d0ba --- /dev/null +++ b/backend/tests/unit/schemas/__init__.py @@ -0,0 +1 @@ +# Unit tests for schemas diff --git a/backend/tests/unit/schemas/test_ws_messages.py b/backend/tests/unit/schemas/test_ws_messages.py new file mode 100644 index 0000000..a08beff --- /dev/null +++ b/backend/tests/unit/schemas/test_ws_messages.py @@ -0,0 +1,701 @@ +"""Tests for WebSocket message schemas. + +This module tests the Pydantic models for WebSocket communication, verifying +discriminated union parsing, field validation, and serialization behavior. +""" + +from datetime import UTC, datetime + +import pytest +from pydantic import TypeAdapter, ValidationError + +from app.core.enums import GameEndReason, TurnPhase +from app.core.models.actions import AttackAction, PassAction +from app.core.visibility import VisibleGameState, VisiblePlayerState +from app.schemas.ws_messages import ( + ActionMessage, + ActionResultMessage, + ClientMessage, + ConnectionStatus, + ErrorMessage, + GameOverMessage, + GameStateMessage, + HeartbeatAckMessage, + HeartbeatMessage, + JoinGameMessage, + OpponentStatusMessage, + ResignMessage, + ServerMessage, + TurnStartMessage, + TurnTimeoutMessage, + WSErrorCode, + parse_client_message, + parse_server_message, +) + + +class TestClientMessageDiscriminator: + """Tests for client message discriminated union parsing. + + The discriminated union pattern allows automatic type resolution based on + the 'type' field, enabling type-safe message handling without explicit + type checking logic. + """ + + def test_parse_join_game_message(self) -> None: + """Test parsing JoinGameMessage from dictionary. + + JoinGameMessage is the first message clients send to establish their + presence in a game session. The discriminator should correctly resolve + the type from the raw dictionary. + """ + data = { + "type": "join_game", + "message_id": "test-123", + "game_id": "game-456", + } + message = parse_client_message(data) + + assert isinstance(message, JoinGameMessage) + assert message.game_id == "game-456" + assert message.message_id == "test-123" + assert message.last_event_id is None + + def test_parse_join_game_with_last_event_id(self) -> None: + """Test JoinGameMessage with reconnection event ID. + + When reconnecting after a disconnect, clients provide last_event_id + to receive any missed events since that point, enabling seamless + session recovery. + """ + data = { + "type": "join_game", + "message_id": "test-123", + "game_id": "game-456", + "last_event_id": "event-789", + } + message = parse_client_message(data) + + assert isinstance(message, JoinGameMessage) + assert message.last_event_id == "event-789" + + def test_parse_action_message(self) -> None: + """Test parsing ActionMessage with embedded game action. + + ActionMessage wraps the Action union type, so the discriminator must + handle both the message type and the nested action type correctly. + """ + data = { + "type": "action", + "message_id": "test-123", + "game_id": "game-456", + "action": {"type": "attack", "attack_index": 0}, + } + message = parse_client_message(data) + + assert isinstance(message, ActionMessage) + assert message.game_id == "game-456" + assert isinstance(message.action, AttackAction) + assert message.action.attack_index == 0 + + def test_parse_action_message_with_pass(self) -> None: + """Test parsing ActionMessage with PassAction. + + PassAction is the simplest action type - verifies that minimal + action payloads are handled correctly. + """ + data = { + "type": "action", + "message_id": "test-123", + "game_id": "game-456", + "action": {"type": "pass"}, + } + message = parse_client_message(data) + + assert isinstance(message, ActionMessage) + assert isinstance(message.action, PassAction) + + def test_parse_resign_message(self) -> None: + """Test parsing ResignMessage. + + ResignMessage allows resignation at the WebSocket layer, separate from + the game engine's ResignAction, for cases where engine interaction + may not be possible. + """ + data = { + "type": "resign", + "message_id": "test-123", + "game_id": "game-456", + } + message = parse_client_message(data) + + assert isinstance(message, ResignMessage) + assert message.game_id == "game-456" + + def test_parse_heartbeat_message(self) -> None: + """Test parsing HeartbeatMessage. + + Heartbeat messages keep the WebSocket connection alive and should + be the most minimal message type with just the type discriminator. + """ + data = { + "type": "heartbeat", + "message_id": "test-123", + } + message = parse_client_message(data) + + assert isinstance(message, HeartbeatMessage) + + def test_unknown_message_type_raises_error(self) -> None: + """Test that unknown message types raise ValidationError. + + The discriminated union should fail fast on unknown types rather + than silently accepting invalid messages. + """ + data = { + "type": "unknown_type", + "message_id": "test-123", + } + + with pytest.raises(ValidationError): + parse_client_message(data) + + def test_message_id_auto_generated(self) -> None: + """Test that message_id is auto-generated if not provided. + + Clients may omit message_id in which case a UUID should be generated, + ensuring all messages have a unique identifier for tracking. + """ + data = { + "type": "heartbeat", + } + message = parse_client_message(data) + + assert message.message_id is not None + assert len(message.message_id) > 0 + + def test_empty_message_id_rejected(self) -> None: + """Test that empty message_id is rejected. + + An explicitly empty message_id would break idempotency tracking, + so it should be rejected by validation. + """ + data = { + "type": "heartbeat", + "message_id": "", + } + + with pytest.raises(ValidationError): + parse_client_message(data) + + def test_type_adapter_client_message(self) -> None: + """Test ClientMessage union with TypeAdapter. + + TypeAdapter is the recommended way to work with discriminated unions + in Pydantic v2 - verifies it works for both parsing and validation. + """ + adapter = TypeAdapter(ClientMessage) + + data = {"type": "heartbeat", "message_id": "test"} + message = adapter.validate_python(data) + + assert isinstance(message, HeartbeatMessage) + + +class TestServerMessageDiscriminator: + """Tests for server message discriminated union parsing. + + Server messages are more complex than client messages, often containing + nested game state. The discriminator must handle all message types and + their embedded data correctly. + """ + + def test_parse_game_state_message(self) -> None: + """Test parsing GameStateMessage with full visible state. + + GameStateMessage carries the complete game state from a player's + perspective, used for initial sync and full state updates. + """ + visible_state = VisibleGameState( + game_id="game-456", + viewer_id="player-1", + players={ + "player-1": VisiblePlayerState(player_id="player-1"), + "player-2": VisiblePlayerState(player_id="player-2"), + }, + current_player_id="player-1", + turn_number=1, + phase=TurnPhase.MAIN, + is_my_turn=True, + ) + + message = GameStateMessage( + game_id="game-456", + state=visible_state, + ) + + data = message.model_dump() + parsed = parse_server_message(data) + + assert isinstance(parsed, GameStateMessage) + assert parsed.game_id == "game-456" + assert parsed.state.viewer_id == "player-1" + assert parsed.event_id is not None + + def test_parse_action_result_success(self) -> None: + """Test parsing successful ActionResultMessage. + + Successful action results should include the changes that occurred + as a result of the action for client-side state updates. + """ + message = ActionResultMessage( + game_id="game-456", + request_message_id="action-123", + success=True, + action_type="attack", + changes={"damage_dealt": 30, "defender_hp": 70}, + ) + + data = message.model_dump() + parsed = parse_server_message(data) + + assert isinstance(parsed, ActionResultMessage) + assert parsed.success is True + assert parsed.changes["damage_dealt"] == 30 + assert parsed.error_code is None + + def test_parse_action_result_failure(self) -> None: + """Test parsing failed ActionResultMessage. + + Failed action results should include error code and message + for client-side error handling and user feedback. + """ + message = ActionResultMessage( + game_id="game-456", + request_message_id="action-123", + success=False, + action_type="attack", + error_code=WSErrorCode.NOT_YOUR_TURN, + error_message="It's not your turn", + ) + + data = message.model_dump() + parsed = parse_server_message(data) + + assert isinstance(parsed, ActionResultMessage) + assert parsed.success is False + assert parsed.error_code == WSErrorCode.NOT_YOUR_TURN + + def test_parse_error_message(self) -> None: + """Test parsing ErrorMessage. + + Error messages are for protocol-level errors not tied to a specific + action, like authentication failures or invalid message format. + """ + message = ErrorMessage( + code=WSErrorCode.AUTHENTICATION_FAILED, + message="Invalid token", + details={"reason": "expired"}, + ) + + data = message.model_dump() + parsed = parse_server_message(data) + + assert isinstance(parsed, ErrorMessage) + assert parsed.code == WSErrorCode.AUTHENTICATION_FAILED + assert parsed.details["reason"] == "expired" + + def test_parse_turn_start_message(self) -> None: + """Test parsing TurnStartMessage. + + TurnStartMessage notifies both players when a turn begins, + enabling UI updates like starting a turn timer. + """ + message = TurnStartMessage( + game_id="game-456", + player_id="player-1", + turn_number=5, + ) + + data = message.model_dump() + parsed = parse_server_message(data) + + assert isinstance(parsed, TurnStartMessage) + assert parsed.turn_number == 5 + assert parsed.player_id == "player-1" + + def test_parse_turn_timeout_warning(self) -> None: + """Test parsing TurnTimeoutMessage as warning. + + Timeout warnings give players a chance to complete their action + before the turn is forcibly ended. + """ + message = TurnTimeoutMessage( + game_id="game-456", + remaining_seconds=30, + is_warning=True, + player_id="player-1", + ) + + data = message.model_dump() + parsed = parse_server_message(data) + + assert isinstance(parsed, TurnTimeoutMessage) + assert parsed.is_warning is True + assert parsed.remaining_seconds == 30 + + def test_parse_turn_timeout_expired(self) -> None: + """Test parsing TurnTimeoutMessage as expired. + + When is_warning is False, the timeout has occurred and the game + engine will force the turn to end. + """ + message = TurnTimeoutMessage( + game_id="game-456", + remaining_seconds=0, + is_warning=False, + player_id="player-1", + ) + + data = message.model_dump() + parsed = parse_server_message(data) + + assert isinstance(parsed, TurnTimeoutMessage) + assert parsed.is_warning is False + assert parsed.remaining_seconds == 0 + + def test_parse_game_over_with_winner(self) -> None: + """Test parsing GameOverMessage with a winner. + + Most games end with a winner - the final state is included so + clients can display the final board position. + """ + final_state = VisibleGameState( + game_id="game-456", + viewer_id="player-1", + winner_id="player-1", + end_reason=GameEndReason.PRIZES_TAKEN, + ) + + message = GameOverMessage( + game_id="game-456", + winner_id="player-1", + end_reason=GameEndReason.PRIZES_TAKEN, + final_state=final_state, + ) + + data = message.model_dump() + parsed = parse_server_message(data) + + assert isinstance(parsed, GameOverMessage) + assert parsed.winner_id == "player-1" + assert parsed.end_reason == GameEndReason.PRIZES_TAKEN + + def test_parse_game_over_draw(self) -> None: + """Test parsing GameOverMessage as a draw. + + Games can end in a draw (e.g., timeout with equal scores), + in which case winner_id is None. + """ + final_state = VisibleGameState( + game_id="game-456", + viewer_id="player-1", + end_reason=GameEndReason.DRAW, + ) + + message = GameOverMessage( + game_id="game-456", + winner_id=None, + end_reason=GameEndReason.DRAW, + final_state=final_state, + ) + + data = message.model_dump() + parsed = parse_server_message(data) + + assert isinstance(parsed, GameOverMessage) + assert parsed.winner_id is None + assert parsed.end_reason == GameEndReason.DRAW + + def test_parse_opponent_status_connected(self) -> None: + """Test parsing OpponentStatusMessage for connection. + + Notifies players when their opponent connects, enabling + connection status display in the UI. + """ + message = OpponentStatusMessage( + game_id="game-456", + opponent_id="player-2", + status=ConnectionStatus.CONNECTED, + ) + + data = message.model_dump() + parsed = parse_server_message(data) + + assert isinstance(parsed, OpponentStatusMessage) + assert parsed.status == ConnectionStatus.CONNECTED + + def test_parse_opponent_status_disconnected(self) -> None: + """Test parsing OpponentStatusMessage for disconnection. + + When an opponent disconnects, the UI can show a waiting indicator + and possibly pause the game timer. + """ + message = OpponentStatusMessage( + game_id="game-456", + opponent_id="player-2", + status=ConnectionStatus.DISCONNECTED, + ) + + data = message.model_dump() + parsed = parse_server_message(data) + + assert isinstance(parsed, OpponentStatusMessage) + assert parsed.status == ConnectionStatus.DISCONNECTED + + def test_parse_heartbeat_ack(self) -> None: + """Test parsing HeartbeatAckMessage. + + HeartbeatAck confirms the connection is alive, allowing clients + to measure round-trip latency. + """ + message = HeartbeatAckMessage() + + data = message.model_dump() + parsed = parse_server_message(data) + + assert isinstance(parsed, HeartbeatAckMessage) + + def test_server_message_has_timestamp(self) -> None: + """Test that all server messages have a timestamp. + + Timestamps enable client-side latency tracking and event ordering + in case of out-of-order message delivery. + """ + message = HeartbeatAckMessage() + + assert message.timestamp is not None + assert isinstance(message.timestamp, datetime) + # Should be UTC + assert message.timestamp.tzinfo == UTC + + def test_type_adapter_server_message(self) -> None: + """Test ServerMessage union with TypeAdapter. + + Verifies TypeAdapter works correctly for the more complex + server message union with nested types. + """ + adapter = TypeAdapter(ServerMessage) + + error_data = { + "type": "error", + "message_id": "test", + "timestamp": datetime.now(UTC).isoformat(), + "code": "game_not_found", + "message": "Game not found", + } + message = adapter.validate_python(error_data) + + assert isinstance(message, ErrorMessage) + + +class TestMessageSerialization: + """Tests for message JSON serialization and round-trip behavior. + + Messages must serialize cleanly to JSON for WebSocket transmission + and deserialize back to the same logical structure. + """ + + def test_client_message_round_trip(self) -> None: + """Test client message JSON round-trip. + + Messages should survive serialization to JSON and back without + losing data or changing types. + """ + original = ActionMessage( + message_id="test-123", + game_id="game-456", + action=AttackAction(attack_index=1, targets=["target-1"]), + ) + + json_str = original.model_dump_json() + import json + + data = json.loads(json_str) + restored = parse_client_message(data) + + assert isinstance(restored, ActionMessage) + assert restored.message_id == original.message_id + assert isinstance(restored.action, AttackAction) + assert restored.action.attack_index == 1 + assert restored.action.targets == ["target-1"] + + def test_server_message_round_trip(self) -> None: + """Test server message JSON round-trip. + + Server messages with nested state should serialize and deserialize + correctly, preserving all game state information. + """ + visible_state = VisibleGameState( + game_id="game-456", + viewer_id="player-1", + turn_number=3, + phase=TurnPhase.ATTACK, + ) + + original = GameStateMessage( + message_id="test-123", + game_id="game-456", + state=visible_state, + ) + + json_str = original.model_dump_json() + import json + + data = json.loads(json_str) + restored = parse_server_message(data) + + assert isinstance(restored, GameStateMessage) + assert restored.state.turn_number == 3 + assert restored.state.phase == TurnPhase.ATTACK + + def test_error_code_serializes_as_string(self) -> None: + """Test that WSErrorCode serializes as string value. + + Error codes should serialize to their string values for clean + JSON output that clients can easily handle. + """ + message = ErrorMessage( + code=WSErrorCode.GAME_NOT_FOUND, + message="Game not found", + ) + + json_str = message.model_dump_json() + import json + + data = json.loads(json_str) + + assert data["code"] == "game_not_found" + + +class TestMessageValidation: + """Tests for message field validation rules. + + Validation ensures messages contain valid data before processing, + preventing invalid state from propagating through the system. + """ + + def test_join_game_requires_game_id(self) -> None: + """Test that JoinGameMessage requires game_id. + + game_id is mandatory - clients must specify which game to join. + """ + data = { + "type": "join_game", + "message_id": "test-123", + # missing game_id + } + + with pytest.raises(ValidationError): + parse_client_message(data) + + def test_action_message_requires_action(self) -> None: + """Test that ActionMessage requires action field. + + The action field is mandatory - messages without an action + cannot be processed. + """ + data = { + "type": "action", + "message_id": "test-123", + "game_id": "game-456", + # missing action + } + + with pytest.raises(ValidationError): + parse_client_message(data) + + def test_turn_number_must_be_positive(self) -> None: + """Test that turn_number must be >= 1. + + Turn numbers start at 1 - zero or negative values are invalid. + """ + with pytest.raises(ValidationError): + TurnStartMessage( + game_id="game-456", + player_id="player-1", + turn_number=0, # Invalid + ) + + def test_remaining_seconds_non_negative(self) -> None: + """Test that remaining_seconds must be >= 0. + + Negative timeout values are meaningless and should be rejected. + """ + with pytest.raises(ValidationError): + TurnTimeoutMessage( + game_id="game-456", + remaining_seconds=-1, # Invalid + is_warning=True, + player_id="player-1", + ) + + def test_action_result_requires_request_id(self) -> None: + """Test that ActionResultMessage requires request_message_id. + + Results must link back to the original request for client-side + correlation of actions and responses. + """ + with pytest.raises(ValidationError): + ActionResultMessage( + game_id="game-456", + # missing request_message_id + success=True, + action_type="attack", + ) + + +class TestWSErrorCode: + """Tests for WebSocket error code enumeration. + + Error codes provide machine-readable error classification for + programmatic error handling on the client. + """ + + def test_all_error_codes_are_strings(self) -> None: + """Test that all error codes are string values. + + StrEnum ensures all values serialize as strings for JSON + compatibility. + """ + for code in WSErrorCode: + assert isinstance(code.value, str) + + def test_error_codes_are_lowercase_snake_case(self) -> None: + """Test error code naming convention. + + Error codes follow lowercase_snake_case for consistency with + other API identifiers and JSON conventions. + """ + for code in WSErrorCode: + assert code.value == code.value.lower() + assert " " not in code.value + + +class TestConnectionStatus: + """Tests for connection status enumeration.""" + + def test_all_statuses_are_strings(self) -> None: + """Test that all connection statuses are string values.""" + for status in ConnectionStatus: + assert isinstance(status.value, str) + + def test_expected_statuses_exist(self) -> None: + """Test that expected connection statuses are defined. + + The three key states (connected, disconnected, reconnecting) + cover all connection lifecycle phases. + """ + assert ConnectionStatus.CONNECTED + assert ConnectionStatus.DISCONNECTED + assert ConnectionStatus.RECONNECTING diff --git a/backend/tests/unit/services/test_connection_manager.py b/backend/tests/unit/services/test_connection_manager.py new file mode 100644 index 0000000..7fbbf84 --- /dev/null +++ b/backend/tests/unit/services/test_connection_manager.py @@ -0,0 +1,665 @@ +"""Tests for ConnectionManager service. + +This module tests WebSocket connection tracking with Redis. Since these are +unit tests, we mock the Redis operations to test the ConnectionManager logic +without requiring a real Redis instance. +""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, patch +from uuid import uuid4 + +import pytest + +from app.services.connection_manager import ( + CONN_PREFIX, + GAME_CONNS_PREFIX, + HEARTBEAT_INTERVAL_SECONDS, + USER_CONN_PREFIX, + ConnectionInfo, + ConnectionManager, +) + + +@pytest.fixture +def manager() -> ConnectionManager: + """Create a ConnectionManager instance for testing.""" + return ConnectionManager(conn_ttl_seconds=3600) + + +@pytest.fixture +def mock_redis() -> AsyncMock: + """Create a mock Redis client.""" + redis = AsyncMock() + redis.hset = AsyncMock() + redis.hget = AsyncMock() + redis.hgetall = AsyncMock(return_value={}) + redis.set = AsyncMock() + redis.get = AsyncMock(return_value=None) + redis.delete = AsyncMock() + redis.exists = AsyncMock(return_value=False) + redis.expire = AsyncMock() + redis.sadd = AsyncMock() + redis.srem = AsyncMock() + redis.smembers = AsyncMock(return_value=set()) + redis.scard = AsyncMock(return_value=0) + return redis + + +class TestConnectionInfoDataclass: + """Tests for the ConnectionInfo dataclass.""" + + def test_is_stale_returns_false_for_recent_connection(self) -> None: + """Test that recent connections are not marked as stale. + + Connections that have been seen within the threshold should be + considered active, not stale. + """ + now = datetime.now(UTC) + info = ConnectionInfo( + sid="test-sid", + user_id="user-123", + game_id=None, + connected_at=now, + last_seen=now, + ) + + assert info.is_stale() is False + + def test_is_stale_returns_true_for_old_connection(self) -> None: + """Test that old connections are marked as stale. + + Connections that haven't been seen for longer than the threshold + should be considered stale and eligible for cleanup. + """ + now = datetime.now(UTC) + old_time = now - timedelta(seconds=HEARTBEAT_INTERVAL_SECONDS * 4) + info = ConnectionInfo( + sid="test-sid", + user_id="user-123", + game_id=None, + connected_at=old_time, + last_seen=old_time, + ) + + assert info.is_stale() is True + + def test_is_stale_with_custom_threshold(self) -> None: + """Test is_stale with a custom threshold. + + The threshold can be adjusted for different use cases like + more aggressive cleanup or more lenient timeout. + """ + now = datetime.now(UTC) + last_seen = now - timedelta(seconds=60) + info = ConnectionInfo( + sid="test-sid", + user_id="user-123", + game_id=None, + connected_at=now, + last_seen=last_seen, + ) + + # 60 seconds old, with 30 second threshold = stale + assert info.is_stale(threshold_seconds=30) is True + + # 60 seconds old, with 120 second threshold = not stale + assert info.is_stale(threshold_seconds=120) is False + + +class TestRegisterConnection: + """Tests for connection registration.""" + + @pytest.mark.asyncio + async def test_register_connection_creates_records( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that registering a connection creates all necessary Redis records. + + Registration should create: + 1. Connection hash with user_id, game_id, timestamps + 2. User-to-connection mapping + """ + sid = "test-sid-123" + user_id = str(uuid4()) + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + await manager.register_connection(sid, user_id) + + # Verify connection hash was created + conn_key = f"{CONN_PREFIX}{sid}" + mock_redis.hset.assert_called() + hset_call = mock_redis.hset.call_args + assert hset_call.args[0] == conn_key + mapping = hset_call.kwargs["mapping"] + assert mapping["user_id"] == user_id + assert mapping["game_id"] == "" + + # Verify user-to-connection mapping was created + user_conn_key = f"{USER_CONN_PREFIX}{user_id}" + mock_redis.set.assert_called_with(user_conn_key, sid) + + @pytest.mark.asyncio + async def test_register_connection_replaces_old_connection( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that registering a new connection replaces the old one. + + When a user reconnects, their old connection should be cleaned up + to prevent stale connection data from lingering. + """ + old_sid = "old-sid" + new_sid = "new-sid" + user_id = str(uuid4()) + + # Mock: user has existing connection + mock_redis.get.return_value = old_sid + mock_redis.hgetall.return_value = { + "user_id": user_id, + "game_id": "", + "connected_at": datetime.now(UTC).isoformat(), + "last_seen": datetime.now(UTC).isoformat(), + } + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + await manager.register_connection(new_sid, user_id) + + # Verify old connection was cleaned up (delete called for old conn key) + old_conn_key = f"{CONN_PREFIX}{old_sid}" + delete_calls = [call.args[0] for call in mock_redis.delete.call_args_list] + assert old_conn_key in delete_calls + + @pytest.mark.asyncio + async def test_register_connection_accepts_uuid( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that register_connection accepts UUID objects. + + The user_id can be passed as either a string or UUID object + for convenience. + """ + sid = "test-sid" + user_uuid = uuid4() + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + await manager.register_connection(sid, user_uuid) + + # Verify user_id was converted to string + hset_call = mock_redis.hset.call_args + mapping = hset_call.kwargs["mapping"] + assert mapping["user_id"] == str(user_uuid) + + +class TestUnregisterConnection: + """Tests for connection unregistration.""" + + @pytest.mark.asyncio + async def test_unregister_returns_none_for_unknown_sid( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that unregistering unknown connection returns None. + + If the connection doesn't exist, we should return None rather than + raising an error. + """ + mock_redis.hgetall.return_value = {} + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.unregister_connection("unknown-sid") + + assert result is None + + @pytest.mark.asyncio + async def test_unregister_cleans_up_all_data( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that unregistering cleans up all related Redis data. + + Cleanup should remove: + 1. Connection hash + 2. User-to-connection mapping (if it still points to this sid) + 3. Game connection set membership + """ + sid = "test-sid" + user_id = "user-123" + game_id = "game-456" + + # Mock: connection exists with game + mock_redis.hgetall.return_value = { + "user_id": user_id, + "game_id": game_id, + "connected_at": datetime.now(UTC).isoformat(), + "last_seen": datetime.now(UTC).isoformat(), + } + mock_redis.get.return_value = sid # user mapping points to this sid + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.unregister_connection(sid) + + # Verify result + assert result is not None + assert result.sid == sid + assert result.user_id == user_id + assert result.game_id == game_id + + # Verify cleanup + mock_redis.srem.assert_called() # Removed from game set + mock_redis.delete.assert_called() # Connection deleted + + +class TestGameAssociation: + """Tests for game join/leave operations.""" + + @pytest.mark.asyncio + async def test_join_game_adds_to_game_set( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that joining a game adds the connection to the game's set. + + When a connection joins a game, it should: + 1. Update the connection's game_id + 2. Add the sid to the game's connection set + """ + sid = "test-sid" + game_id = "game-123" + + mock_redis.exists.return_value = True + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.join_game(sid, game_id) + + assert result is True + mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", game_id) + mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid) + + @pytest.mark.asyncio + async def test_join_game_returns_false_for_unknown_connection( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that joining a game fails for unknown connections. + + If the connection doesn't exist, we should return False rather + than creating orphan game association data. + """ + mock_redis.exists.return_value = False + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.join_game("unknown-sid", "game-123") + + assert result is False + mock_redis.sadd.assert_not_called() + + @pytest.mark.asyncio + async def test_join_game_leaves_previous_game( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that joining a new game leaves the previous game. + + When switching games, the connection should be removed from the + old game's connection set before joining the new one. + """ + sid = "test-sid" + old_game = "game-old" + new_game = "game-new" + + mock_redis.exists.return_value = True + mock_redis.hget.return_value = old_game # Currently in old game + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + await manager.join_game(sid, new_game) + + # Verify left old game + mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{old_game}", sid) + # Verify joined new game + mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{new_game}", sid) + + @pytest.mark.asyncio + async def test_leave_game_removes_from_set( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that leaving a game removes the connection from the set. + + Leave should: + 1. Remove sid from game's connection set + 2. Clear game_id on connection record + """ + sid = "test-sid" + game_id = "game-123" + + mock_redis.hget.return_value = game_id + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.leave_game(sid) + + assert result == game_id + mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid) + mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "") + + @pytest.mark.asyncio + async def test_leave_game_returns_none_when_not_in_game( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that leave_game returns None when not in a game. + + If the connection isn't associated with any game, we should + return None without making unnecessary Redis calls. + """ + mock_redis.hget.return_value = "" # No game_id + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.leave_game("test-sid") + + assert result is None + mock_redis.srem.assert_not_called() + + +class TestHeartbeat: + """Tests for heartbeat/activity tracking.""" + + @pytest.mark.asyncio + async def test_update_heartbeat_refreshes_last_seen( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that heartbeat updates the last_seen timestamp. + + Heartbeats keep the connection alive by updating the timestamp + and refreshing TTLs on Redis records. + """ + sid = "test-sid" + mock_redis.exists.return_value = True + mock_redis.hget.return_value = "user-123" + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.update_heartbeat(sid) + + assert result is True + # Verify last_seen was updated + hset_call = mock_redis.hset.call_args + assert hset_call.args[0] == f"{CONN_PREFIX}{sid}" + assert hset_call.args[1] == "last_seen" + # Verify TTL was refreshed + mock_redis.expire.assert_called() + + @pytest.mark.asyncio + async def test_update_heartbeat_returns_false_for_unknown( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that heartbeat returns False for unknown connections. + + If the connection doesn't exist, we shouldn't try to update it. + """ + mock_redis.exists.return_value = False + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.update_heartbeat("unknown-sid") + + assert result is False + + +class TestQueryMethods: + """Tests for connection query methods.""" + + @pytest.mark.asyncio + async def test_get_connection_returns_info( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that get_connection returns ConnectionInfo for valid sid. + + The returned ConnectionInfo should have all fields populated + from the Redis hash. + """ + sid = "test-sid" + now = datetime.now(UTC) + mock_redis.hgetall.return_value = { + "user_id": "user-123", + "game_id": "game-456", + "connected_at": now.isoformat(), + "last_seen": now.isoformat(), + } + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.get_connection(sid) + + assert result is not None + assert result.sid == sid + assert result.user_id == "user-123" + assert result.game_id == "game-456" + + @pytest.mark.asyncio + async def test_get_connection_returns_none_for_unknown( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that get_connection returns None for unknown sid.""" + mock_redis.hgetall.return_value = {} + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.get_connection("unknown-sid") + + assert result is None + + @pytest.mark.asyncio + async def test_is_user_online_returns_true_for_connected_user( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that is_user_online returns True for connected users. + + A user is online if they have an active connection that isn't stale. + """ + user_id = "user-123" + now = datetime.now(UTC) + + mock_redis.get.return_value = "test-sid" + mock_redis.hgetall.return_value = { + "user_id": user_id, + "game_id": "", + "connected_at": now.isoformat(), + "last_seen": now.isoformat(), + } + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.is_user_online(user_id) + + assert result is True + + @pytest.mark.asyncio + async def test_is_user_online_returns_false_for_stale_connection( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that is_user_online returns False for stale connections. + + Even if a connection record exists, if it's stale (no recent heartbeat), + the user should be considered offline. + """ + user_id = "user-123" + old_time = datetime.now(UTC) - timedelta(minutes=5) + + mock_redis.get.return_value = "test-sid" + mock_redis.hgetall.return_value = { + "user_id": user_id, + "game_id": "", + "connected_at": old_time.isoformat(), + "last_seen": old_time.isoformat(), + } + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.is_user_online(user_id) + + assert result is False + + @pytest.mark.asyncio + async def test_get_game_connections_returns_all_participants( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that get_game_connections returns all game participants. + + Should return ConnectionInfo for each sid in the game's connection set. + """ + game_id = "game-123" + now = datetime.now(UTC) + + mock_redis.smembers.return_value = {"sid-1", "sid-2"} + mock_redis.hgetall.side_effect = [ + { + "user_id": "user-1", + "game_id": game_id, + "connected_at": now.isoformat(), + "last_seen": now.isoformat(), + }, + { + "user_id": "user-2", + "game_id": game_id, + "connected_at": now.isoformat(), + "last_seen": now.isoformat(), + }, + ] + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.get_game_connections(game_id) + + assert len(result) == 2 + user_ids = {conn.user_id for conn in result} + assert user_ids == {"user-1", "user-2"} + + @pytest.mark.asyncio + async def test_get_opponent_sid_returns_other_player( + self, + manager: ConnectionManager, + mock_redis: AsyncMock, + ) -> None: + """Test that get_opponent_sid returns the other player's sid. + + In a 2-player game, this should return the sid of the player + who is not the current user. + """ + game_id = "game-123" + current_user = "user-1" + opponent_user = "user-2" + opponent_sid = "sid-2" + now = datetime.now(UTC) + + mock_redis.smembers.return_value = {"sid-1", "sid-2"} + + # Use a function to return the right data based on the key queried + # This avoids dependency on set iteration order + def hgetall_by_key(key: str) -> dict[str, str]: + if key == f"{CONN_PREFIX}sid-1": + return { + "user_id": current_user, + "game_id": game_id, + "connected_at": now.isoformat(), + "last_seen": now.isoformat(), + } + elif key == f"{CONN_PREFIX}sid-2": + return { + "user_id": opponent_user, + "game_id": game_id, + "connected_at": now.isoformat(), + "last_seen": now.isoformat(), + } + return {} + + mock_redis.hgetall.side_effect = hgetall_by_key + + with patch("app.services.connection_manager.get_redis") as mock_get_redis: + mock_get_redis.return_value.__aenter__.return_value = mock_redis + + result = await manager.get_opponent_sid(game_id, current_user) + + assert result == opponent_sid + + +class TestKeyGeneration: + """Tests for Redis key generation methods.""" + + def test_conn_key_format(self, manager: ConnectionManager) -> None: + """Test that connection keys have correct format. + + Keys should follow the pattern conn:{sid} for easy identification + and pattern matching. + """ + key = manager._conn_key("test-sid-123") + assert key == "conn:test-sid-123" + + def test_user_conn_key_format(self, manager: ConnectionManager) -> None: + """Test that user connection keys have correct format. + + Keys should follow the pattern user_conn:{user_id}. + """ + key = manager._user_conn_key("user-456") + assert key == "user_conn:user-456" + + def test_game_conns_key_format(self, manager: ConnectionManager) -> None: + """Test that game connections keys have correct format. + + Keys should follow the pattern game_conns:{game_id}. + """ + key = manager._game_conns_key("game-789") + assert key == "game_conns:game-789" diff --git a/backend/tests/unit/services/test_game_service.py b/backend/tests/unit/services/test_game_service.py new file mode 100644 index 0000000..53916e0 --- /dev/null +++ b/backend/tests/unit/services/test_game_service.py @@ -0,0 +1,809 @@ +"""Tests for GameService. + +This module tests the game service layer that orchestrates between +WebSocket communication and the core GameEngine. + +The GameService is STATELESS regarding game rules: +- No GameEngine is stored in the service +- Engine is created per-operation using rules from GameState +- Rules come from frontend at game creation, stored in GameState +""" + +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest + +from app.core.engine import ActionResult +from app.core.enums import GameEndReason, TurnPhase +from app.core.models.actions import AttackAction, PassAction, ResignAction +from app.core.models.game_state import GameState, PlayerState +from app.core.win_conditions import WinResult +from app.services.game_service import ( + GameAlreadyEndedError, + GameNotFoundError, + GameService, + InvalidActionError, + NotPlayerTurnError, + PlayerNotInGameError, +) + + +@pytest.fixture +def mock_state_manager() -> AsyncMock: + """Create a mock GameStateManager. + + The state manager handles persistence to Redis (cache) and + Postgres (durable storage). + """ + manager = AsyncMock() + manager.load_state = AsyncMock(return_value=None) + manager.save_to_cache = AsyncMock() + manager.persist_to_db = AsyncMock() + manager.cache_exists = AsyncMock(return_value=False) + return manager + + +@pytest.fixture +def mock_card_service() -> MagicMock: + """Create a mock CardService. + + CardService provides card definitions for game creation. + """ + return MagicMock() + + +@pytest.fixture +def game_service( + mock_state_manager: AsyncMock, + mock_card_service: MagicMock, +) -> GameService: + """Create a GameService with mocked dependencies. + + Note: No engine is passed - GameService creates engines per-operation + using rules stored in each game's GameState. + """ + return GameService( + state_manager=mock_state_manager, + card_service=mock_card_service, + ) + + +@pytest.fixture +def sample_game_state() -> GameState: + """Create a sample game state for testing. + + The game state includes two players and basic turn tracking. + The rules are stored in the state itself (default RulesConfig). + """ + player1 = PlayerState(player_id="player-1") + player2 = PlayerState(player_id="player-2") + + return GameState( + game_id="game-123", + players={"player-1": player1, "player-2": player2}, + current_player_id="player-1", + turn_number=1, + phase=TurnPhase.MAIN, + ) + + +class TestGameStateAccess: + """Tests for game state access methods.""" + + @pytest.mark.asyncio + async def test_get_game_state_returns_state( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test that get_game_state returns the game state when found. + + The state manager should be called to load the state, and the + result should be returned to the caller. + """ + mock_state_manager.load_state.return_value = sample_game_state + + result = await game_service.get_game_state("game-123") + + assert result == sample_game_state + mock_state_manager.load_state.assert_called_once_with("game-123") + + @pytest.mark.asyncio + async def test_get_game_state_raises_not_found( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + ) -> None: + """Test that get_game_state raises GameNotFoundError when not found. + + When the state manager returns None, we should raise a specific + exception rather than returning None. + """ + mock_state_manager.load_state.return_value = None + + with pytest.raises(GameNotFoundError) as exc_info: + await game_service.get_game_state("nonexistent") + + assert exc_info.value.game_id == "nonexistent" + + @pytest.mark.asyncio + async def test_get_player_view_returns_visible_state( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test that get_player_view returns visibility-filtered state. + + The returned state should be filtered for the requesting player, + hiding the opponent's private information. + """ + mock_state_manager.load_state.return_value = sample_game_state + + result = await game_service.get_player_view("game-123", "player-1") + + assert result.game_id == "game-123" + assert result.viewer_id == "player-1" + assert result.is_my_turn is True + + @pytest.mark.asyncio + async def test_get_player_view_raises_not_in_game( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test that get_player_view raises error for non-participants. + + Only players in the game should be able to view its state. + """ + mock_state_manager.load_state.return_value = sample_game_state + + with pytest.raises(PlayerNotInGameError) as exc_info: + await game_service.get_player_view("game-123", "stranger") + + assert exc_info.value.player_id == "stranger" + + @pytest.mark.asyncio + async def test_is_player_turn_returns_true_for_current_player( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test that is_player_turn returns True for the current player. + + The current player is determined by the game state's current_player_id. + """ + mock_state_manager.load_state.return_value = sample_game_state + + result = await game_service.is_player_turn("game-123", "player-1") + + assert result is True + + @pytest.mark.asyncio + async def test_is_player_turn_returns_false_for_other_player( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test that is_player_turn returns False for non-current player. + + Players who are not the current player should get False. + """ + mock_state_manager.load_state.return_value = sample_game_state + + result = await game_service.is_player_turn("game-123", "player-2") + + assert result is False + + @pytest.mark.asyncio + async def test_game_exists_returns_true_when_cached( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + ) -> None: + """Test that game_exists returns True when game is in cache. + + This is a quick check that doesn't need to load the full state. + """ + mock_state_manager.cache_exists.return_value = True + + result = await game_service.game_exists("game-123") + + assert result is True + mock_state_manager.cache_exists.assert_called_once_with("game-123") + + +class TestJoinGame: + """Tests for the join_game method.""" + + @pytest.mark.asyncio + async def test_join_game_success( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test successful game join returns visible state. + + When a player joins their game, they should receive the + visibility-filtered state and know if it's their turn. + """ + mock_state_manager.load_state.return_value = sample_game_state + + result = await game_service.join_game("game-123", "player-1") + + assert result.success is True + assert result.game_id == "game-123" + assert result.player_id == "player-1" + assert result.visible_state is not None + assert result.is_your_turn is True + + @pytest.mark.asyncio + async def test_join_game_not_your_turn( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test that is_your_turn is False when it's opponent's turn. + + The second player should see is_your_turn=False when the first + player is the current player. + """ + mock_state_manager.load_state.return_value = sample_game_state + + result = await game_service.join_game("game-123", "player-2") + + assert result.success is True + assert result.is_your_turn is False + + @pytest.mark.asyncio + async def test_join_game_not_found( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + ) -> None: + """Test join_game returns failure for non-existent game. + + Rather than raising an exception, join_game returns a failed + result for better WebSocket error handling. + """ + mock_state_manager.load_state.return_value = None + + result = await game_service.join_game("nonexistent", "player-1") + + assert result.success is False + assert "not found" in result.message.lower() + + @pytest.mark.asyncio + async def test_join_game_not_participant( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test join_game returns failure for non-participants. + + Players who are not in the game should not be able to join. + """ + mock_state_manager.load_state.return_value = sample_game_state + + result = await game_service.join_game("game-123", "stranger") + + assert result.success is False + assert "not a participant" in result.message.lower() + + @pytest.mark.asyncio + async def test_join_ended_game( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test joining an ended game still succeeds but indicates game over. + + Players should be able to rejoin ended games to see the final + state, but is_your_turn should be False. + """ + sample_game_state.winner_id = "player-1" + sample_game_state.end_reason = GameEndReason.PRIZES_TAKEN + mock_state_manager.load_state.return_value = sample_game_state + + result = await game_service.join_game("game-123", "player-2") + + assert result.success is True + assert result.is_your_turn is False + assert "ended" in result.message.lower() + + +class TestExecuteAction: + """Tests for the execute_action method. + + These tests verify action execution through GameService. Since GameService + creates engines per-operation, we patch _create_engine_for_game to return + a mock engine with controlled behavior. + """ + + @pytest.mark.asyncio + async def test_execute_action_success( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test successful action execution. + + A valid action by the current player should be executed and + the state should be saved to cache. + """ + mock_state_manager.load_state.return_value = sample_game_state + + mock_engine = MagicMock() + mock_engine.execute_action = AsyncMock( + return_value=ActionResult( + success=True, + message="Attack executed", + state_changes=[{"type": "damage", "amount": 30}], + ) + ) + + with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine): + action = AttackAction(attack_index=0) + result = await game_service.execute_action("game-123", "player-1", action) + + assert result.success is True + assert result.action_type == "attack" + mock_state_manager.save_to_cache.assert_called_once() + + @pytest.mark.asyncio + async def test_execute_action_game_not_found( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + ) -> None: + """Test execute_action raises error when game not found. + + Missing games should raise GameNotFoundError for proper + error handling in the WebSocket layer. + """ + mock_state_manager.load_state.return_value = None + + with pytest.raises(GameNotFoundError): + await game_service.execute_action("nonexistent", "player-1", PassAction()) + + @pytest.mark.asyncio + async def test_execute_action_not_in_game( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test execute_action raises error for non-participants. + + Only players in the game can execute actions. + """ + mock_state_manager.load_state.return_value = sample_game_state + + with pytest.raises(PlayerNotInGameError): + await game_service.execute_action("game-123", "stranger", PassAction()) + + @pytest.mark.asyncio + async def test_execute_action_game_ended( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test execute_action raises error on ended games. + + No actions should be allowed once a game has ended. + """ + sample_game_state.winner_id = "player-1" + sample_game_state.end_reason = GameEndReason.RESIGNATION + mock_state_manager.load_state.return_value = sample_game_state + + with pytest.raises(GameAlreadyEndedError): + await game_service.execute_action("game-123", "player-1", PassAction()) + + @pytest.mark.asyncio + async def test_execute_action_not_your_turn( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test execute_action raises error when not player's turn. + + Only the current player can execute actions. + """ + mock_state_manager.load_state.return_value = sample_game_state + + with pytest.raises(NotPlayerTurnError) as exc_info: + await game_service.execute_action("game-123", "player-2", PassAction()) + + assert exc_info.value.player_id == "player-2" + assert exc_info.value.current_player_id == "player-1" + + @pytest.mark.asyncio + async def test_execute_action_resign_allowed_out_of_turn( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test that resignation is allowed even when not your turn. + + Resignation is a special action that can be executed anytime + by either player. + """ + mock_state_manager.load_state.return_value = sample_game_state + + mock_engine = MagicMock() + mock_engine.execute_action = AsyncMock( + return_value=ActionResult( + success=True, + message="Player resigned", + win_result=WinResult( + winner_id="player-1", + loser_id="player-2", + end_reason=GameEndReason.RESIGNATION, + reason="Player resigned", + ), + ) + ) + + with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine): + # player-2 resigns even though it's player-1's turn + result = await game_service.execute_action("game-123", "player-2", ResignAction()) + + assert result.success is True + assert result.game_over is True + assert result.winner_id == "player-1" + + @pytest.mark.asyncio + async def test_execute_action_invalid_action( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test execute_action raises error for invalid actions. + + When the GameEngine rejects an action, we should raise + InvalidActionError with the reason. + """ + mock_state_manager.load_state.return_value = sample_game_state + + mock_engine = MagicMock() + mock_engine.execute_action = AsyncMock( + return_value=ActionResult( + success=False, + message="Not enough energy to attack", + ) + ) + + with ( + patch.object(game_service, "_create_engine_for_game", return_value=mock_engine), + pytest.raises(InvalidActionError) as exc_info, + ): + await game_service.execute_action("game-123", "player-1", AttackAction(attack_index=0)) + + assert "Not enough energy" in exc_info.value.reason + + @pytest.mark.asyncio + async def test_execute_action_game_over( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test execute_action detects game over and persists to DB. + + When an action results in a win, the state should be persisted + to the database for durability. + """ + mock_state_manager.load_state.return_value = sample_game_state + + mock_engine = MagicMock() + mock_engine.execute_action = AsyncMock( + return_value=ActionResult( + success=True, + message="Final prize taken!", + win_result=WinResult( + winner_id="player-1", + loser_id="player-2", + end_reason=GameEndReason.PRIZES_TAKEN, + reason="All prizes taken", + ), + ) + ) + + with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine): + result = await game_service.execute_action( + "game-123", "player-1", AttackAction(attack_index=0) + ) + + assert result.game_over is True + assert result.winner_id == "player-1" + assert result.end_reason == GameEndReason.PRIZES_TAKEN + mock_state_manager.persist_to_db.assert_called_once() + + @pytest.mark.asyncio + async def test_execute_action_uses_game_rules( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test that execute_action creates engine with game's rules. + + The engine should be created using the rules stored in the game + state, not any service-level defaults. + """ + mock_state_manager.load_state.return_value = sample_game_state + + mock_engine = MagicMock() + mock_engine.execute_action = AsyncMock( + return_value=ActionResult(success=True, message="OK") + ) + + with patch.object( + game_service, "_create_engine_for_game", return_value=mock_engine + ) as mock_create: + await game_service.execute_action("game-123", "player-1", PassAction()) + + # Verify engine was created with the game state + mock_create.assert_called_once_with(sample_game_state) + + +class TestResignGame: + """Tests for the resign_game convenience method.""" + + @pytest.mark.asyncio + async def test_resign_game_executes_resign_action( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test that resign_game is a convenience wrapper for execute_action. + + The resign_game method should internally create a ResignAction + and call execute_action. + """ + mock_state_manager.load_state.return_value = sample_game_state + + mock_engine = MagicMock() + mock_engine.execute_action = AsyncMock( + return_value=ActionResult( + success=True, + message="Player resigned", + win_result=WinResult( + winner_id="player-2", + loser_id="player-1", + end_reason=GameEndReason.RESIGNATION, + reason="Player resigned", + ), + ) + ) + + with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine): + result = await game_service.resign_game("game-123", "player-1") + + assert result.success is True + assert result.action_type == "resign" + assert result.game_over is True + + +class TestEndGame: + """Tests for the end_game method (forced ending).""" + + @pytest.mark.asyncio + async def test_end_game_sets_winner_and_reason( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test that end_game sets winner and end reason. + + Used for timeout or disconnection scenarios where the game + needs to be forcibly ended. + """ + mock_state_manager.load_state.return_value = sample_game_state + + await game_service.end_game( + "game-123", + winner_id="player-1", + end_reason=GameEndReason.TIMEOUT, + ) + + # Verify state was modified + assert sample_game_state.winner_id == "player-1" + assert sample_game_state.end_reason == GameEndReason.TIMEOUT + + # Verify persistence + mock_state_manager.save_to_cache.assert_called_once() + mock_state_manager.persist_to_db.assert_called_once() + + @pytest.mark.asyncio + async def test_end_game_draw( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + sample_game_state: GameState, + ) -> None: + """Test ending a game as a draw (no winner). + + Some scenarios (timeout with equal scores) can result in a draw. + """ + mock_state_manager.load_state.return_value = sample_game_state + + await game_service.end_game( + "game-123", + winner_id=None, + end_reason=GameEndReason.DRAW, + ) + + assert sample_game_state.winner_id is None + assert sample_game_state.end_reason == GameEndReason.DRAW + + @pytest.mark.asyncio + async def test_end_game_not_found( + self, + game_service: GameService, + mock_state_manager: AsyncMock, + ) -> None: + """Test end_game raises error when game not found.""" + mock_state_manager.load_state.return_value = None + + with pytest.raises(GameNotFoundError): + await game_service.end_game( + "nonexistent", + winner_id="player-1", + end_reason=GameEndReason.TIMEOUT, + ) + + +class TestCreateGame: + """Tests for the create_game method (skeleton).""" + + @pytest.mark.asyncio + async def test_create_game_raises_not_implemented( + self, + game_service: GameService, + ) -> None: + """Test that create_game raises NotImplementedError. + + The full implementation will be done in GS-002. For now, + it should raise NotImplementedError with a clear message. + """ + with pytest.raises(NotImplementedError) as exc_info: + await game_service.create_game( + player1_id=str(uuid4()), + player2_id=str(uuid4()), + ) + + assert "GS-002" in str(exc_info.value) + + +class TestCreateEngineForGame: + """Tests for the _create_engine_for_game method. + + This method is responsible for creating a GameEngine configured + with the rules from a specific game's state. + """ + + def test_create_engine_uses_game_rules( + self, + game_service: GameService, + sample_game_state: GameState, + ) -> None: + """Test that engine is created with the game's rules. + + The engine should use the RulesConfig stored in the game state, + not any default configuration. + """ + engine = game_service._create_engine_for_game(sample_game_state) + + # Engine should have the game's rules + assert engine.rules == sample_game_state.rules + + def test_create_engine_with_rng_seed( + self, + game_service: GameService, + sample_game_state: GameState, + ) -> None: + """Test that engine uses seeded RNG when game has rng_seed. + + When a game has an rng_seed set, the engine should use a + deterministic RNG for replay support. + """ + sample_game_state.rng_seed = 12345 + + engine = game_service._create_engine_for_game(sample_game_state) + + # Engine should have been created (we can't easily verify seed, + # but we can verify it doesn't error) + assert engine is not None + + def test_create_engine_without_rng_seed( + self, + game_service: GameService, + sample_game_state: GameState, + ) -> None: + """Test that engine uses secure RNG when no seed is set. + + Without an rng_seed, the engine should use cryptographically + secure random number generation. + """ + sample_game_state.rng_seed = None + + engine = game_service._create_engine_for_game(sample_game_state) + + assert engine is not None + + def test_create_engine_derives_unique_seed_per_action( + self, + game_service: GameService, + sample_game_state: GameState, + ) -> None: + """Test that different action counts produce different RNG sequences. + + For deterministic replay, each action needs a unique but + reproducible RNG seed based on game seed + action count. + """ + sample_game_state.rng_seed = 12345 + + # Simulate first action (action_log is empty) + sample_game_state.action_log = [] + engine1 = game_service._create_engine_for_game(sample_game_state) + + # Simulate second action (one action in log) + sample_game_state.action_log = [{"type": "pass"}] + engine2 = game_service._create_engine_for_game(sample_game_state) + + # Both engines should be created successfully + # (They will have different seeds due to action count) + assert engine1 is not None + assert engine2 is not None + + +class TestExceptionMessages: + """Tests for exception message formatting.""" + + def test_game_not_found_error_message(self) -> None: + """Test GameNotFoundError has descriptive message.""" + error = GameNotFoundError("game-123") + assert "game-123" in str(error) + assert error.game_id == "game-123" + + def test_not_player_turn_error_message(self) -> None: + """Test NotPlayerTurnError has descriptive message.""" + error = NotPlayerTurnError("game-123", "player-2", "player-1") + assert "player-2" in str(error) + assert "player-1" in str(error) + assert error.game_id == "game-123" + + def test_invalid_action_error_message(self) -> None: + """Test InvalidActionError has descriptive message.""" + error = InvalidActionError("game-123", "player-1", "Not enough energy") + assert "Not enough energy" in str(error) + assert error.reason == "Not enough energy" + + def test_player_not_in_game_error_message(self) -> None: + """Test PlayerNotInGameError has descriptive message.""" + error = PlayerNotInGameError("game-123", "stranger") + assert "stranger" in str(error) + assert "game-123" in str(error) + + def test_game_already_ended_error_message(self) -> None: + """Test GameAlreadyEndedError has descriptive message.""" + error = GameAlreadyEndedError("game-123") + assert "game-123" in str(error) + assert "ended" in str(error).lower()