"""WebSocket message schemas for Mantimon TCG real-time communication. This module defines Pydantic models for all WebSocket message types, following the discriminated union pattern from app/core/models/actions.py. Messages are categorized into client-to-server and server-to-client types. Message Design Principles: 1. All messages have a 'type' discriminator field for automatic parsing 2. All messages have a 'message_id' for idempotency and event ordering 3. Server messages include a 'timestamp' for client-side latency tracking 4. Game-scoped messages include 'game_id' for multi-game support Example: # Parse incoming client message data = {"type": "join_game", "message_id": "abc-123", "game_id": "game-456"} message = parse_client_message(data) assert isinstance(message, JoinGameMessage) # Create server response response = GameStateMessage( message_id=str(uuid4()), game_id="game-456", state=visible_state, ) await websocket.send_text(response.model_dump_json()) """ from datetime import UTC, datetime from enum import StrEnum from typing import Annotated, Any, Literal from uuid import uuid4 from pydantic import BaseModel, Field, field_validator from app.core.enums import GameEndReason from app.core.models.actions import Action from app.core.visibility import VisibleGameState class WSErrorCode(StrEnum): """Error codes for WebSocket error messages. These codes help clients handle errors programmatically without parsing error message text. """ # Connection errors AUTHENTICATION_FAILED = "authentication_failed" CONNECTION_CLOSED = "connection_closed" RATE_LIMITED = "rate_limited" # Game errors GAME_NOT_FOUND = "game_not_found" NOT_IN_GAME = "not_in_game" ALREADY_IN_GAME = "already_in_game" GAME_FULL = "game_full" GAME_ENDED = "game_ended" # Action errors INVALID_ACTION = "invalid_action" NOT_YOUR_TURN = "not_your_turn" ACTION_NOT_ALLOWED = "action_not_allowed" # Protocol errors INVALID_MESSAGE = "invalid_message" UNKNOWN_MESSAGE_TYPE = "unknown_message_type" # Server errors INTERNAL_ERROR = "internal_error" class ConnectionStatus(StrEnum): """Connection status for opponent status messages.""" CONNECTED = "connected" DISCONNECTED = "disconnected" RECONNECTING = "reconnecting" # ============================================================================= # Base Message Classes # ============================================================================= def _generate_message_id() -> str: """Generate a unique message ID.""" return str(uuid4()) def _utc_now() -> datetime: """Get current UTC timestamp.""" return datetime.now(UTC) class BaseClientMessage(BaseModel): """Base class for all client-to-server messages. All client messages must include a message_id for idempotency. The client generates this ID to allow detecting and handling duplicate messages. Attributes: message_id: Client-generated UUID for idempotency and tracking. """ message_id: str = Field( default_factory=_generate_message_id, description="Client-generated UUID for idempotency", ) @field_validator("message_id") @classmethod def validate_message_id(cls, v: str) -> str: """Ensure message_id is not empty.""" if not v or not v.strip(): raise ValueError("message_id cannot be empty") return v class BaseServerMessage(BaseModel): """Base class for all server-to-client messages. All server messages include a message_id and timestamp. The timestamp enables client-side latency calculation and event ordering. Attributes: message_id: Server-generated UUID for tracking. timestamp: UTC timestamp when the message was created. """ message_id: str = Field( default_factory=_generate_message_id, description="Server-generated UUID for tracking", ) timestamp: datetime = Field( default_factory=_utc_now, description="UTC timestamp when message was created", ) # ============================================================================= # Client -> Server Messages # ============================================================================= class JoinGameMessage(BaseClientMessage): """Request to join or rejoin a game session. When rejoining after a disconnect, the client can provide last_event_id to receive any missed events since that point. Attributes: type: Discriminator field, always "join_game". game_id: The game to join. last_event_id: For reconnection - ID of last received event for replay. """ type: Literal["join_game"] = "join_game" game_id: str = Field(..., description="ID of the game to join") last_event_id: str | None = Field( default=None, description="Last event ID received (for reconnection replay)", ) class ActionMessage(BaseClientMessage): """Submit a game action for execution. Wraps the Action union type from app/core/models/actions.py to provide consistent message envelope with message_id and game context. Attributes: type: Discriminator field, always "action". game_id: The game this action is for. action: The game action to execute. """ type: Literal["action"] = "action" game_id: str = Field(..., description="ID of the game") action: Action = Field(..., description="The game action to execute") class ResignMessage(BaseClientMessage): """Request to resign from a game. This is separate from ResignAction in the game engine to allow resignation handling at the WebSocket layer (e.g., when the game engine is unavailable). Attributes: type: Discriminator field, always "resign". game_id: The game to resign from. """ type: Literal["resign"] = "resign" game_id: str = Field(..., description="ID of the game to resign from") class HeartbeatMessage(BaseClientMessage): """Keep-alive message to maintain connection. Clients should send heartbeats periodically (e.g., every 30 seconds) to prevent connection timeout. The server responds with a HeartbeatAck. Attributes: type: Discriminator field, always "heartbeat". """ type: Literal["heartbeat"] = "heartbeat" # Union type for client messages ClientMessage = Annotated[ JoinGameMessage | ActionMessage | ResignMessage | HeartbeatMessage, Field(discriminator="type"), ] # ============================================================================= # Server -> Client Messages # ============================================================================= class GameStateMessage(BaseServerMessage): """Full game state update. Sent when a player joins a game or when a full state sync is needed. Contains the complete visible game state from that player's perspective. Attributes: type: Discriminator field, always "game_state". game_id: The game this state is for. state: The full visible game state. event_id: Monotonic event ID for reconnection replay. spectator_count: Number of users spectating this game. """ type: Literal["game_state"] = "game_state" game_id: str = Field(..., description="ID of the game") state: VisibleGameState = Field(..., description="Full visible game state") event_id: str = Field( default_factory=_generate_message_id, description="Event ID for reconnection replay", ) spectator_count: int = Field( default=0, description="Number of users spectating this game", ge=0, ) class ActionResultMessage(BaseServerMessage): """Result of a player action. Sent after processing an ActionMessage to confirm success or failure. On success, includes changes that resulted from the action. Attributes: type: Discriminator field, always "action_result". game_id: The game this result is for. request_message_id: The message_id of the original ActionMessage. success: Whether the action succeeded. action_type: The type of action that was attempted. changes: Description of state changes (on success). error_code: Error code if action failed. error_message: Human-readable error message if failed. """ type: Literal["action_result"] = "action_result" game_id: str = Field(..., description="ID of the game") request_message_id: str = Field(..., description="message_id of the original action request") success: bool = Field(..., description="Whether the action succeeded") action_type: str = Field(..., description="Type of action that was attempted") changes: dict[str, Any] = Field( default_factory=dict, description="State changes resulting from the action", ) error_code: WSErrorCode | None = Field(default=None, description="Error code if action failed") error_message: str | None = Field(default=None, description="Human-readable error message") class ErrorMessage(BaseServerMessage): """Error notification for protocol or connection errors. Used for errors not associated with a specific action, such as invalid message format, authentication failures, or server errors. Attributes: type: Discriminator field, always "error". code: Machine-readable error code. message: Human-readable error description. details: Additional error context. request_message_id: ID of the message that caused the error, if applicable. """ type: Literal["error"] = "error" code: WSErrorCode = Field(..., description="Error code") message: str = Field(..., description="Human-readable error description") details: dict[str, Any] = Field(default_factory=dict, description="Additional error context") request_message_id: str | None = Field( default=None, description="ID of message that caused error" ) class TurnStartMessage(BaseServerMessage): """Notification that a new turn has started. Sent to both players when a turn begins. Indicates whose turn it is and the current turn number. Attributes: type: Discriminator field, always "turn_start". game_id: The game this notification is for. player_id: The player whose turn is starting. turn_number: The current turn number. event_id: Event ID for reconnection replay. """ type: Literal["turn_start"] = "turn_start" game_id: str = Field(..., description="ID of the game") player_id: str = Field(..., description="Player whose turn is starting") turn_number: int = Field(..., description="Current turn number", ge=1) event_id: str = Field( default_factory=_generate_message_id, description="Event ID for reconnection replay", ) class TurnTimeoutMessage(BaseServerMessage): """Timeout warning or expiration notification. Sent when a player's turn is approaching timeout (warning) or has expired. Warnings give players time to complete their action. Attributes: type: Discriminator field, always "turn_timeout". game_id: The game this notification is for. remaining_seconds: Seconds remaining before timeout. is_warning: True if this is a warning, False if timeout has occurred. player_id: The player whose turn is timing out. """ type: Literal["turn_timeout"] = "turn_timeout" game_id: str = Field(..., description="ID of the game") remaining_seconds: int = Field(..., description="Seconds remaining before timeout", ge=0) is_warning: bool = Field(..., description="True if warning, False if timeout occurred") player_id: str = Field(..., description="Player whose turn is timing out") class GameOverMessage(BaseServerMessage): """Notification that the game has ended. Sent to all players when the game concludes, regardless of the reason. Attributes: type: Discriminator field, always "game_over". game_id: The game that ended. winner_id: The winning player, or None for a draw. end_reason: Why the game ended. final_state: The final visible game state. event_id: Event ID for reconnection replay. """ type: Literal["game_over"] = "game_over" game_id: str = Field(..., description="ID of the game that ended") winner_id: str | None = Field(default=None, description="Winner player ID, or None for draw") end_reason: GameEndReason = Field(..., description="Reason the game ended") final_state: VisibleGameState = Field(..., description="Final visible game state") event_id: str = Field( default_factory=_generate_message_id, description="Event ID for reconnection replay", ) class OpponentStatusMessage(BaseServerMessage): """Notification of opponent connection status change. Sent when the opponent connects, disconnects, or is reconnecting. Allows the UI to show connection status to the player. Attributes: type: Discriminator field, always "opponent_status". game_id: The game this status is for. opponent_id: The opponent's player ID. status: The opponent's current connection status. """ type: Literal["opponent_status"] = "opponent_status" game_id: str = Field(..., description="ID of the game") opponent_id: str = Field(..., description="Opponent's player ID") status: ConnectionStatus = Field(..., description="Opponent's connection status") class HeartbeatAckMessage(BaseServerMessage): """Acknowledgment of a client heartbeat. Sent in response to HeartbeatMessage to confirm the connection is alive. Attributes: type: Discriminator field, always "heartbeat_ack". """ type: Literal["heartbeat_ack"] = "heartbeat_ack" # Union type for server messages ServerMessage = Annotated[ GameStateMessage | ActionResultMessage | ErrorMessage | TurnStartMessage | TurnTimeoutMessage | GameOverMessage | OpponentStatusMessage | HeartbeatAckMessage, Field(discriminator="type"), ] # ============================================================================= # Parsing Functions # ============================================================================= def parse_client_message(data: dict[str, Any]) -> ClientMessage: """Parse an incoming client message from a dictionary. Uses the 'type' field to determine which message model to use. Args: data: Dictionary containing message data with a 'type' field. Returns: The appropriate ClientMessage subtype. Raises: ValueError: If the message type is unknown. ValidationError: If the data doesn't match the message schema. Example: data = {"type": "join_game", "message_id": "abc", "game_id": "123"} message = parse_client_message(data) assert isinstance(message, JoinGameMessage) """ from pydantic import TypeAdapter adapter: TypeAdapter[ClientMessage] = TypeAdapter(ClientMessage) return adapter.validate_python(data) def parse_server_message(data: dict[str, Any]) -> ServerMessage: """Parse a server message from a dictionary. Primarily used for testing - clients receive JSON and parse directly. Args: data: Dictionary containing message data with a 'type' field. Returns: The appropriate ServerMessage subtype. Raises: ValueError: If the message type is unknown. ValidationError: If the data doesn't match the message schema. """ from pydantic import TypeAdapter adapter: TypeAdapter[ServerMessage] = TypeAdapter(ServerMessage) return adapter.validate_python(data) __all__ = [ # Enums "ConnectionStatus", "WSErrorCode", # Base classes "BaseClientMessage", "BaseServerMessage", # Client messages "ActionMessage", "ClientMessage", "HeartbeatMessage", "JoinGameMessage", "ResignMessage", # Server messages "ActionResultMessage", "ErrorMessage", "GameOverMessage", "GameStateMessage", "HeartbeatAckMessage", "OpponentStatusMessage", "ServerMessage", "TurnStartMessage", "TurnTimeoutMessage", # Parsing functions "parse_client_message", "parse_server_message", ]