mantimon-tcg/backend/app/schemas/ws_messages.py
Cal Corum 0c810e5b30 Add Phase 4 WebSocket infrastructure (WS-001 through GS-001)
WebSocket Message Schemas (WS-002):
- Add Pydantic models for all client/server WebSocket messages
- Implement discriminated unions for message type parsing
- Include JoinGame, Action, Resign, Heartbeat client messages
- Include GameState, ActionResult, Error, TurnStart server messages

Connection Manager (WS-003):
- Add Redis-backed WebSocket connection tracking
- Implement user-to-sid mapping with TTL management
- Support game room association and opponent lookup
- Add heartbeat tracking for connection health

Socket.IO Authentication (WS-004):
- Add JWT-based authentication middleware
- Support token extraction from multiple formats
- Implement session setup with ConnectionManager integration
- Add require_auth helper for event handlers

Socket.IO Server Setup (WS-001):
- Configure AsyncServer with ASGI mode
- Register /game namespace with event handlers
- Integrate with FastAPI via ASGIApp wrapper
- Configure CORS from application settings

Game Service (GS-001):
- Add stateless GameService for game lifecycle orchestration
- Create engine per-operation using rules from GameState
- Implement action-based RNG seeding for deterministic replay
- Add rng_seed field to GameState for replay support

Architecture verified:
- Core module independence (no forbidden imports)
- Config from request pattern (rules in GameState)
- Dependency injection (constructor deps, method config)
- All 1090 tests passing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 22:21:20 -06:00

492 lines
16 KiB
Python

"""WebSocket message schemas for Mantimon TCG real-time communication.
This module defines Pydantic models for all WebSocket message types, following
the discriminated union pattern from app/core/models/actions.py. Messages are
categorized into client-to-server and server-to-client types.
Message Design Principles:
1. All messages have a 'type' discriminator field for automatic parsing
2. All messages have a 'message_id' for idempotency and event ordering
3. Server messages include a 'timestamp' for client-side latency tracking
4. Game-scoped messages include 'game_id' for multi-game support
Example:
# Parse incoming client message
data = {"type": "join_game", "message_id": "abc-123", "game_id": "game-456"}
message = parse_client_message(data)
assert isinstance(message, JoinGameMessage)
# Create server response
response = GameStateMessage(
message_id=str(uuid4()),
game_id="game-456",
state=visible_state,
)
await websocket.send_text(response.model_dump_json())
"""
from datetime import UTC, datetime
from enum import StrEnum
from typing import Annotated, Any, Literal
from uuid import uuid4
from pydantic import BaseModel, Field, field_validator
from app.core.enums import GameEndReason
from app.core.models.actions import Action
from app.core.visibility import VisibleGameState
class WSErrorCode(StrEnum):
"""Error codes for WebSocket error messages.
These codes help clients handle errors programmatically without
parsing error message text.
"""
# Connection errors
AUTHENTICATION_FAILED = "authentication_failed"
CONNECTION_CLOSED = "connection_closed"
RATE_LIMITED = "rate_limited"
# Game errors
GAME_NOT_FOUND = "game_not_found"
NOT_IN_GAME = "not_in_game"
ALREADY_IN_GAME = "already_in_game"
GAME_FULL = "game_full"
GAME_ENDED = "game_ended"
# Action errors
INVALID_ACTION = "invalid_action"
NOT_YOUR_TURN = "not_your_turn"
ACTION_NOT_ALLOWED = "action_not_allowed"
# Protocol errors
INVALID_MESSAGE = "invalid_message"
UNKNOWN_MESSAGE_TYPE = "unknown_message_type"
# Server errors
INTERNAL_ERROR = "internal_error"
class ConnectionStatus(StrEnum):
"""Connection status for opponent status messages."""
CONNECTED = "connected"
DISCONNECTED = "disconnected"
RECONNECTING = "reconnecting"
# =============================================================================
# Base Message Classes
# =============================================================================
def _generate_message_id() -> str:
"""Generate a unique message ID."""
return str(uuid4())
def _utc_now() -> datetime:
"""Get current UTC timestamp."""
return datetime.now(UTC)
class BaseClientMessage(BaseModel):
"""Base class for all client-to-server messages.
All client messages must include a message_id for idempotency. The client
generates this ID to allow detecting and handling duplicate messages.
Attributes:
message_id: Client-generated UUID for idempotency and tracking.
"""
message_id: str = Field(
default_factory=_generate_message_id,
description="Client-generated UUID for idempotency",
)
@field_validator("message_id")
@classmethod
def validate_message_id(cls, v: str) -> str:
"""Ensure message_id is not empty."""
if not v or not v.strip():
raise ValueError("message_id cannot be empty")
return v
class BaseServerMessage(BaseModel):
"""Base class for all server-to-client messages.
All server messages include a message_id and timestamp. The timestamp
enables client-side latency calculation and event ordering.
Attributes:
message_id: Server-generated UUID for tracking.
timestamp: UTC timestamp when the message was created.
"""
message_id: str = Field(
default_factory=_generate_message_id,
description="Server-generated UUID for tracking",
)
timestamp: datetime = Field(
default_factory=_utc_now,
description="UTC timestamp when message was created",
)
# =============================================================================
# Client -> Server Messages
# =============================================================================
class JoinGameMessage(BaseClientMessage):
"""Request to join or rejoin a game session.
When rejoining after a disconnect, the client can provide last_event_id
to receive any missed events since that point.
Attributes:
type: Discriminator field, always "join_game".
game_id: The game to join.
last_event_id: For reconnection - ID of last received event for replay.
"""
type: Literal["join_game"] = "join_game"
game_id: str = Field(..., description="ID of the game to join")
last_event_id: str | None = Field(
default=None,
description="Last event ID received (for reconnection replay)",
)
class ActionMessage(BaseClientMessage):
"""Submit a game action for execution.
Wraps the Action union type from app/core/models/actions.py to provide
consistent message envelope with message_id and game context.
Attributes:
type: Discriminator field, always "action".
game_id: The game this action is for.
action: The game action to execute.
"""
type: Literal["action"] = "action"
game_id: str = Field(..., description="ID of the game")
action: Action = Field(..., description="The game action to execute")
class ResignMessage(BaseClientMessage):
"""Request to resign from a game.
This is separate from ResignAction in the game engine to allow
resignation handling at the WebSocket layer (e.g., when the game
engine is unavailable).
Attributes:
type: Discriminator field, always "resign".
game_id: The game to resign from.
"""
type: Literal["resign"] = "resign"
game_id: str = Field(..., description="ID of the game to resign from")
class HeartbeatMessage(BaseClientMessage):
"""Keep-alive message to maintain connection.
Clients should send heartbeats periodically (e.g., every 30 seconds)
to prevent connection timeout. The server responds with a HeartbeatAck.
Attributes:
type: Discriminator field, always "heartbeat".
"""
type: Literal["heartbeat"] = "heartbeat"
# Union type for client messages
ClientMessage = Annotated[
JoinGameMessage | ActionMessage | ResignMessage | HeartbeatMessage,
Field(discriminator="type"),
]
# =============================================================================
# Server -> Client Messages
# =============================================================================
class GameStateMessage(BaseServerMessage):
"""Full game state update.
Sent when a player joins a game or when a full state sync is needed.
Contains the complete visible game state from that player's perspective.
Attributes:
type: Discriminator field, always "game_state".
game_id: The game this state is for.
state: The full visible game state.
event_id: Monotonic event ID for reconnection replay.
"""
type: Literal["game_state"] = "game_state"
game_id: str = Field(..., description="ID of the game")
state: VisibleGameState = Field(..., description="Full visible game state")
event_id: str = Field(
default_factory=_generate_message_id,
description="Event ID for reconnection replay",
)
class ActionResultMessage(BaseServerMessage):
"""Result of a player action.
Sent after processing an ActionMessage to confirm success or failure.
On success, includes changes that resulted from the action.
Attributes:
type: Discriminator field, always "action_result".
game_id: The game this result is for.
request_message_id: The message_id of the original ActionMessage.
success: Whether the action succeeded.
action_type: The type of action that was attempted.
changes: Description of state changes (on success).
error_code: Error code if action failed.
error_message: Human-readable error message if failed.
"""
type: Literal["action_result"] = "action_result"
game_id: str = Field(..., description="ID of the game")
request_message_id: str = Field(..., description="message_id of the original action request")
success: bool = Field(..., description="Whether the action succeeded")
action_type: str = Field(..., description="Type of action that was attempted")
changes: dict[str, Any] = Field(
default_factory=dict,
description="State changes resulting from the action",
)
error_code: WSErrorCode | None = Field(default=None, description="Error code if action failed")
error_message: str | None = Field(default=None, description="Human-readable error message")
class ErrorMessage(BaseServerMessage):
"""Error notification for protocol or connection errors.
Used for errors not associated with a specific action, such as
invalid message format, authentication failures, or server errors.
Attributes:
type: Discriminator field, always "error".
code: Machine-readable error code.
message: Human-readable error description.
details: Additional error context.
request_message_id: ID of the message that caused the error, if applicable.
"""
type: Literal["error"] = "error"
code: WSErrorCode = Field(..., description="Error code")
message: str = Field(..., description="Human-readable error description")
details: dict[str, Any] = Field(default_factory=dict, description="Additional error context")
request_message_id: str | None = Field(
default=None, description="ID of message that caused error"
)
class TurnStartMessage(BaseServerMessage):
"""Notification that a new turn has started.
Sent to both players when a turn begins. Indicates whose turn it is
and the current turn number.
Attributes:
type: Discriminator field, always "turn_start".
game_id: The game this notification is for.
player_id: The player whose turn is starting.
turn_number: The current turn number.
event_id: Event ID for reconnection replay.
"""
type: Literal["turn_start"] = "turn_start"
game_id: str = Field(..., description="ID of the game")
player_id: str = Field(..., description="Player whose turn is starting")
turn_number: int = Field(..., description="Current turn number", ge=1)
event_id: str = Field(
default_factory=_generate_message_id,
description="Event ID for reconnection replay",
)
class TurnTimeoutMessage(BaseServerMessage):
"""Timeout warning or expiration notification.
Sent when a player's turn is approaching timeout (warning) or has
expired. Warnings give players time to complete their action.
Attributes:
type: Discriminator field, always "turn_timeout".
game_id: The game this notification is for.
remaining_seconds: Seconds remaining before timeout.
is_warning: True if this is a warning, False if timeout has occurred.
player_id: The player whose turn is timing out.
"""
type: Literal["turn_timeout"] = "turn_timeout"
game_id: str = Field(..., description="ID of the game")
remaining_seconds: int = Field(..., description="Seconds remaining before timeout", ge=0)
is_warning: bool = Field(..., description="True if warning, False if timeout occurred")
player_id: str = Field(..., description="Player whose turn is timing out")
class GameOverMessage(BaseServerMessage):
"""Notification that the game has ended.
Sent to all players when the game concludes, regardless of the reason.
Attributes:
type: Discriminator field, always "game_over".
game_id: The game that ended.
winner_id: The winning player, or None for a draw.
end_reason: Why the game ended.
final_state: The final visible game state.
event_id: Event ID for reconnection replay.
"""
type: Literal["game_over"] = "game_over"
game_id: str = Field(..., description="ID of the game that ended")
winner_id: str | None = Field(default=None, description="Winner player ID, or None for draw")
end_reason: GameEndReason = Field(..., description="Reason the game ended")
final_state: VisibleGameState = Field(..., description="Final visible game state")
event_id: str = Field(
default_factory=_generate_message_id,
description="Event ID for reconnection replay",
)
class OpponentStatusMessage(BaseServerMessage):
"""Notification of opponent connection status change.
Sent when the opponent connects, disconnects, or is reconnecting.
Allows the UI to show connection status to the player.
Attributes:
type: Discriminator field, always "opponent_status".
game_id: The game this status is for.
opponent_id: The opponent's player ID.
status: The opponent's current connection status.
"""
type: Literal["opponent_status"] = "opponent_status"
game_id: str = Field(..., description="ID of the game")
opponent_id: str = Field(..., description="Opponent's player ID")
status: ConnectionStatus = Field(..., description="Opponent's connection status")
class HeartbeatAckMessage(BaseServerMessage):
"""Acknowledgment of a client heartbeat.
Sent in response to HeartbeatMessage to confirm the connection is alive.
Attributes:
type: Discriminator field, always "heartbeat_ack".
"""
type: Literal["heartbeat_ack"] = "heartbeat_ack"
# Union type for server messages
ServerMessage = Annotated[
GameStateMessage
| ActionResultMessage
| ErrorMessage
| TurnStartMessage
| TurnTimeoutMessage
| GameOverMessage
| OpponentStatusMessage
| HeartbeatAckMessage,
Field(discriminator="type"),
]
# =============================================================================
# Parsing Functions
# =============================================================================
def parse_client_message(data: dict[str, Any]) -> ClientMessage:
"""Parse an incoming client message from a dictionary.
Uses the 'type' field to determine which message model to use.
Args:
data: Dictionary containing message data with a 'type' field.
Returns:
The appropriate ClientMessage subtype.
Raises:
ValueError: If the message type is unknown.
ValidationError: If the data doesn't match the message schema.
Example:
data = {"type": "join_game", "message_id": "abc", "game_id": "123"}
message = parse_client_message(data)
assert isinstance(message, JoinGameMessage)
"""
from pydantic import TypeAdapter
adapter: TypeAdapter[ClientMessage] = TypeAdapter(ClientMessage)
return adapter.validate_python(data)
def parse_server_message(data: dict[str, Any]) -> ServerMessage:
"""Parse a server message from a dictionary.
Primarily used for testing - clients receive JSON and parse directly.
Args:
data: Dictionary containing message data with a 'type' field.
Returns:
The appropriate ServerMessage subtype.
Raises:
ValueError: If the message type is unknown.
ValidationError: If the data doesn't match the message schema.
"""
from pydantic import TypeAdapter
adapter: TypeAdapter[ServerMessage] = TypeAdapter(ServerMessage)
return adapter.validate_python(data)
__all__ = [
# Enums
"ConnectionStatus",
"WSErrorCode",
# Base classes
"BaseClientMessage",
"BaseServerMessage",
# Client messages
"ActionMessage",
"ClientMessage",
"HeartbeatMessage",
"JoinGameMessage",
"ResignMessage",
# Server messages
"ActionResultMessage",
"ErrorMessage",
"GameOverMessage",
"GameStateMessage",
"HeartbeatAckMessage",
"OpponentStatusMessage",
"ServerMessage",
"TurnStartMessage",
"TurnTimeoutMessage",
# Parsing functions
"parse_client_message",
"parse_server_message",
]