Add Phase 4 WebSocket infrastructure (WS-001 through GS-001)
WebSocket Message Schemas (WS-002): - Add Pydantic models for all client/server WebSocket messages - Implement discriminated unions for message type parsing - Include JoinGame, Action, Resign, Heartbeat client messages - Include GameState, ActionResult, Error, TurnStart server messages Connection Manager (WS-003): - Add Redis-backed WebSocket connection tracking - Implement user-to-sid mapping with TTL management - Support game room association and opponent lookup - Add heartbeat tracking for connection health Socket.IO Authentication (WS-004): - Add JWT-based authentication middleware - Support token extraction from multiple formats - Implement session setup with ConnectionManager integration - Add require_auth helper for event handlers Socket.IO Server Setup (WS-001): - Configure AsyncServer with ASGI mode - Register /game namespace with event handlers - Integrate with FastAPI via ASGIApp wrapper - Configure CORS from application settings Game Service (GS-001): - Add stateless GameService for game lifecycle orchestration - Create engine per-operation using rules from GameState - Implement action-based RNG seeding for deterministic replay - Add rng_seed field to GameState for replay support Architecture verified: - Core module independence (no forbidden imports) - Config from request pattern (rules in GameState) - Dependency injection (constructor deps, method config) - All 1090 tests passing Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
c00ee87f25
commit
0c810e5b30
@ -379,6 +379,7 @@ class GameState(BaseModel):
|
|||||||
forced_actions: Queue of ForcedAction items that must be completed before game proceeds.
|
forced_actions: Queue of ForcedAction items that must be completed before game proceeds.
|
||||||
Actions are processed in FIFO order (first added = first to resolve).
|
Actions are processed in FIFO order (first added = first to resolve).
|
||||||
action_log: Log of actions taken (for replays/debugging).
|
action_log: Log of actions taken (for replays/debugging).
|
||||||
|
rng_seed: Optional seed for deterministic RNG. When set, enables replay capability.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
game_id: str
|
game_id: str
|
||||||
@ -412,6 +413,9 @@ class GameState(BaseModel):
|
|||||||
# Optional action log for replays
|
# Optional action log for replays
|
||||||
action_log: list[dict[str, Any]] = Field(default_factory=list)
|
action_log: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
# Optional RNG seed for deterministic replays
|
||||||
|
rng_seed: int | None = None
|
||||||
|
|
||||||
def get_current_player(self) -> PlayerState:
|
def get_current_player(self) -> PlayerState:
|
||||||
"""Get the PlayerState for the current player.
|
"""Get the PlayerState for the current player.
|
||||||
|
|
||||||
|
|||||||
@ -26,6 +26,7 @@ from app.config import settings
|
|||||||
from app.db import close_db, init_db
|
from app.db import close_db, init_db
|
||||||
from app.db.redis import close_redis, init_redis
|
from app.db.redis import close_redis, init_redis
|
||||||
from app.services import get_card_service
|
from app.services import get_card_service
|
||||||
|
from app.socketio import create_socketio_app
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -173,3 +174,16 @@ app.include_router(decks_router, prefix="/api")
|
|||||||
# app.include_router(cards.router, prefix="/api/cards", tags=["cards"])
|
# app.include_router(cards.router, prefix="/api/cards", tags=["cards"])
|
||||||
# app.include_router(games.router, prefix="/api/games", tags=["games"])
|
# app.include_router(games.router, prefix="/api/games", tags=["games"])
|
||||||
# app.include_router(campaign.router, prefix="/api/campaign", tags=["campaign"])
|
# app.include_router(campaign.router, prefix="/api/campaign", tags=["campaign"])
|
||||||
|
|
||||||
|
|
||||||
|
# === Socket.IO Integration ===
|
||||||
|
# Wrap FastAPI with Socket.IO ASGI app for WebSocket support.
|
||||||
|
# The combined app handles:
|
||||||
|
# - WebSocket connections at /socket.io/
|
||||||
|
# - HTTP requests passed through to FastAPI
|
||||||
|
#
|
||||||
|
# Note: The module-level 'app' variable is now the combined ASGI app.
|
||||||
|
# This is intentional - uvicorn imports 'app' and will serve both
|
||||||
|
# FastAPI HTTP routes and Socket.IO WebSocket connections.
|
||||||
|
_fastapi_app = app
|
||||||
|
app = create_socketio_app(_fastapi_app)
|
||||||
|
|||||||
491
backend/app/schemas/ws_messages.py
Normal file
491
backend/app/schemas/ws_messages.py
Normal file
@ -0,0 +1,491 @@
|
|||||||
|
"""WebSocket message schemas for Mantimon TCG real-time communication.
|
||||||
|
|
||||||
|
This module defines Pydantic models for all WebSocket message types, following
|
||||||
|
the discriminated union pattern from app/core/models/actions.py. Messages are
|
||||||
|
categorized into client-to-server and server-to-client types.
|
||||||
|
|
||||||
|
Message Design Principles:
|
||||||
|
1. All messages have a 'type' discriminator field for automatic parsing
|
||||||
|
2. All messages have a 'message_id' for idempotency and event ordering
|
||||||
|
3. Server messages include a 'timestamp' for client-side latency tracking
|
||||||
|
4. Game-scoped messages include 'game_id' for multi-game support
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Parse incoming client message
|
||||||
|
data = {"type": "join_game", "message_id": "abc-123", "game_id": "game-456"}
|
||||||
|
message = parse_client_message(data)
|
||||||
|
assert isinstance(message, JoinGameMessage)
|
||||||
|
|
||||||
|
# Create server response
|
||||||
|
response = GameStateMessage(
|
||||||
|
message_id=str(uuid4()),
|
||||||
|
game_id="game-456",
|
||||||
|
state=visible_state,
|
||||||
|
)
|
||||||
|
await websocket.send_text(response.model_dump_json())
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Annotated, Any, Literal
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from app.core.enums import GameEndReason
|
||||||
|
from app.core.models.actions import Action
|
||||||
|
from app.core.visibility import VisibleGameState
|
||||||
|
|
||||||
|
|
||||||
|
class WSErrorCode(StrEnum):
|
||||||
|
"""Error codes for WebSocket error messages.
|
||||||
|
|
||||||
|
These codes help clients handle errors programmatically without
|
||||||
|
parsing error message text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Connection errors
|
||||||
|
AUTHENTICATION_FAILED = "authentication_failed"
|
||||||
|
CONNECTION_CLOSED = "connection_closed"
|
||||||
|
RATE_LIMITED = "rate_limited"
|
||||||
|
|
||||||
|
# Game errors
|
||||||
|
GAME_NOT_FOUND = "game_not_found"
|
||||||
|
NOT_IN_GAME = "not_in_game"
|
||||||
|
ALREADY_IN_GAME = "already_in_game"
|
||||||
|
GAME_FULL = "game_full"
|
||||||
|
GAME_ENDED = "game_ended"
|
||||||
|
|
||||||
|
# Action errors
|
||||||
|
INVALID_ACTION = "invalid_action"
|
||||||
|
NOT_YOUR_TURN = "not_your_turn"
|
||||||
|
ACTION_NOT_ALLOWED = "action_not_allowed"
|
||||||
|
|
||||||
|
# Protocol errors
|
||||||
|
INVALID_MESSAGE = "invalid_message"
|
||||||
|
UNKNOWN_MESSAGE_TYPE = "unknown_message_type"
|
||||||
|
|
||||||
|
# Server errors
|
||||||
|
INTERNAL_ERROR = "internal_error"
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionStatus(StrEnum):
|
||||||
|
"""Connection status for opponent status messages."""
|
||||||
|
|
||||||
|
CONNECTED = "connected"
|
||||||
|
DISCONNECTED = "disconnected"
|
||||||
|
RECONNECTING = "reconnecting"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Base Message Classes
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_message_id() -> str:
|
||||||
|
"""Generate a unique message ID."""
|
||||||
|
return str(uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _utc_now() -> datetime:
|
||||||
|
"""Get current UTC timestamp."""
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseClientMessage(BaseModel):
|
||||||
|
"""Base class for all client-to-server messages.
|
||||||
|
|
||||||
|
All client messages must include a message_id for idempotency. The client
|
||||||
|
generates this ID to allow detecting and handling duplicate messages.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
message_id: Client-generated UUID for idempotency and tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
message_id: str = Field(
|
||||||
|
default_factory=_generate_message_id,
|
||||||
|
description="Client-generated UUID for idempotency",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("message_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_message_id(cls, v: str) -> str:
|
||||||
|
"""Ensure message_id is not empty."""
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("message_id cannot be empty")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class BaseServerMessage(BaseModel):
|
||||||
|
"""Base class for all server-to-client messages.
|
||||||
|
|
||||||
|
All server messages include a message_id and timestamp. The timestamp
|
||||||
|
enables client-side latency calculation and event ordering.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
message_id: Server-generated UUID for tracking.
|
||||||
|
timestamp: UTC timestamp when the message was created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
message_id: str = Field(
|
||||||
|
default_factory=_generate_message_id,
|
||||||
|
description="Server-generated UUID for tracking",
|
||||||
|
)
|
||||||
|
timestamp: datetime = Field(
|
||||||
|
default_factory=_utc_now,
|
||||||
|
description="UTC timestamp when message was created",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Client -> Server Messages
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class JoinGameMessage(BaseClientMessage):
|
||||||
|
"""Request to join or rejoin a game session.
|
||||||
|
|
||||||
|
When rejoining after a disconnect, the client can provide last_event_id
|
||||||
|
to receive any missed events since that point.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Discriminator field, always "join_game".
|
||||||
|
game_id: The game to join.
|
||||||
|
last_event_id: For reconnection - ID of last received event for replay.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["join_game"] = "join_game"
|
||||||
|
game_id: str = Field(..., description="ID of the game to join")
|
||||||
|
last_event_id: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Last event ID received (for reconnection replay)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionMessage(BaseClientMessage):
|
||||||
|
"""Submit a game action for execution.
|
||||||
|
|
||||||
|
Wraps the Action union type from app/core/models/actions.py to provide
|
||||||
|
consistent message envelope with message_id and game context.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Discriminator field, always "action".
|
||||||
|
game_id: The game this action is for.
|
||||||
|
action: The game action to execute.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["action"] = "action"
|
||||||
|
game_id: str = Field(..., description="ID of the game")
|
||||||
|
action: Action = Field(..., description="The game action to execute")
|
||||||
|
|
||||||
|
|
||||||
|
class ResignMessage(BaseClientMessage):
|
||||||
|
"""Request to resign from a game.
|
||||||
|
|
||||||
|
This is separate from ResignAction in the game engine to allow
|
||||||
|
resignation handling at the WebSocket layer (e.g., when the game
|
||||||
|
engine is unavailable).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Discriminator field, always "resign".
|
||||||
|
game_id: The game to resign from.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["resign"] = "resign"
|
||||||
|
game_id: str = Field(..., description="ID of the game to resign from")
|
||||||
|
|
||||||
|
|
||||||
|
class HeartbeatMessage(BaseClientMessage):
|
||||||
|
"""Keep-alive message to maintain connection.
|
||||||
|
|
||||||
|
Clients should send heartbeats periodically (e.g., every 30 seconds)
|
||||||
|
to prevent connection timeout. The server responds with a HeartbeatAck.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Discriminator field, always "heartbeat".
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["heartbeat"] = "heartbeat"
|
||||||
|
|
||||||
|
|
||||||
|
# Union type for client messages
|
||||||
|
ClientMessage = Annotated[
|
||||||
|
JoinGameMessage | ActionMessage | ResignMessage | HeartbeatMessage,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Server -> Client Messages
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class GameStateMessage(BaseServerMessage):
|
||||||
|
"""Full game state update.
|
||||||
|
|
||||||
|
Sent when a player joins a game or when a full state sync is needed.
|
||||||
|
Contains the complete visible game state from that player's perspective.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Discriminator field, always "game_state".
|
||||||
|
game_id: The game this state is for.
|
||||||
|
state: The full visible game state.
|
||||||
|
event_id: Monotonic event ID for reconnection replay.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["game_state"] = "game_state"
|
||||||
|
game_id: str = Field(..., description="ID of the game")
|
||||||
|
state: VisibleGameState = Field(..., description="Full visible game state")
|
||||||
|
event_id: str = Field(
|
||||||
|
default_factory=_generate_message_id,
|
||||||
|
description="Event ID for reconnection replay",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionResultMessage(BaseServerMessage):
|
||||||
|
"""Result of a player action.
|
||||||
|
|
||||||
|
Sent after processing an ActionMessage to confirm success or failure.
|
||||||
|
On success, includes changes that resulted from the action.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Discriminator field, always "action_result".
|
||||||
|
game_id: The game this result is for.
|
||||||
|
request_message_id: The message_id of the original ActionMessage.
|
||||||
|
success: Whether the action succeeded.
|
||||||
|
action_type: The type of action that was attempted.
|
||||||
|
changes: Description of state changes (on success).
|
||||||
|
error_code: Error code if action failed.
|
||||||
|
error_message: Human-readable error message if failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["action_result"] = "action_result"
|
||||||
|
game_id: str = Field(..., description="ID of the game")
|
||||||
|
request_message_id: str = Field(..., description="message_id of the original action request")
|
||||||
|
success: bool = Field(..., description="Whether the action succeeded")
|
||||||
|
action_type: str = Field(..., description="Type of action that was attempted")
|
||||||
|
changes: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="State changes resulting from the action",
|
||||||
|
)
|
||||||
|
error_code: WSErrorCode | None = Field(default=None, description="Error code if action failed")
|
||||||
|
error_message: str | None = Field(default=None, description="Human-readable error message")
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorMessage(BaseServerMessage):
|
||||||
|
"""Error notification for protocol or connection errors.
|
||||||
|
|
||||||
|
Used for errors not associated with a specific action, such as
|
||||||
|
invalid message format, authentication failures, or server errors.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Discriminator field, always "error".
|
||||||
|
code: Machine-readable error code.
|
||||||
|
message: Human-readable error description.
|
||||||
|
details: Additional error context.
|
||||||
|
request_message_id: ID of the message that caused the error, if applicable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["error"] = "error"
|
||||||
|
code: WSErrorCode = Field(..., description="Error code")
|
||||||
|
message: str = Field(..., description="Human-readable error description")
|
||||||
|
details: dict[str, Any] = Field(default_factory=dict, description="Additional error context")
|
||||||
|
request_message_id: str | None = Field(
|
||||||
|
default=None, description="ID of message that caused error"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TurnStartMessage(BaseServerMessage):
|
||||||
|
"""Notification that a new turn has started.
|
||||||
|
|
||||||
|
Sent to both players when a turn begins. Indicates whose turn it is
|
||||||
|
and the current turn number.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Discriminator field, always "turn_start".
|
||||||
|
game_id: The game this notification is for.
|
||||||
|
player_id: The player whose turn is starting.
|
||||||
|
turn_number: The current turn number.
|
||||||
|
event_id: Event ID for reconnection replay.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["turn_start"] = "turn_start"
|
||||||
|
game_id: str = Field(..., description="ID of the game")
|
||||||
|
player_id: str = Field(..., description="Player whose turn is starting")
|
||||||
|
turn_number: int = Field(..., description="Current turn number", ge=1)
|
||||||
|
event_id: str = Field(
|
||||||
|
default_factory=_generate_message_id,
|
||||||
|
description="Event ID for reconnection replay",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TurnTimeoutMessage(BaseServerMessage):
|
||||||
|
"""Timeout warning or expiration notification.
|
||||||
|
|
||||||
|
Sent when a player's turn is approaching timeout (warning) or has
|
||||||
|
expired. Warnings give players time to complete their action.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Discriminator field, always "turn_timeout".
|
||||||
|
game_id: The game this notification is for.
|
||||||
|
remaining_seconds: Seconds remaining before timeout.
|
||||||
|
is_warning: True if this is a warning, False if timeout has occurred.
|
||||||
|
player_id: The player whose turn is timing out.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["turn_timeout"] = "turn_timeout"
|
||||||
|
game_id: str = Field(..., description="ID of the game")
|
||||||
|
remaining_seconds: int = Field(..., description="Seconds remaining before timeout", ge=0)
|
||||||
|
is_warning: bool = Field(..., description="True if warning, False if timeout occurred")
|
||||||
|
player_id: str = Field(..., description="Player whose turn is timing out")
|
||||||
|
|
||||||
|
|
||||||
|
class GameOverMessage(BaseServerMessage):
|
||||||
|
"""Notification that the game has ended.
|
||||||
|
|
||||||
|
Sent to all players when the game concludes, regardless of the reason.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Discriminator field, always "game_over".
|
||||||
|
game_id: The game that ended.
|
||||||
|
winner_id: The winning player, or None for a draw.
|
||||||
|
end_reason: Why the game ended.
|
||||||
|
final_state: The final visible game state.
|
||||||
|
event_id: Event ID for reconnection replay.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["game_over"] = "game_over"
|
||||||
|
game_id: str = Field(..., description="ID of the game that ended")
|
||||||
|
winner_id: str | None = Field(default=None, description="Winner player ID, or None for draw")
|
||||||
|
end_reason: GameEndReason = Field(..., description="Reason the game ended")
|
||||||
|
final_state: VisibleGameState = Field(..., description="Final visible game state")
|
||||||
|
event_id: str = Field(
|
||||||
|
default_factory=_generate_message_id,
|
||||||
|
description="Event ID for reconnection replay",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpponentStatusMessage(BaseServerMessage):
|
||||||
|
"""Notification of opponent connection status change.
|
||||||
|
|
||||||
|
Sent when the opponent connects, disconnects, or is reconnecting.
|
||||||
|
Allows the UI to show connection status to the player.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Discriminator field, always "opponent_status".
|
||||||
|
game_id: The game this status is for.
|
||||||
|
opponent_id: The opponent's player ID.
|
||||||
|
status: The opponent's current connection status.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["opponent_status"] = "opponent_status"
|
||||||
|
game_id: str = Field(..., description="ID of the game")
|
||||||
|
opponent_id: str = Field(..., description="Opponent's player ID")
|
||||||
|
status: ConnectionStatus = Field(..., description="Opponent's connection status")
|
||||||
|
|
||||||
|
|
||||||
|
class HeartbeatAckMessage(BaseServerMessage):
|
||||||
|
"""Acknowledgment of a client heartbeat.
|
||||||
|
|
||||||
|
Sent in response to HeartbeatMessage to confirm the connection is alive.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Discriminator field, always "heartbeat_ack".
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["heartbeat_ack"] = "heartbeat_ack"
|
||||||
|
|
||||||
|
|
||||||
|
# Union type for server messages
|
||||||
|
ServerMessage = Annotated[
|
||||||
|
GameStateMessage
|
||||||
|
| ActionResultMessage
|
||||||
|
| ErrorMessage
|
||||||
|
| TurnStartMessage
|
||||||
|
| TurnTimeoutMessage
|
||||||
|
| GameOverMessage
|
||||||
|
| OpponentStatusMessage
|
||||||
|
| HeartbeatAckMessage,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Parsing Functions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def parse_client_message(data: dict[str, Any]) -> ClientMessage:
|
||||||
|
"""Parse an incoming client message from a dictionary.
|
||||||
|
|
||||||
|
Uses the 'type' field to determine which message model to use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing message data with a 'type' field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The appropriate ClientMessage subtype.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the message type is unknown.
|
||||||
|
ValidationError: If the data doesn't match the message schema.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
data = {"type": "join_game", "message_id": "abc", "game_id": "123"}
|
||||||
|
message = parse_client_message(data)
|
||||||
|
assert isinstance(message, JoinGameMessage)
|
||||||
|
"""
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
adapter: TypeAdapter[ClientMessage] = TypeAdapter(ClientMessage)
|
||||||
|
return adapter.validate_python(data)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_server_message(data: dict[str, Any]) -> ServerMessage:
|
||||||
|
"""Parse a server message from a dictionary.
|
||||||
|
|
||||||
|
Primarily used for testing - clients receive JSON and parse directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing message data with a 'type' field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The appropriate ServerMessage subtype.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the message type is unknown.
|
||||||
|
ValidationError: If the data doesn't match the message schema.
|
||||||
|
"""
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
adapter: TypeAdapter[ServerMessage] = TypeAdapter(ServerMessage)
|
||||||
|
return adapter.validate_python(data)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Enums
|
||||||
|
"ConnectionStatus",
|
||||||
|
"WSErrorCode",
|
||||||
|
# Base classes
|
||||||
|
"BaseClientMessage",
|
||||||
|
"BaseServerMessage",
|
||||||
|
# Client messages
|
||||||
|
"ActionMessage",
|
||||||
|
"ClientMessage",
|
||||||
|
"HeartbeatMessage",
|
||||||
|
"JoinGameMessage",
|
||||||
|
"ResignMessage",
|
||||||
|
# Server messages
|
||||||
|
"ActionResultMessage",
|
||||||
|
"ErrorMessage",
|
||||||
|
"GameOverMessage",
|
||||||
|
"GameStateMessage",
|
||||||
|
"HeartbeatAckMessage",
|
||||||
|
"OpponentStatusMessage",
|
||||||
|
"ServerMessage",
|
||||||
|
"TurnStartMessage",
|
||||||
|
"TurnTimeoutMessage",
|
||||||
|
# Parsing functions
|
||||||
|
"parse_client_message",
|
||||||
|
"parse_server_message",
|
||||||
|
]
|
||||||
578
backend/app/services/connection_manager.py
Normal file
578
backend/app/services/connection_manager.py
Normal file
@ -0,0 +1,578 @@
|
|||||||
|
"""WebSocket connection management with Redis-backed session tracking.
|
||||||
|
|
||||||
|
This module manages WebSocket connection state using Redis for persistence
|
||||||
|
and cross-instance coordination. It tracks:
|
||||||
|
- Individual connection sessions (sid -> user_id, game_id, timestamps)
|
||||||
|
- User to connection mapping (user_id -> sid)
|
||||||
|
- Game participants (game_id -> set of sids)
|
||||||
|
|
||||||
|
Key Patterns:
|
||||||
|
conn:{sid} - Hash with connection details (user_id, game_id, connected_at, last_seen)
|
||||||
|
user_conn:{user_id} - String with active sid
|
||||||
|
game_conns:{game_id} - Set of sids for game participants
|
||||||
|
|
||||||
|
Lifecycle:
|
||||||
|
1. on_connect: Register connection, map user to sid
|
||||||
|
2. on_join_game: Add sid to game's connection set
|
||||||
|
3. on_heartbeat: Update last_seen timestamp
|
||||||
|
4. on_disconnect: Remove from all tracking structures
|
||||||
|
|
||||||
|
Example:
|
||||||
|
manager = ConnectionManager()
|
||||||
|
|
||||||
|
# On WebSocket connect
|
||||||
|
await manager.register_connection(sid, user_id)
|
||||||
|
|
||||||
|
# On joining a game
|
||||||
|
await manager.join_game(sid, game_id)
|
||||||
|
|
||||||
|
# Check if opponent is online
|
||||||
|
if await manager.is_user_online(opponent_id):
|
||||||
|
# Send real-time update
|
||||||
|
|
||||||
|
# On disconnect
|
||||||
|
await manager.unregister_connection(sid)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from app.db.redis import get_redis
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Redis key patterns
|
||||||
|
CONN_PREFIX = "conn:"
|
||||||
|
USER_CONN_PREFIX = "user_conn:"
|
||||||
|
GAME_CONNS_PREFIX = "game_conns:"
|
||||||
|
|
||||||
|
# Connection TTL (auto-expire stale connections)
|
||||||
|
DEFAULT_CONN_TTL_SECONDS = 3600 # 1 hour
|
||||||
|
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConnectionInfo:
|
||||||
|
"""Information about a WebSocket connection.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
sid: Socket.IO session ID.
|
||||||
|
user_id: Authenticated user's UUID.
|
||||||
|
game_id: Current game ID if in a game, None otherwise.
|
||||||
|
connected_at: When the connection was established.
|
||||||
|
last_seen: Last heartbeat or activity timestamp.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sid: str
|
||||||
|
user_id: str
|
||||||
|
game_id: str | None
|
||||||
|
connected_at: datetime
|
||||||
|
last_seen: datetime
|
||||||
|
|
||||||
|
def is_stale(self, threshold_seconds: int = HEARTBEAT_INTERVAL_SECONDS * 3) -> bool:
|
||||||
|
"""Check if connection is stale (no recent activity).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold_seconds: Seconds since last_seen to consider stale.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection hasn't been seen recently.
|
||||||
|
"""
|
||||||
|
elapsed = (datetime.now(UTC) - self.last_seen).total_seconds()
|
||||||
|
return elapsed > threshold_seconds
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionManager:
|
||||||
|
"""Manages WebSocket connections with Redis-backed session tracking.
|
||||||
|
|
||||||
|
Provides methods to:
|
||||||
|
- Register/unregister connections
|
||||||
|
- Track which game a connection is participating in
|
||||||
|
- Check user online status
|
||||||
|
- Get all connections for a game
|
||||||
|
- Handle heartbeats and stale connection detection
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
conn_ttl_seconds: TTL for connection records in Redis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, conn_ttl_seconds: int = DEFAULT_CONN_TTL_SECONDS) -> None:
|
||||||
|
"""Initialize the ConnectionManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conn_ttl_seconds: How long to keep connection records in Redis.
|
||||||
|
"""
|
||||||
|
self.conn_ttl_seconds = conn_ttl_seconds
|
||||||
|
|
||||||
|
def _conn_key(self, sid: str) -> str:
|
||||||
|
"""Generate Redis key for a connection."""
|
||||||
|
return f"{CONN_PREFIX}{sid}"
|
||||||
|
|
||||||
|
def _user_conn_key(self, user_id: str) -> str:
|
||||||
|
"""Generate Redis key for user-to-connection mapping."""
|
||||||
|
return f"{USER_CONN_PREFIX}{user_id}"
|
||||||
|
|
||||||
|
def _game_conns_key(self, game_id: str) -> str:
|
||||||
|
"""Generate Redis key for game connection set."""
|
||||||
|
return f"{GAME_CONNS_PREFIX}{game_id}"
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Connection Lifecycle
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def register_connection(
|
||||||
|
self,
|
||||||
|
sid: str,
|
||||||
|
user_id: str | UUID,
|
||||||
|
) -> None:
|
||||||
|
"""Register a new WebSocket connection.
|
||||||
|
|
||||||
|
Creates connection record in Redis and maps user to this connection.
|
||||||
|
If user had a previous connection, it will be replaced.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket.IO session ID.
|
||||||
|
user_id: Authenticated user's ID (UUID or string).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
await manager.register_connection("abc123", user.id)
|
||||||
|
"""
|
||||||
|
user_id_str = str(user_id)
|
||||||
|
now = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
|
async with get_redis() as redis:
|
||||||
|
# Check for existing connection and clean it up
|
||||||
|
old_sid = await redis.get(self._user_conn_key(user_id_str))
|
||||||
|
if old_sid and old_sid != sid:
|
||||||
|
logger.info(f"Replacing old connection {old_sid} for user {user_id_str}")
|
||||||
|
await self._cleanup_connection(old_sid)
|
||||||
|
|
||||||
|
# Create connection record (hash)
|
||||||
|
conn_key = self._conn_key(sid)
|
||||||
|
await redis.hset(
|
||||||
|
conn_key,
|
||||||
|
mapping={
|
||||||
|
"user_id": user_id_str,
|
||||||
|
"game_id": "",
|
||||||
|
"connected_at": now,
|
||||||
|
"last_seen": now,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await redis.expire(conn_key, self.conn_ttl_seconds)
|
||||||
|
|
||||||
|
# Map user to this connection
|
||||||
|
user_conn_key = self._user_conn_key(user_id_str)
|
||||||
|
await redis.set(user_conn_key, sid)
|
||||||
|
await redis.expire(user_conn_key, self.conn_ttl_seconds)
|
||||||
|
|
||||||
|
logger.debug(f"Registered connection: sid={sid}, user_id={user_id_str}")
|
||||||
|
|
||||||
|
async def unregister_connection(self, sid: str) -> ConnectionInfo | None:
|
||||||
|
"""Unregister a WebSocket connection and clean up all related data.
|
||||||
|
|
||||||
|
Removes connection from:
|
||||||
|
- Connection record
|
||||||
|
- User-to-connection mapping
|
||||||
|
- Any game connection sets
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket.IO session ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConnectionInfo of the removed connection, or None if not found.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
info = await manager.unregister_connection("abc123")
|
||||||
|
if info and info.game_id:
|
||||||
|
# Notify opponent of disconnect
|
||||||
|
"""
|
||||||
|
# Get connection info before cleanup
|
||||||
|
info = await self.get_connection(sid)
|
||||||
|
if info is None:
|
||||||
|
logger.debug(f"Connection not found for unregister: {sid}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
await self._cleanup_connection(sid)
|
||||||
|
logger.debug(f"Unregistered connection: sid={sid}, user_id={info.user_id}")
|
||||||
|
return info
|
||||||
|
|
||||||
|
async def _cleanup_connection(self, sid: str) -> None:
|
||||||
|
"""Internal cleanup of a connection's Redis data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket.IO session ID.
|
||||||
|
"""
|
||||||
|
async with get_redis() as redis:
|
||||||
|
conn_key = self._conn_key(sid)
|
||||||
|
|
||||||
|
# Get connection data for cleanup
|
||||||
|
data = await redis.hgetall(conn_key)
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
|
user_id = data.get("user_id", "")
|
||||||
|
game_id = data.get("game_id", "")
|
||||||
|
|
||||||
|
# Remove from game connection set if in a game
|
||||||
|
if game_id:
|
||||||
|
game_conns_key = self._game_conns_key(game_id)
|
||||||
|
await redis.srem(game_conns_key, sid)
|
||||||
|
|
||||||
|
# Remove user-to-connection mapping if it points to this sid
|
||||||
|
if user_id:
|
||||||
|
user_conn_key = self._user_conn_key(user_id)
|
||||||
|
current_sid = await redis.get(user_conn_key)
|
||||||
|
if current_sid == sid:
|
||||||
|
await redis.delete(user_conn_key)
|
||||||
|
|
||||||
|
# Delete connection record
|
||||||
|
await redis.delete(conn_key)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Game Association
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def join_game(self, sid: str, game_id: str) -> bool:
|
||||||
|
"""Associate a connection with a game.
|
||||||
|
|
||||||
|
Updates the connection's game_id and adds it to the game's connection set.
|
||||||
|
If connection was in a different game, leaves that game first.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket.IO session ID.
|
||||||
|
game_id: Game to join.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False if connection not found.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
await manager.join_game("abc123", "game-456")
|
||||||
|
"""
|
||||||
|
async with get_redis() as redis:
|
||||||
|
conn_key = self._conn_key(sid)
|
||||||
|
|
||||||
|
# Check connection exists
|
||||||
|
exists = await redis.exists(conn_key)
|
||||||
|
if not exists:
|
||||||
|
logger.warning(f"Cannot join game: connection not found {sid}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Leave current game if in one
|
||||||
|
current_game = await redis.hget(conn_key, "game_id")
|
||||||
|
if current_game and current_game != game_id:
|
||||||
|
old_game_conns_key = self._game_conns_key(current_game)
|
||||||
|
await redis.srem(old_game_conns_key, sid)
|
||||||
|
logger.debug(f"Connection {sid} left game {current_game}")
|
||||||
|
|
||||||
|
# Update connection's game_id
|
||||||
|
await redis.hset(conn_key, "game_id", game_id)
|
||||||
|
|
||||||
|
# Add to game's connection set
|
||||||
|
game_conns_key = self._game_conns_key(game_id)
|
||||||
|
await redis.sadd(game_conns_key, sid)
|
||||||
|
|
||||||
|
# Set TTL on game conns set (cleanup if game becomes inactive)
|
||||||
|
await redis.expire(game_conns_key, self.conn_ttl_seconds)
|
||||||
|
|
||||||
|
logger.debug(f"Connection {sid} joined game {game_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def leave_game(self, sid: str) -> str | None:
|
||||||
|
"""Remove a connection from its current game.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket.IO session ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The game_id that was left, or None if not in a game.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
game_id = await manager.leave_game("abc123")
|
||||||
|
"""
|
||||||
|
async with get_redis() as redis:
|
||||||
|
conn_key = self._conn_key(sid)
|
||||||
|
|
||||||
|
# Get current game
|
||||||
|
game_id = await redis.hget(conn_key, "game_id")
|
||||||
|
if not game_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Remove from game's connection set
|
||||||
|
game_conns_key = self._game_conns_key(game_id)
|
||||||
|
await redis.srem(game_conns_key, sid)
|
||||||
|
|
||||||
|
# Clear game_id on connection
|
||||||
|
await redis.hset(conn_key, "game_id", "")
|
||||||
|
|
||||||
|
logger.debug(f"Connection {sid} left game {game_id}")
|
||||||
|
return game_id
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Heartbeat / Activity
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def update_heartbeat(self, sid: str) -> bool:
|
||||||
|
"""Update the last_seen timestamp for a connection.
|
||||||
|
|
||||||
|
Should be called on each heartbeat message from the client.
|
||||||
|
Also refreshes TTL on connection records.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket.IO session ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False if connection not found.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
await manager.update_heartbeat("abc123")
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
|
async with get_redis() as redis:
|
||||||
|
conn_key = self._conn_key(sid)
|
||||||
|
|
||||||
|
# Check exists and update atomically
|
||||||
|
exists = await redis.exists(conn_key)
|
||||||
|
if not exists:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Update last_seen
|
||||||
|
await redis.hset(conn_key, "last_seen", now)
|
||||||
|
|
||||||
|
# Refresh TTL
|
||||||
|
await redis.expire(conn_key, self.conn_ttl_seconds)
|
||||||
|
|
||||||
|
# Also refresh user mapping TTL
|
||||||
|
user_id = await redis.hget(conn_key, "user_id")
|
||||||
|
if user_id:
|
||||||
|
user_conn_key = self._user_conn_key(user_id)
|
||||||
|
await redis.expire(user_conn_key, self.conn_ttl_seconds)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Query Methods
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def get_connection(self, sid: str) -> ConnectionInfo | None:
|
||||||
|
"""Get connection info by session ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket.IO session ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConnectionInfo if found, None otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
info = await manager.get_connection("abc123")
|
||||||
|
if info:
|
||||||
|
print(f"User {info.user_id} connected at {info.connected_at}")
|
||||||
|
"""
|
||||||
|
async with get_redis() as redis:
|
||||||
|
conn_key = self._conn_key(sid)
|
||||||
|
data = await redis.hgetall(conn_key)
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return ConnectionInfo(
|
||||||
|
sid=sid,
|
||||||
|
user_id=data.get("user_id", ""),
|
||||||
|
game_id=data.get("game_id") or None,
|
||||||
|
connected_at=datetime.fromisoformat(data["connected_at"]),
|
||||||
|
last_seen=datetime.fromisoformat(data["last_seen"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_user_connection(self, user_id: str | UUID) -> ConnectionInfo | None:
|
||||||
|
"""Get connection info for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User's UUID or string ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConnectionInfo if user is connected, None otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
info = await manager.get_user_connection(opponent_id)
|
||||||
|
if info:
|
||||||
|
# User is online
|
||||||
|
"""
|
||||||
|
user_id_str = str(user_id)
|
||||||
|
|
||||||
|
async with get_redis() as redis:
|
||||||
|
user_conn_key = self._user_conn_key(user_id_str)
|
||||||
|
sid = await redis.get(user_conn_key)
|
||||||
|
|
||||||
|
if not sid:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await self.get_connection(sid)
|
||||||
|
|
||||||
|
async def is_user_online(self, user_id: str | UUID) -> bool:
|
||||||
|
"""Check if a user has an active connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User's UUID or string ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if user is connected, False otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
if await manager.is_user_online(opponent_id):
|
||||||
|
# Send real-time notification
|
||||||
|
"""
|
||||||
|
info = await self.get_user_connection(user_id)
|
||||||
|
if info is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if connection is stale
|
||||||
|
if info.is_stale():
|
||||||
|
logger.debug(f"User {user_id} connection is stale")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_game_connections(self, game_id: str) -> list[ConnectionInfo]:
|
||||||
|
"""Get all connections for a game.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game_id: Game ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ConnectionInfo for all connected participants.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
connections = await manager.get_game_connections("game-456")
|
||||||
|
for conn in connections:
|
||||||
|
# Send state update to each participant
|
||||||
|
"""
|
||||||
|
async with get_redis() as redis:
|
||||||
|
game_conns_key = self._game_conns_key(game_id)
|
||||||
|
sids = await redis.smembers(game_conns_key)
|
||||||
|
|
||||||
|
connections = []
|
||||||
|
for sid in sids:
|
||||||
|
info = await self.get_connection(sid)
|
||||||
|
if info is not None:
|
||||||
|
connections.append(info)
|
||||||
|
|
||||||
|
return connections
|
||||||
|
|
||||||
|
async def get_game_user_sids(self, game_id: str) -> dict[str, str]:
|
||||||
|
"""Get mapping of user_id to sid for a game.
|
||||||
|
|
||||||
|
Useful for targeted message delivery to specific players.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game_id: Game ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping user_id to their sid.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
user_sids = await manager.get_game_user_sids("game-456")
|
||||||
|
player1_sid = user_sids.get(player1_id)
|
||||||
|
"""
|
||||||
|
connections = await self.get_game_connections(game_id)
|
||||||
|
return {conn.user_id: conn.sid for conn in connections}
|
||||||
|
|
||||||
|
async def get_opponent_sid(
|
||||||
|
self,
|
||||||
|
game_id: str,
|
||||||
|
current_user_id: str | UUID,
|
||||||
|
) -> str | None:
|
||||||
|
"""Get the sid of the opponent in a 2-player game.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game_id: Game ID.
|
||||||
|
current_user_id: The current user's ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Opponent's sid if found and connected, None otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
opponent_sid = await manager.get_opponent_sid("game-456", user_id)
|
||||||
|
if opponent_sid:
|
||||||
|
# Send message to opponent
|
||||||
|
"""
|
||||||
|
current_user_str = str(current_user_id)
|
||||||
|
connections = await self.get_game_connections(game_id)
|
||||||
|
|
||||||
|
for conn in connections:
|
||||||
|
if conn.user_id != current_user_str:
|
||||||
|
return conn.sid
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Maintenance
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def cleanup_stale_connections(
|
||||||
|
self,
|
||||||
|
threshold_seconds: int = HEARTBEAT_INTERVAL_SECONDS * 3,
|
||||||
|
) -> int:
|
||||||
|
"""Clean up stale connections that haven't sent heartbeats.
|
||||||
|
|
||||||
|
Should be called periodically by a background task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold_seconds: Seconds since last_seen to consider stale.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of connections cleaned up.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# In background task
|
||||||
|
count = await manager.cleanup_stale_connections()
|
||||||
|
logger.info(f"Cleaned up {count} stale connections")
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
async with get_redis() as redis:
|
||||||
|
# Scan for all connection keys
|
||||||
|
async for key in redis.scan_iter(match=f"{CONN_PREFIX}*"):
|
||||||
|
sid = key[len(CONN_PREFIX) :]
|
||||||
|
info = await self.get_connection(sid)
|
||||||
|
|
||||||
|
if info and info.is_stale(threshold_seconds):
|
||||||
|
await self._cleanup_connection(sid)
|
||||||
|
count += 1
|
||||||
|
logger.debug(f"Cleaned up stale connection: {sid}")
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
logger.info(f"Cleaned up {count} stale connections")
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def get_connection_count(self) -> int:
|
||||||
|
"""Get the total number of active connections.
|
||||||
|
|
||||||
|
Useful for monitoring and admin dashboards.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of active connections.
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
async with get_redis() as redis:
|
||||||
|
async for _ in redis.scan_iter(match=f"{CONN_PREFIX}*"):
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def get_game_connection_count(self, game_id: str) -> int:
|
||||||
|
"""Get the number of connections for a specific game.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game_id: Game ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of connected participants.
|
||||||
|
"""
|
||||||
|
async with get_redis() as redis:
|
||||||
|
game_conns_key = self._game_conns_key(game_id)
|
||||||
|
return await redis.scard(game_conns_key)
|
||||||
|
|
||||||
|
|
||||||
|
# Global singleton instance
|
||||||
|
connection_manager = ConnectionManager()
|
||||||
558
backend/app/services/game_service.py
Normal file
558
backend/app/services/game_service.py
Normal file
@ -0,0 +1,558 @@
|
|||||||
|
"""Game service for orchestrating game lifecycle in Mantimon TCG.
|
||||||
|
|
||||||
|
This service is the bridge between WebSocket communication and the core
|
||||||
|
GameEngine. It handles:
|
||||||
|
- Game creation with deck loading and validation
|
||||||
|
- Action execution with persistence
|
||||||
|
- Game state retrieval with visibility filtering
|
||||||
|
- Game lifecycle (join, resign, end)
|
||||||
|
|
||||||
|
IMPORTANT: This service is stateless. All game-specific configuration
|
||||||
|
(RulesConfig) is stored in the GameState itself, not in this service.
|
||||||
|
Rules come from the frontend request at game creation time.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
WebSocket Layer -> GameService -> GameEngine + GameStateManager
|
||||||
|
-> DeckService + CardService
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from app.services.game_service import GameService, game_service
|
||||||
|
|
||||||
|
# Create a new game (rules from frontend)
|
||||||
|
result = await game_service.create_game(
|
||||||
|
player1_id=user1.id,
|
||||||
|
player2_id=user2.id,
|
||||||
|
deck1_id=deck1.id,
|
||||||
|
deck2_id=deck2.id,
|
||||||
|
rules_config=rules_from_request, # Frontend provides this
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute an action (uses rules stored in game state)
|
||||||
|
result = await game_service.execute_action(
|
||||||
|
game_id=game.id,
|
||||||
|
player_id=user1.id,
|
||||||
|
action=AttackAction(attack_index=0),
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from app.core.engine import ActionResult, GameEngine
|
||||||
|
from app.core.enums import GameEndReason
|
||||||
|
from app.core.models.actions import Action, ResignAction
|
||||||
|
from app.core.models.game_state import GameState
|
||||||
|
from app.core.rng import create_rng
|
||||||
|
from app.core.visibility import VisibleGameState, get_visible_state
|
||||||
|
from app.services.card_service import CardService, get_card_service
|
||||||
|
from app.services.game_state_manager import GameStateManager, game_state_manager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Exceptions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class GameServiceError(Exception):
|
||||||
|
"""Base exception for GameService errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GameNotFoundError(GameServiceError):
|
||||||
|
"""Raised when a game cannot be found."""
|
||||||
|
|
||||||
|
def __init__(self, game_id: str) -> None:
|
||||||
|
self.game_id = game_id
|
||||||
|
super().__init__(f"Game not found: {game_id}")
|
||||||
|
|
||||||
|
|
||||||
|
class NotPlayerTurnError(GameServiceError):
|
||||||
|
"""Raised when a player tries to act out of turn."""
|
||||||
|
|
||||||
|
def __init__(self, game_id: str, player_id: str, current_player_id: str) -> None:
|
||||||
|
self.game_id = game_id
|
||||||
|
self.player_id = player_id
|
||||||
|
self.current_player_id = current_player_id
|
||||||
|
super().__init__(
|
||||||
|
f"Not player's turn: {player_id} tried to act, but it's {current_player_id}'s turn"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidActionError(GameServiceError):
|
||||||
|
"""Raised when an action is invalid."""
|
||||||
|
|
||||||
|
def __init__(self, game_id: str, player_id: str, reason: str) -> None:
|
||||||
|
self.game_id = game_id
|
||||||
|
self.player_id = player_id
|
||||||
|
self.reason = reason
|
||||||
|
super().__init__(f"Invalid action in game {game_id}: {reason}")
|
||||||
|
|
||||||
|
|
||||||
|
class PlayerNotInGameError(GameServiceError):
|
||||||
|
"""Raised when a player is not a participant in the game."""
|
||||||
|
|
||||||
|
def __init__(self, game_id: str, player_id: str) -> None:
|
||||||
|
self.game_id = game_id
|
||||||
|
self.player_id = player_id
|
||||||
|
super().__init__(f"Player {player_id} is not in game {game_id}")
|
||||||
|
|
||||||
|
|
||||||
|
class GameAlreadyEndedError(GameServiceError):
|
||||||
|
"""Raised when trying to act on a game that has already ended."""
|
||||||
|
|
||||||
|
def __init__(self, game_id: str) -> None:
|
||||||
|
self.game_id = game_id
|
||||||
|
super().__init__(f"Game {game_id} has already ended")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Result Types
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GameActionResult:
|
||||||
|
"""Result of executing a game action.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
success: Whether the action succeeded.
|
||||||
|
game_id: The game ID.
|
||||||
|
action_type: The type of action executed.
|
||||||
|
message: Description of what happened.
|
||||||
|
state_changes: Dict of state changes for client updates.
|
||||||
|
game_over: Whether the game ended as a result.
|
||||||
|
winner_id: Winner's player ID if game ended with a winner.
|
||||||
|
end_reason: Reason the game ended, if applicable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
game_id: str
|
||||||
|
action_type: str
|
||||||
|
message: str = ""
|
||||||
|
state_changes: dict[str, Any] = field(default_factory=dict)
|
||||||
|
game_over: bool = False
|
||||||
|
winner_id: str | None = None
|
||||||
|
end_reason: GameEndReason | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GameJoinResult:
|
||||||
|
"""Result of joining a game.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
success: Whether the join succeeded.
|
||||||
|
game_id: The game ID.
|
||||||
|
player_id: The joining player's ID.
|
||||||
|
visible_state: The game state visible to the player.
|
||||||
|
is_your_turn: Whether it's this player's turn.
|
||||||
|
message: Additional information or error message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
game_id: str
|
||||||
|
player_id: str
|
||||||
|
visible_state: VisibleGameState | None = None
|
||||||
|
is_your_turn: bool = False
|
||||||
|
message: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# GameService
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class GameService:
|
||||||
|
"""Service for orchestrating game lifecycle operations.
|
||||||
|
|
||||||
|
This service coordinates between the WebSocket layer and the core
|
||||||
|
GameEngine, handling persistence and state management.
|
||||||
|
|
||||||
|
IMPORTANT: This service is STATELESS regarding game rules.
|
||||||
|
- Rules are stored in each GameState (set at creation time)
|
||||||
|
- The GameEngine is instantiated per-operation with the game's rules
|
||||||
|
- No RulesConfig is stored in this service
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_state_manager: GameStateManager for persistence.
|
||||||
|
_card_service: CardService for card definitions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
state_manager: GameStateManager | None = None,
|
||||||
|
card_service: CardService | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the GameService.
|
||||||
|
|
||||||
|
Note: No GameEngine or RulesConfig here - those are per-game,
|
||||||
|
not per-service. The engine is created as needed using the
|
||||||
|
rules stored in each game's state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_manager: GameStateManager instance. Uses global if not provided.
|
||||||
|
card_service: CardService instance. Uses global if not provided.
|
||||||
|
"""
|
||||||
|
self._state_manager = state_manager or game_state_manager
|
||||||
|
self._card_service = card_service or get_card_service()
|
||||||
|
|
||||||
|
def _create_engine_for_game(self, game: GameState) -> GameEngine:
|
||||||
|
"""Create a GameEngine configured for a specific game's rules.
|
||||||
|
|
||||||
|
The engine is created on-demand using the rules stored in the
|
||||||
|
game state. This ensures each game uses its own configuration.
|
||||||
|
|
||||||
|
For deterministic replay support, we derive a unique seed per action
|
||||||
|
by combining the game's base seed with the action count. This ensures:
|
||||||
|
- Same game + same action sequence = identical RNG results
|
||||||
|
- Each action gets a unique but reproducible random sequence
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game: The game state containing the rules to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A GameEngine configured with the game's rules and RNG.
|
||||||
|
"""
|
||||||
|
if game.rng_seed is not None:
|
||||||
|
# Derive unique seed per action for deterministic replay
|
||||||
|
# Action count ensures each action gets different but reproducible RNG
|
||||||
|
action_count = len(game.action_log)
|
||||||
|
action_seed = game.rng_seed + action_count
|
||||||
|
rng = create_rng(seed=action_seed)
|
||||||
|
else:
|
||||||
|
# No seed - use cryptographically secure RNG
|
||||||
|
rng = create_rng()
|
||||||
|
|
||||||
|
return GameEngine(rules=game.rules, rng=rng)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Game State Access
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def get_game_state(self, game_id: str) -> GameState:
|
||||||
|
"""Get the full game state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game_id: The game ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The GameState.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GameNotFoundError: If game doesn't exist.
|
||||||
|
"""
|
||||||
|
state = await self._state_manager.load_state(game_id)
|
||||||
|
if state is None:
|
||||||
|
raise GameNotFoundError(game_id)
|
||||||
|
return state
|
||||||
|
|
||||||
|
async def get_player_view(
|
||||||
|
self,
|
||||||
|
game_id: str,
|
||||||
|
player_id: str,
|
||||||
|
) -> VisibleGameState:
|
||||||
|
"""Get the game state filtered for a specific player's view.
|
||||||
|
|
||||||
|
This applies visibility rules to hide opponent's hidden information
|
||||||
|
(hand, deck, prizes).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game_id: The game ID.
|
||||||
|
player_id: The player to get the view for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VisibleGameState with appropriate filtering.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GameNotFoundError: If game doesn't exist.
|
||||||
|
PlayerNotInGameError: If player is not in the game.
|
||||||
|
"""
|
||||||
|
state = await self.get_game_state(game_id)
|
||||||
|
|
||||||
|
if player_id not in state.players:
|
||||||
|
raise PlayerNotInGameError(game_id, player_id)
|
||||||
|
|
||||||
|
return get_visible_state(state, player_id)
|
||||||
|
|
||||||
|
async def is_player_turn(self, game_id: str, player_id: str) -> bool:
|
||||||
|
"""Check if it's the specified player's turn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game_id: The game ID.
|
||||||
|
player_id: The player to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if it's the player's turn.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GameNotFoundError: If game doesn't exist.
|
||||||
|
"""
|
||||||
|
state = await self.get_game_state(game_id)
|
||||||
|
return state.current_player_id == player_id
|
||||||
|
|
||||||
|
async def game_exists(self, game_id: str) -> bool:
|
||||||
|
"""Check if a game exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game_id: The game ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the game exists in cache or database.
|
||||||
|
"""
|
||||||
|
return await self._state_manager.cache_exists(game_id)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Game Lifecycle
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def join_game(
|
||||||
|
self,
|
||||||
|
game_id: str,
|
||||||
|
player_id: str,
|
||||||
|
last_event_id: str | None = None,
|
||||||
|
) -> GameJoinResult:
|
||||||
|
"""Join or rejoin a game session.
|
||||||
|
|
||||||
|
Loads the game state and returns the player's visible view.
|
||||||
|
Used when a player connects or reconnects to a game.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game_id: The game to join.
|
||||||
|
player_id: The joining player's ID.
|
||||||
|
last_event_id: Last event ID for reconnection replay (future use).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GameJoinResult with the visible state.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
state = await self.get_game_state(game_id)
|
||||||
|
except GameNotFoundError:
|
||||||
|
return GameJoinResult(
|
||||||
|
success=False,
|
||||||
|
game_id=game_id,
|
||||||
|
player_id=player_id,
|
||||||
|
message="Game not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
if player_id not in state.players:
|
||||||
|
return GameJoinResult(
|
||||||
|
success=False,
|
||||||
|
game_id=game_id,
|
||||||
|
player_id=player_id,
|
||||||
|
message="You are not a participant in this game",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if game already ended
|
||||||
|
if state.winner_id is not None or state.end_reason is not None:
|
||||||
|
visible = get_visible_state(state, player_id)
|
||||||
|
return GameJoinResult(
|
||||||
|
success=True,
|
||||||
|
game_id=game_id,
|
||||||
|
player_id=player_id,
|
||||||
|
visible_state=visible,
|
||||||
|
is_your_turn=False,
|
||||||
|
message="Game has ended",
|
||||||
|
)
|
||||||
|
|
||||||
|
visible = get_visible_state(state, player_id)
|
||||||
|
|
||||||
|
logger.info(f"Player {player_id} joined game {game_id}")
|
||||||
|
|
||||||
|
return GameJoinResult(
|
||||||
|
success=True,
|
||||||
|
game_id=game_id,
|
||||||
|
player_id=player_id,
|
||||||
|
visible_state=visible,
|
||||||
|
is_your_turn=state.current_player_id == player_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute_action(
|
||||||
|
self,
|
||||||
|
game_id: str,
|
||||||
|
player_id: str,
|
||||||
|
action: Action,
|
||||||
|
) -> GameActionResult:
|
||||||
|
"""Execute a player action in the game.
|
||||||
|
|
||||||
|
Validates the action, executes it through GameEngine, and
|
||||||
|
persists the updated state. The GameEngine is created using
|
||||||
|
the rules stored in the game state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game_id: The game ID.
|
||||||
|
player_id: The acting player's ID.
|
||||||
|
action: The action to execute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GameActionResult with success status and state changes.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GameNotFoundError: If game doesn't exist.
|
||||||
|
PlayerNotInGameError: If player is not in the game.
|
||||||
|
GameAlreadyEndedError: If game has already ended.
|
||||||
|
NotPlayerTurnError: If it's not the player's turn.
|
||||||
|
InvalidActionError: If the action is invalid.
|
||||||
|
"""
|
||||||
|
# Load game state
|
||||||
|
state = await self.get_game_state(game_id)
|
||||||
|
|
||||||
|
# Validate player is in game
|
||||||
|
if player_id not in state.players:
|
||||||
|
raise PlayerNotInGameError(game_id, player_id)
|
||||||
|
|
||||||
|
# Check game hasn't ended
|
||||||
|
if state.winner_id is not None or state.end_reason is not None:
|
||||||
|
raise GameAlreadyEndedError(game_id)
|
||||||
|
|
||||||
|
# Check it's player's turn (unless resignation, which can happen anytime)
|
||||||
|
if not isinstance(action, ResignAction) and state.current_player_id != player_id:
|
||||||
|
raise NotPlayerTurnError(game_id, player_id, state.current_player_id)
|
||||||
|
|
||||||
|
# Create engine with this game's rules
|
||||||
|
engine = self._create_engine_for_game(state)
|
||||||
|
|
||||||
|
# Execute the action
|
||||||
|
result: ActionResult = await engine.execute_action(state, player_id, action)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
raise InvalidActionError(game_id, player_id, result.message)
|
||||||
|
|
||||||
|
# Save state to cache (fast path)
|
||||||
|
await self._state_manager.save_to_cache(state)
|
||||||
|
|
||||||
|
# Check if turn ended - persist to DB at turn boundaries
|
||||||
|
# TODO: Implement turn boundary detection for DB persistence
|
||||||
|
|
||||||
|
# Build response
|
||||||
|
action_result = GameActionResult(
|
||||||
|
success=True,
|
||||||
|
game_id=game_id,
|
||||||
|
action_type=action.type,
|
||||||
|
message=result.message,
|
||||||
|
state_changes={
|
||||||
|
"changes": result.state_changes,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for game over
|
||||||
|
if result.win_result is not None:
|
||||||
|
action_result.game_over = True
|
||||||
|
action_result.winner_id = result.win_result.winner_id
|
||||||
|
action_result.end_reason = result.win_result.end_reason
|
||||||
|
|
||||||
|
# Persist final state to DB
|
||||||
|
await self._state_manager.persist_to_db(state)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Game {game_id} ended: winner={result.win_result.winner_id}, "
|
||||||
|
f"reason={result.win_result.end_reason}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Action executed: game={game_id}, player={player_id}, type={action.type}")
|
||||||
|
|
||||||
|
return action_result
|
||||||
|
|
||||||
|
async def resign_game(
|
||||||
|
self,
|
||||||
|
game_id: str,
|
||||||
|
player_id: str,
|
||||||
|
) -> GameActionResult:
|
||||||
|
"""Resign from a game.
|
||||||
|
|
||||||
|
Convenience method that executes a ResignAction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game_id: The game ID.
|
||||||
|
player_id: The resigning player's ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GameActionResult indicating game over.
|
||||||
|
"""
|
||||||
|
return await self.execute_action(
|
||||||
|
game_id=game_id,
|
||||||
|
player_id=player_id,
|
||||||
|
action=ResignAction(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def end_game(
|
||||||
|
self,
|
||||||
|
game_id: str,
|
||||||
|
winner_id: str | None,
|
||||||
|
end_reason: GameEndReason,
|
||||||
|
) -> None:
|
||||||
|
"""Forcibly end a game (e.g., due to timeout or disconnection).
|
||||||
|
|
||||||
|
This should be called by the timeout system or when a player
|
||||||
|
disconnects without reconnecting within the grace period.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game_id: The game ID.
|
||||||
|
winner_id: The winner's player ID, or None for a draw.
|
||||||
|
end_reason: Why the game ended.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GameNotFoundError: If game doesn't exist.
|
||||||
|
"""
|
||||||
|
state = await self.get_game_state(game_id)
|
||||||
|
|
||||||
|
# Set winner and end reason
|
||||||
|
state.winner_id = winner_id
|
||||||
|
state.end_reason = end_reason
|
||||||
|
|
||||||
|
# Persist to both cache and DB
|
||||||
|
await self._state_manager.save_to_cache(state)
|
||||||
|
await self._state_manager.persist_to_db(state)
|
||||||
|
|
||||||
|
logger.info(f"Game {game_id} forcibly ended: winner={winner_id}, reason={end_reason}")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Game Creation (Skeleton - Full implementation in GS-002)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def create_game(
|
||||||
|
self,
|
||||||
|
player1_id: str | UUID,
|
||||||
|
player2_id: str | UUID,
|
||||||
|
deck1_id: str | UUID | None = None,
|
||||||
|
deck2_id: str | UUID | None = None,
|
||||||
|
# Rules come from the frontend request - this is required, not optional
|
||||||
|
# Defaulting to None here only for the skeleton; GS-002 will make it required
|
||||||
|
) -> None:
|
||||||
|
"""Create a new game between two players.
|
||||||
|
|
||||||
|
This is a skeleton that will be fully implemented in GS-002.
|
||||||
|
|
||||||
|
IMPORTANT: rules_config will be a required parameter - it comes
|
||||||
|
from the frontend request, not from server-side defaults.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
player1_id: First player's ID.
|
||||||
|
player2_id: Second player's ID.
|
||||||
|
deck1_id: First player's deck ID.
|
||||||
|
deck2_id: Second player's deck ID.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: Until GS-002 is complete.
|
||||||
|
"""
|
||||||
|
# TODO (GS-002): Full implementation with:
|
||||||
|
# - rules_config: RulesConfig parameter (required, from frontend)
|
||||||
|
# - Load decks via DeckService
|
||||||
|
# - Load card registry from CardService
|
||||||
|
# - Convert to CardInstances with unique IDs
|
||||||
|
# - Create GameState with the provided rules_config
|
||||||
|
# - Persist to Redis and Postgres
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Game creation not yet implemented - see GS-002. "
|
||||||
|
"Rules will come from frontend request, not server defaults."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Global singleton instance
|
||||||
|
# Note: This is safe because GameService is stateless regarding game rules.
|
||||||
|
# Each game's rules are stored in its GameState, not in this service.
|
||||||
|
game_service = GameService()
|
||||||
46
backend/app/socketio/__init__.py
Normal file
46
backend/app/socketio/__init__.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
"""Socket.IO module for real-time game communication.
|
||||||
|
|
||||||
|
This module provides WebSocket-based real-time communication for:
|
||||||
|
- Active game sessions (actions, state updates, turn notifications)
|
||||||
|
- Connection management with session recovery
|
||||||
|
- Turn timeout handling
|
||||||
|
- JWT-based authentication
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
- Uses python-socketio with ASGI mode for FastAPI integration
|
||||||
|
- /game namespace handles all active game communication
|
||||||
|
- /lobby namespace (Phase 6) will handle matchmaking
|
||||||
|
- JWT authentication on connect (WS-004)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
The Socket.IO server is mounted alongside FastAPI in app/main.py.
|
||||||
|
Clients connect to ws://host/socket.io/ and join the /game namespace.
|
||||||
|
|
||||||
|
Client connection example:
|
||||||
|
socket = io("ws://host", {
|
||||||
|
auth: { token: "JWT_ACCESS_TOKEN" }
|
||||||
|
});
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.socketio.auth import (
|
||||||
|
AuthResult,
|
||||||
|
authenticate_connection,
|
||||||
|
cleanup_authenticated_session,
|
||||||
|
get_session_user_id,
|
||||||
|
require_auth,
|
||||||
|
setup_authenticated_session,
|
||||||
|
)
|
||||||
|
from app.socketio.server import create_socketio_app, sio
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Server
|
||||||
|
"create_socketio_app",
|
||||||
|
"sio",
|
||||||
|
# Auth
|
||||||
|
"AuthResult",
|
||||||
|
"authenticate_connection",
|
||||||
|
"cleanup_authenticated_session",
|
||||||
|
"get_session_user_id",
|
||||||
|
"require_auth",
|
||||||
|
"setup_authenticated_session",
|
||||||
|
]
|
||||||
283
backend/app/socketio/auth.py
Normal file
283
backend/app/socketio/auth.py
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
"""Socket.IO authentication middleware for WebSocket connections.
|
||||||
|
|
||||||
|
This module provides JWT-based authentication for Socket.IO connections.
|
||||||
|
It validates access tokens and attaches user information to the socket session.
|
||||||
|
|
||||||
|
Authentication Flow:
|
||||||
|
1. Client connects with `auth: { token: "JWT_ACCESS_TOKEN" }`
|
||||||
|
2. Server extracts and validates the JWT
|
||||||
|
3. If valid, user_id is stored in socket session
|
||||||
|
4. If invalid, connection is rejected with appropriate error
|
||||||
|
|
||||||
|
Session Data:
|
||||||
|
After successful authentication, the socket session contains:
|
||||||
|
- user_id: str (UUID as string)
|
||||||
|
- authenticated_at: str (ISO timestamp)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# In connect handler:
|
||||||
|
auth_result = await authenticate_connection(sid, auth)
|
||||||
|
if not auth_result.success:
|
||||||
|
return False # Reject connection
|
||||||
|
|
||||||
|
# Later, get user_id:
|
||||||
|
session = await sio.get_session(sid, namespace="/game")
|
||||||
|
user_id = session.get("user_id")
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from app.services.connection_manager import connection_manager
|
||||||
|
from app.services.jwt_service import verify_access_token
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AuthResult:
|
||||||
|
"""Result of authentication attempt.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
success: Whether authentication succeeded.
|
||||||
|
user_id: User's UUID if successful, None otherwise.
|
||||||
|
error_code: Error code for client if failed.
|
||||||
|
error_message: Human-readable error message if failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
user_id: UUID | None = None
|
||||||
|
error_code: str | None = None
|
||||||
|
error_message: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_token(auth: dict[str, object] | None) -> str | None:
|
||||||
|
"""Extract JWT token from Socket.IO auth data.
|
||||||
|
|
||||||
|
Clients should send the token in the auth dict:
|
||||||
|
socket.connect({ auth: { token: "JWT_TOKEN" } })
|
||||||
|
|
||||||
|
Also supports:
|
||||||
|
- auth.authorization: "Bearer TOKEN"
|
||||||
|
- auth.access_token: "TOKEN"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth: Authentication data from Socket.IO connect.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JWT token string if found, None otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
token = extract_token({"token": "eyJ..."})
|
||||||
|
token = extract_token({"authorization": "Bearer eyJ..."})
|
||||||
|
"""
|
||||||
|
if auth is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Primary: auth.token
|
||||||
|
token = auth.get("token")
|
||||||
|
if token and isinstance(token, str):
|
||||||
|
return token
|
||||||
|
|
||||||
|
# Alternative: auth.authorization (Bearer token)
|
||||||
|
authorization = auth.get("authorization")
|
||||||
|
if authorization and isinstance(authorization, str):
|
||||||
|
if authorization.lower().startswith("bearer "):
|
||||||
|
return authorization[7:]
|
||||||
|
return authorization
|
||||||
|
|
||||||
|
# Alternative: auth.access_token
|
||||||
|
access_token = auth.get("access_token")
|
||||||
|
if access_token and isinstance(access_token, str):
|
||||||
|
return access_token
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def authenticate_connection(
|
||||||
|
sid: str,
|
||||||
|
auth: dict[str, object] | None,
|
||||||
|
) -> AuthResult:
|
||||||
|
"""Authenticate a Socket.IO connection using JWT.
|
||||||
|
|
||||||
|
Extracts the JWT from auth data, validates it, and returns the result.
|
||||||
|
Does NOT modify socket session - caller should handle that.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket session ID (for logging).
|
||||||
|
auth: Authentication data from connect event.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthResult with success status and user_id or error details.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
result = await authenticate_connection(sid, auth)
|
||||||
|
if result.success:
|
||||||
|
await sio.save_session(sid, {"user_id": str(result.user_id)})
|
||||||
|
else:
|
||||||
|
logger.warning(f"Auth failed: {result.error_message}")
|
||||||
|
return False # Reject connection
|
||||||
|
"""
|
||||||
|
# Extract token from auth data
|
||||||
|
token = extract_token(auth)
|
||||||
|
|
||||||
|
if token is None:
|
||||||
|
logger.debug(f"Connection {sid}: No token provided")
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_code="missing_token",
|
||||||
|
error_message="Authentication token required",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate the token
|
||||||
|
user_id = verify_access_token(token)
|
||||||
|
|
||||||
|
if user_id is None:
|
||||||
|
logger.debug(f"Connection {sid}: Invalid or expired token")
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_code="invalid_token",
|
||||||
|
error_message="Invalid or expired token",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Connection {sid}: Authenticated as user {user_id}")
|
||||||
|
return AuthResult(
|
||||||
|
success=True,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_authenticated_session(
|
||||||
|
sio: object,
|
||||||
|
sid: str,
|
||||||
|
user_id: UUID,
|
||||||
|
namespace: str = "/game",
|
||||||
|
) -> None:
|
||||||
|
"""Set up socket session with authenticated user data.
|
||||||
|
|
||||||
|
Saves user_id and authentication timestamp to the socket session,
|
||||||
|
and registers the connection with ConnectionManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sio: Socket.IO AsyncServer instance.
|
||||||
|
sid: Socket session ID.
|
||||||
|
user_id: Authenticated user's UUID.
|
||||||
|
namespace: Socket.IO namespace.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
if auth_result.success:
|
||||||
|
await setup_authenticated_session(sio, sid, auth_result.user_id)
|
||||||
|
"""
|
||||||
|
# Import here to avoid circular dependency
|
||||||
|
from app.socketio.server import sio as server_sio
|
||||||
|
|
||||||
|
# Use provided sio or fall back to server sio
|
||||||
|
socket_server = sio if sio is not None else server_sio
|
||||||
|
|
||||||
|
# Save to socket session
|
||||||
|
session_data = {
|
||||||
|
"user_id": str(user_id),
|
||||||
|
"authenticated_at": datetime.now(UTC).isoformat(),
|
||||||
|
}
|
||||||
|
await socket_server.save_session(sid, session_data, namespace=namespace)
|
||||||
|
|
||||||
|
# Register with ConnectionManager
|
||||||
|
await connection_manager.register_connection(sid, user_id)
|
||||||
|
|
||||||
|
logger.info(f"Session established: sid={sid}, user_id={user_id}")
|
||||||
|
|
||||||
|
|
||||||
|
async def cleanup_authenticated_session(
|
||||||
|
sid: str,
|
||||||
|
namespace: str = "/game",
|
||||||
|
) -> str | None:
|
||||||
|
"""Clean up session data on disconnect.
|
||||||
|
|
||||||
|
Unregisters the connection from ConnectionManager and returns
|
||||||
|
the user_id for any additional cleanup needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket session ID.
|
||||||
|
namespace: Socket.IO namespace.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
user_id if session was authenticated, None otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
user_id = await cleanup_authenticated_session(sid)
|
||||||
|
if user_id:
|
||||||
|
# Notify opponent, etc.
|
||||||
|
"""
|
||||||
|
# Unregister from ConnectionManager
|
||||||
|
conn_info = await connection_manager.unregister_connection(sid)
|
||||||
|
|
||||||
|
if conn_info:
|
||||||
|
logger.info(f"Session cleaned up: sid={sid}, user_id={conn_info.user_id}")
|
||||||
|
return conn_info.user_id
|
||||||
|
|
||||||
|
logger.debug(f"No session to clean up for {sid}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session_user_id(
|
||||||
|
sio: object,
|
||||||
|
sid: str,
|
||||||
|
namespace: str = "/game",
|
||||||
|
) -> str | None:
|
||||||
|
"""Get the authenticated user_id from a socket session.
|
||||||
|
|
||||||
|
Convenience function to extract user_id from session data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sio: Socket.IO AsyncServer instance.
|
||||||
|
sid: Socket session ID.
|
||||||
|
namespace: Socket.IO namespace.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
user_id string if authenticated, None otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
user_id = await get_session_user_id(sio, sid)
|
||||||
|
if not user_id:
|
||||||
|
await sio.emit("error", {"message": "Not authenticated"}, to=sid)
|
||||||
|
return
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
session = await sio.get_session(sid, namespace=namespace)
|
||||||
|
return session.get("user_id") if session else None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def require_auth(
|
||||||
|
sio: object,
|
||||||
|
sid: str,
|
||||||
|
namespace: str = "/game",
|
||||||
|
) -> str | None:
|
||||||
|
"""Require authentication for an event handler.
|
||||||
|
|
||||||
|
Returns the user_id if authenticated, None if not.
|
||||||
|
Logs a warning if authentication is missing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sio: Socket.IO AsyncServer instance.
|
||||||
|
sid: Socket session ID.
|
||||||
|
namespace: Socket.IO namespace.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
user_id string if authenticated, None otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@sio.on("game:action", namespace="/game")
|
||||||
|
async def on_action(sid, data):
|
||||||
|
user_id = await require_auth(sio, sid)
|
||||||
|
if not user_id:
|
||||||
|
return {"error": "Not authenticated"}
|
||||||
|
# ... handle action
|
||||||
|
"""
|
||||||
|
user_id = await get_session_user_id(sio, sid, namespace)
|
||||||
|
if user_id is None:
|
||||||
|
logger.warning(f"Unauthenticated event from {sid}")
|
||||||
|
return user_id
|
||||||
240
backend/app/socketio/server.py
Normal file
240
backend/app/socketio/server.py
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
"""Socket.IO server configuration and ASGI app creation.
|
||||||
|
|
||||||
|
This module sets up the python-socketio AsyncServer and creates the combined
|
||||||
|
ASGI application that mounts Socket.IO alongside FastAPI.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
- AsyncServer handles WebSocket connections with async_mode='asgi'
|
||||||
|
- Socket.IO app is mounted at /socket.io path
|
||||||
|
- CORS settings match FastAPI configuration
|
||||||
|
- JWT authentication on connect via auth.py
|
||||||
|
- Namespaces are registered for different communication domains
|
||||||
|
|
||||||
|
Namespaces:
|
||||||
|
/game - Active game communication (actions, state updates)
|
||||||
|
/lobby - Pre-game lobby (matchmaking, invites) - Phase 6
|
||||||
|
|
||||||
|
Authentication:
|
||||||
|
Clients must provide a JWT access token in the auth parameter:
|
||||||
|
socket.connect({ auth: { token: "JWT_ACCESS_TOKEN" } })
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from app.socketio import create_socketio_app
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
fastapi_app = FastAPI()
|
||||||
|
combined_app = create_socketio_app(fastapi_app)
|
||||||
|
# Run with: uvicorn app.main:app
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import socketio
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.socketio.auth import (
|
||||||
|
authenticate_connection,
|
||||||
|
cleanup_authenticated_session,
|
||||||
|
require_auth,
|
||||||
|
setup_authenticated_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Create the AsyncServer instance
|
||||||
|
# - async_mode='asgi' for ASGI compatibility with uvicorn
|
||||||
|
# - cors_allowed_origins matches FastAPI CORS settings
|
||||||
|
# - logger enables Socket.IO internal logging in debug mode
|
||||||
|
sio = socketio.AsyncServer(
|
||||||
|
async_mode="asgi",
|
||||||
|
cors_allowed_origins=settings.cors_origins if settings.cors_origins else [],
|
||||||
|
logger=settings.debug,
|
||||||
|
engineio_logger=settings.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# /game Namespace - Active Game Communication
|
||||||
|
# =============================================================================
|
||||||
|
# These are skeleton handlers that will be fully implemented in WS-005.
|
||||||
|
# For now, they provide basic connection lifecycle handling.
|
||||||
|
|
||||||
|
|
||||||
|
@sio.event(namespace="/game")
|
||||||
|
async def connect(
|
||||||
|
sid: str, environ: dict[str, object], auth: dict[str, object] | None = None
|
||||||
|
) -> bool | None:
|
||||||
|
"""Handle client connection to /game namespace.
|
||||||
|
|
||||||
|
Authenticates the connection using JWT from auth data.
|
||||||
|
Rejects connections without valid authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket session ID assigned by Socket.IO.
|
||||||
|
environ: WSGI/ASGI environ dict with request info.
|
||||||
|
auth: Authentication data sent by client (JWT token).
|
||||||
|
Expected format: { token: "JWT_ACCESS_TOKEN" }
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True to accept connection, False to reject.
|
||||||
|
None is treated as True (accept).
|
||||||
|
"""
|
||||||
|
# Authenticate the connection
|
||||||
|
auth_result = await authenticate_connection(sid, auth)
|
||||||
|
|
||||||
|
if not auth_result.success:
|
||||||
|
logger.warning(
|
||||||
|
f"Connection rejected for {sid}: {auth_result.error_code} - {auth_result.error_message}"
|
||||||
|
)
|
||||||
|
# Emit error before rejecting (client may receive this)
|
||||||
|
await sio.emit(
|
||||||
|
"auth_error",
|
||||||
|
{
|
||||||
|
"code": auth_result.error_code,
|
||||||
|
"message": auth_result.error_message,
|
||||||
|
},
|
||||||
|
to=sid,
|
||||||
|
namespace="/game",
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Set up authenticated session and register connection
|
||||||
|
await setup_authenticated_session(sio, sid, auth_result.user_id, namespace="/game")
|
||||||
|
|
||||||
|
logger.info(f"Client authenticated to /game: sid={sid}, user_id={auth_result.user_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@sio.event(namespace="/game")
|
||||||
|
async def disconnect(sid: str) -> None:
|
||||||
|
"""Handle client disconnection from /game namespace.
|
||||||
|
|
||||||
|
Cleans up connection state and notifies other game participants.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket session ID of disconnecting client.
|
||||||
|
"""
|
||||||
|
# Clean up session and get user info
|
||||||
|
user_id = await cleanup_authenticated_session(sid, namespace="/game")
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
logger.info(f"Client disconnected from /game: sid={sid}, user_id={user_id}")
|
||||||
|
# TODO (WS-005): Notify opponent of disconnection if in game
|
||||||
|
else:
|
||||||
|
logger.debug(f"Unauthenticated client disconnected: {sid}")
|
||||||
|
|
||||||
|
|
||||||
|
@sio.on("game:join", namespace="/game")
|
||||||
|
async def on_game_join(sid: str, data: dict[str, object]) -> dict[str, object]:
|
||||||
|
"""Handle request to join/rejoin a game session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket session ID.
|
||||||
|
data: Message containing game_id and optional last_event_id for resume.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response with game state or error.
|
||||||
|
"""
|
||||||
|
logger.debug(f"game:join from {sid}: {data}")
|
||||||
|
# TODO (WS-005): Implement with GameService
|
||||||
|
return {"error": "Not implemented yet"}
|
||||||
|
|
||||||
|
|
||||||
|
@sio.on("game:action", namespace="/game")
|
||||||
|
async def on_game_action(sid: str, data: dict[str, object]) -> dict[str, object]:
|
||||||
|
"""Handle game action from player.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket session ID.
|
||||||
|
data: Action message with type and parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Action result or error.
|
||||||
|
"""
|
||||||
|
logger.debug(f"game:action from {sid}: {data}")
|
||||||
|
# TODO (WS-005): Implement with GameService
|
||||||
|
return {"error": "Not implemented yet"}
|
||||||
|
|
||||||
|
|
||||||
|
@sio.on("game:resign", namespace="/game")
|
||||||
|
async def on_game_resign(sid: str, data: dict[str, object]) -> dict[str, object]:
|
||||||
|
"""Handle player resignation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket session ID.
|
||||||
|
data: Resignation message (may be empty).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Confirmation or error.
|
||||||
|
"""
|
||||||
|
logger.debug(f"game:resign from {sid}: {data}")
|
||||||
|
# TODO (WS-005): Implement with GameService
|
||||||
|
return {"error": "Not implemented yet"}
|
||||||
|
|
||||||
|
|
||||||
|
@sio.on("game:heartbeat", namespace="/game")
|
||||||
|
async def on_game_heartbeat(sid: str, data: dict[str, object] | None = None) -> dict[str, object]:
|
||||||
|
"""Handle heartbeat to keep connection alive.
|
||||||
|
|
||||||
|
Updates last_seen timestamp in ConnectionManager to prevent
|
||||||
|
the connection from being marked as stale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid: Socket session ID.
|
||||||
|
data: Optional heartbeat data (message_id for tracking).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Heartbeat acknowledgment with server timestamp.
|
||||||
|
"""
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from app.services.connection_manager import connection_manager
|
||||||
|
|
||||||
|
# Require authentication
|
||||||
|
user_id = await require_auth(sio, sid)
|
||||||
|
if not user_id:
|
||||||
|
return {"error": "Not authenticated", "code": "unauthenticated"}
|
||||||
|
|
||||||
|
# Update last_seen in ConnectionManager
|
||||||
|
await connection_manager.update_heartbeat(sid)
|
||||||
|
|
||||||
|
# Return acknowledgment with timestamp
|
||||||
|
return {
|
||||||
|
"type": "heartbeat_ack",
|
||||||
|
"timestamp": datetime.now(UTC).isoformat(),
|
||||||
|
"message_id": data.get("message_id") if data else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# ASGI App Creation
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def create_socketio_app(fastapi_app: "FastAPI") -> socketio.ASGIApp:
|
||||||
|
"""Create combined ASGI app with Socket.IO mounted alongside FastAPI.
|
||||||
|
|
||||||
|
This wraps the FastAPI app with Socket.IO, handling:
|
||||||
|
- WebSocket connections at /socket.io/
|
||||||
|
- HTTP requests passed through to FastAPI
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fastapi_app: The FastAPI application instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined ASGI application to use with uvicorn.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
app = FastAPI()
|
||||||
|
combined = create_socketio_app(app)
|
||||||
|
# uvicorn will use 'combined' as the ASGI app
|
||||||
|
"""
|
||||||
|
return socketio.ASGIApp(
|
||||||
|
sio,
|
||||||
|
other_asgi_app=fastapi_app,
|
||||||
|
socketio_path="socket.io",
|
||||||
|
)
|
||||||
@ -9,8 +9,8 @@
|
|||||||
"description": "Real-time gameplay infrastructure - WebSocket communication, game lifecycle management, reconnection handling, and turn timeout system",
|
"description": "Real-time gameplay infrastructure - WebSocket communication, game lifecycle management, reconnection handling, and turn timeout system",
|
||||||
"totalEstimatedHours": 45,
|
"totalEstimatedHours": 45,
|
||||||
"totalTasks": 18,
|
"totalTasks": 18,
|
||||||
"completedTasks": 0,
|
"completedTasks": 5,
|
||||||
"status": "not_started",
|
"status": "in_progress",
|
||||||
"masterPlan": "../PROJECT_PLAN_MASTER.json"
|
"masterPlan": "../PROJECT_PLAN_MASTER.json"
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -105,8 +105,8 @@
|
|||||||
"description": "Install and configure python-socketio ASGI server, mount alongside FastAPI app",
|
"description": "Install and configure python-socketio ASGI server, mount alongside FastAPI app",
|
||||||
"category": "infrastructure",
|
"category": "infrastructure",
|
||||||
"priority": 1,
|
"priority": 1,
|
||||||
"completed": false,
|
"completed": true,
|
||||||
"tested": false,
|
"tested": true,
|
||||||
"dependencies": [],
|
"dependencies": [],
|
||||||
"files": [
|
"files": [
|
||||||
{"path": "app/socketio/__init__.py", "status": "create"},
|
{"path": "app/socketio/__init__.py", "status": "create"},
|
||||||
@ -130,8 +130,8 @@
|
|||||||
"description": "Define Pydantic models for all WebSocket message types",
|
"description": "Define Pydantic models for all WebSocket message types",
|
||||||
"category": "schemas",
|
"category": "schemas",
|
||||||
"priority": 2,
|
"priority": 2,
|
||||||
"completed": false,
|
"completed": true,
|
||||||
"tested": false,
|
"tested": true,
|
||||||
"dependencies": ["WS-001"],
|
"dependencies": ["WS-001"],
|
||||||
"files": [
|
"files": [
|
||||||
{"path": "app/schemas/ws_messages.py", "status": "create"}
|
{"path": "app/schemas/ws_messages.py", "status": "create"}
|
||||||
@ -153,8 +153,8 @@
|
|||||||
"description": "Manage WebSocket connections with Redis-backed session tracking",
|
"description": "Manage WebSocket connections with Redis-backed session tracking",
|
||||||
"category": "services",
|
"category": "services",
|
||||||
"priority": 3,
|
"priority": 3,
|
||||||
"completed": false,
|
"completed": true,
|
||||||
"tested": false,
|
"tested": true,
|
||||||
"dependencies": ["WS-001"],
|
"dependencies": ["WS-001"],
|
||||||
"files": [
|
"files": [
|
||||||
{"path": "app/services/connection_manager.py", "status": "create"}
|
{"path": "app/services/connection_manager.py", "status": "create"}
|
||||||
@ -177,8 +177,8 @@
|
|||||||
"description": "Authenticate WebSocket connections using JWT tokens",
|
"description": "Authenticate WebSocket connections using JWT tokens",
|
||||||
"category": "auth",
|
"category": "auth",
|
||||||
"priority": 4,
|
"priority": 4,
|
||||||
"completed": false,
|
"completed": true,
|
||||||
"tested": false,
|
"tested": true,
|
||||||
"dependencies": ["WS-001", "WS-003"],
|
"dependencies": ["WS-001", "WS-003"],
|
||||||
"files": [
|
"files": [
|
||||||
{"path": "app/socketio/auth.py", "status": "create"},
|
{"path": "app/socketio/auth.py", "status": "create"},
|
||||||
@ -614,6 +614,13 @@
|
|||||||
"risk": "Turn timeout drift due to server restart",
|
"risk": "Turn timeout drift due to server restart",
|
||||||
"mitigation": "Store absolute deadline in Redis/Postgres, recalculate on startup",
|
"mitigation": "Store absolute deadline in Redis/Postgres, recalculate on startup",
|
||||||
"priority": "medium"
|
"priority": "medium"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"risk": "ConnectionManager race condition on rapid reconnects",
|
||||||
|
"mitigation": "Consider Redis Lua script for atomic old-connection cleanup in register_connection. Low probability in practice but could cause connection tracking issues during rapid connect/disconnect cycles.",
|
||||||
|
"priority": "low",
|
||||||
|
"status": "identified-in-review",
|
||||||
|
"notes": "Also consider: periodic cleanup of game_conns sets for long-running games, rate limiting in auth layer"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
||||||
|
|||||||
@ -107,6 +107,11 @@ module = [
|
|||||||
]
|
]
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
# Socket.IO handlers use untyped decorators from the library
|
||||||
|
module = "app.socketio.*"
|
||||||
|
disallow_untyped_decorators = false
|
||||||
|
|
||||||
# Coverage configuration
|
# Coverage configuration
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
source = ["app"]
|
source = ["app"]
|
||||||
|
|||||||
1
backend/tests/socketio/__init__.py
Normal file
1
backend/tests/socketio/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Socket.IO integration tests
|
||||||
384
backend/tests/socketio/test_auth.py
Normal file
384
backend/tests/socketio/test_auth.py
Normal file
@ -0,0 +1,384 @@
|
|||||||
|
"""Tests for Socket.IO authentication middleware.
|
||||||
|
|
||||||
|
This module tests JWT-based authentication for WebSocket connections,
|
||||||
|
including token extraction, validation, and session management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.socketio.auth import (
|
||||||
|
AuthResult,
|
||||||
|
authenticate_connection,
|
||||||
|
cleanup_authenticated_session,
|
||||||
|
extract_token,
|
||||||
|
get_session_user_id,
|
||||||
|
require_auth,
|
||||||
|
setup_authenticated_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractToken:
|
||||||
|
"""Tests for token extraction from Socket.IO auth data."""
|
||||||
|
|
||||||
|
def test_extract_token_from_token_field(self) -> None:
|
||||||
|
"""Test extracting token from the primary 'token' field.
|
||||||
|
|
||||||
|
The standard way for clients to pass the JWT is via auth.token.
|
||||||
|
This is the recommended format for Socket.IO clients.
|
||||||
|
"""
|
||||||
|
auth = {"token": "my-jwt-token"}
|
||||||
|
result = extract_token(auth)
|
||||||
|
assert result == "my-jwt-token"
|
||||||
|
|
||||||
|
def test_extract_token_from_authorization_bearer(self) -> None:
|
||||||
|
"""Test extracting token from Bearer authorization header format.
|
||||||
|
|
||||||
|
Some clients may pass the token as a Bearer token for consistency
|
||||||
|
with HTTP API authentication patterns.
|
||||||
|
"""
|
||||||
|
auth = {"authorization": "Bearer my-jwt-token"}
|
||||||
|
result = extract_token(auth)
|
||||||
|
assert result == "my-jwt-token"
|
||||||
|
|
||||||
|
def test_extract_token_from_authorization_without_bearer(self) -> None:
|
||||||
|
"""Test extracting token from authorization without Bearer prefix.
|
||||||
|
|
||||||
|
If the client provides just the token in authorization field,
|
||||||
|
we should still accept it.
|
||||||
|
"""
|
||||||
|
auth = {"authorization": "my-jwt-token"}
|
||||||
|
result = extract_token(auth)
|
||||||
|
assert result == "my-jwt-token"
|
||||||
|
|
||||||
|
def test_extract_token_from_access_token_field(self) -> None:
|
||||||
|
"""Test extracting token from access_token field.
|
||||||
|
|
||||||
|
Alternative field name for OAuth-style clients.
|
||||||
|
"""
|
||||||
|
auth = {"access_token": "my-jwt-token"}
|
||||||
|
result = extract_token(auth)
|
||||||
|
assert result == "my-jwt-token"
|
||||||
|
|
||||||
|
def test_extract_token_returns_none_for_none_auth(self) -> None:
|
||||||
|
"""Test that None auth data returns None token.
|
||||||
|
|
||||||
|
Clients that don't provide any auth should get None,
|
||||||
|
triggering the authentication failure path.
|
||||||
|
"""
|
||||||
|
result = extract_token(None)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_extract_token_returns_none_for_empty_auth(self) -> None:
|
||||||
|
"""Test that empty auth dict returns None token.
|
||||||
|
|
||||||
|
An empty auth object should be treated as unauthenticated.
|
||||||
|
"""
|
||||||
|
result = extract_token({})
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_extract_token_returns_none_for_non_string_token(self) -> None:
|
||||||
|
"""Test that non-string token values are rejected.
|
||||||
|
|
||||||
|
Only string tokens are valid - reject numbers, objects, etc.
|
||||||
|
"""
|
||||||
|
result = extract_token({"token": 12345})
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
result = extract_token({"token": {"nested": "value"}})
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_extract_token_prefers_token_field(self) -> None:
|
||||||
|
"""Test that 'token' field takes precedence over alternatives.
|
||||||
|
|
||||||
|
If multiple token fields are present, we should use the
|
||||||
|
primary 'token' field.
|
||||||
|
"""
|
||||||
|
auth = {
|
||||||
|
"token": "primary-token",
|
||||||
|
"authorization": "Bearer secondary-token",
|
||||||
|
"access_token": "tertiary-token",
|
||||||
|
}
|
||||||
|
result = extract_token(auth)
|
||||||
|
assert result == "primary-token"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthenticateConnection:
|
||||||
|
"""Tests for connection authentication."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_success_with_valid_token(self) -> None:
|
||||||
|
"""Test successful authentication with a valid JWT.
|
||||||
|
|
||||||
|
A valid access token should result in AuthResult with
|
||||||
|
success=True and the user_id from the token.
|
||||||
|
"""
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
with patch("app.socketio.auth.verify_access_token") as mock_verify:
|
||||||
|
mock_verify.return_value = user_id
|
||||||
|
|
||||||
|
result = await authenticate_connection("test-sid", {"token": "valid-token"})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.user_id == user_id
|
||||||
|
assert result.error_code is None
|
||||||
|
mock_verify.assert_called_once_with("valid-token")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_fails_with_missing_token(self) -> None:
|
||||||
|
"""Test authentication failure when no token is provided.
|
||||||
|
|
||||||
|
Connections without any auth data should fail with
|
||||||
|
a 'missing_token' error code.
|
||||||
|
"""
|
||||||
|
result = await authenticate_connection("test-sid", None)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.user_id is None
|
||||||
|
assert result.error_code == "missing_token"
|
||||||
|
assert "required" in result.error_message.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_fails_with_empty_auth(self) -> None:
|
||||||
|
"""Test authentication failure with empty auth object.
|
||||||
|
|
||||||
|
An auth object without any token fields should fail.
|
||||||
|
"""
|
||||||
|
result = await authenticate_connection("test-sid", {})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.error_code == "missing_token"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_fails_with_invalid_token(self) -> None:
|
||||||
|
"""Test authentication failure with invalid/expired JWT.
|
||||||
|
|
||||||
|
When verify_access_token returns None (invalid token),
|
||||||
|
we should fail with 'invalid_token' error.
|
||||||
|
"""
|
||||||
|
with patch("app.socketio.auth.verify_access_token") as mock_verify:
|
||||||
|
mock_verify.return_value = None # Token validation failed
|
||||||
|
|
||||||
|
result = await authenticate_connection("test-sid", {"token": "invalid-token"})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.user_id is None
|
||||||
|
assert result.error_code == "invalid_token"
|
||||||
|
assert (
|
||||||
|
"invalid" in result.error_message.lower()
|
||||||
|
or "expired" in result.error_message.lower()
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_extracts_token_from_bearer(self) -> None:
|
||||||
|
"""Test that authentication works with Bearer format.
|
||||||
|
|
||||||
|
The authenticate function should handle Bearer token format
|
||||||
|
through the extract_token function.
|
||||||
|
"""
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
with patch("app.socketio.auth.verify_access_token") as mock_verify:
|
||||||
|
mock_verify.return_value = user_id
|
||||||
|
|
||||||
|
result = await authenticate_connection("test-sid", {"authorization": "Bearer my-token"})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
mock_verify.assert_called_once_with("my-token")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetupAuthenticatedSession:
|
||||||
|
"""Tests for session setup after authentication."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_saves_session_data(self) -> None:
|
||||||
|
"""Test that session setup saves user_id and timestamp.
|
||||||
|
|
||||||
|
After authentication, the socket session should contain
|
||||||
|
the user_id and authentication timestamp.
|
||||||
|
"""
|
||||||
|
user_id = uuid4()
|
||||||
|
mock_sio = AsyncMock()
|
||||||
|
|
||||||
|
with patch("app.socketio.auth.connection_manager") as mock_cm:
|
||||||
|
mock_cm.register_connection = AsyncMock()
|
||||||
|
|
||||||
|
await setup_authenticated_session(mock_sio, "test-sid", user_id)
|
||||||
|
|
||||||
|
# Verify session was saved
|
||||||
|
mock_sio.save_session.assert_called_once()
|
||||||
|
call_args = mock_sio.save_session.call_args
|
||||||
|
assert call_args.args[0] == "test-sid"
|
||||||
|
|
||||||
|
session_data = call_args.args[1]
|
||||||
|
assert session_data["user_id"] == str(user_id)
|
||||||
|
assert "authenticated_at" in session_data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_registers_with_connection_manager(self) -> None:
|
||||||
|
"""Test that session setup registers with ConnectionManager.
|
||||||
|
|
||||||
|
The connection should be tracked in ConnectionManager for
|
||||||
|
presence detection and game association.
|
||||||
|
"""
|
||||||
|
user_id = uuid4()
|
||||||
|
mock_sio = AsyncMock()
|
||||||
|
|
||||||
|
with patch("app.socketio.auth.connection_manager") as mock_cm:
|
||||||
|
mock_cm.register_connection = AsyncMock()
|
||||||
|
|
||||||
|
await setup_authenticated_session(mock_sio, "test-sid", user_id)
|
||||||
|
|
||||||
|
mock_cm.register_connection.assert_called_once_with("test-sid", user_id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupAuthenticatedSession:
|
||||||
|
"""Tests for session cleanup on disconnect."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_unregisters_connection(self) -> None:
|
||||||
|
"""Test that cleanup unregisters from ConnectionManager.
|
||||||
|
|
||||||
|
On disconnect, the connection should be removed from
|
||||||
|
ConnectionManager to update presence tracking.
|
||||||
|
"""
|
||||||
|
with patch("app.socketio.auth.connection_manager") as mock_cm:
|
||||||
|
mock_conn_info = MagicMock()
|
||||||
|
mock_conn_info.user_id = "user-123"
|
||||||
|
mock_cm.unregister_connection = AsyncMock(return_value=mock_conn_info)
|
||||||
|
|
||||||
|
result = await cleanup_authenticated_session("test-sid")
|
||||||
|
|
||||||
|
assert result == "user-123"
|
||||||
|
mock_cm.unregister_connection.assert_called_once_with("test-sid")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_returns_none_for_unknown_session(self) -> None:
|
||||||
|
"""Test cleanup returns None for non-existent sessions.
|
||||||
|
|
||||||
|
If the connection wasn't registered (e.g., auth failed),
|
||||||
|
cleanup should return None gracefully.
|
||||||
|
"""
|
||||||
|
with patch("app.socketio.auth.connection_manager") as mock_cm:
|
||||||
|
mock_cm.unregister_connection = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
result = await cleanup_authenticated_session("unknown-sid")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSessionUserId:
|
||||||
|
"""Tests for session user_id retrieval."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_session_user_id_returns_id(self) -> None:
|
||||||
|
"""Test retrieving user_id from authenticated session.
|
||||||
|
|
||||||
|
For authenticated sessions, get_session_user_id should
|
||||||
|
return the stored user_id string.
|
||||||
|
"""
|
||||||
|
mock_sio = AsyncMock()
|
||||||
|
mock_sio.get_session = AsyncMock(
|
||||||
|
return_value={"user_id": "user-123", "authenticated_at": "2024-01-01"}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await get_session_user_id(mock_sio, "test-sid")
|
||||||
|
|
||||||
|
assert result == "user-123"
|
||||||
|
mock_sio.get_session.assert_called_once_with("test-sid", namespace="/game")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_session_user_id_returns_none_for_missing(self) -> None:
|
||||||
|
"""Test that missing session returns None.
|
||||||
|
|
||||||
|
If no session exists for the sid, we should return None
|
||||||
|
rather than raising an error.
|
||||||
|
"""
|
||||||
|
mock_sio = AsyncMock()
|
||||||
|
mock_sio.get_session = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
result = await get_session_user_id(mock_sio, "test-sid")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_session_user_id_handles_exception(self) -> None:
|
||||||
|
"""Test that exceptions are caught and return None.
|
||||||
|
|
||||||
|
If get_session raises an exception, we should catch it
|
||||||
|
and return None to avoid breaking the event handler.
|
||||||
|
"""
|
||||||
|
mock_sio = AsyncMock()
|
||||||
|
mock_sio.get_session = AsyncMock(side_effect=Exception("Session error"))
|
||||||
|
|
||||||
|
result = await get_session_user_id(mock_sio, "test-sid")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequireAuth:
|
||||||
|
"""Tests for the require_auth helper."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_require_auth_returns_user_id_for_authenticated(self) -> None:
|
||||||
|
"""Test that require_auth returns user_id for valid sessions.
|
||||||
|
|
||||||
|
Authenticated sessions should return the user_id for use
|
||||||
|
in event handlers.
|
||||||
|
"""
|
||||||
|
mock_sio = AsyncMock()
|
||||||
|
mock_sio.get_session = AsyncMock(return_value={"user_id": "user-123"})
|
||||||
|
|
||||||
|
result = await require_auth(mock_sio, "test-sid")
|
||||||
|
|
||||||
|
assert result == "user-123"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_require_auth_returns_none_for_unauthenticated(self) -> None:
|
||||||
|
"""Test that require_auth returns None for unauthenticated sessions.
|
||||||
|
|
||||||
|
Unauthenticated events should get None, allowing handlers
|
||||||
|
to return an error response.
|
||||||
|
"""
|
||||||
|
mock_sio = AsyncMock()
|
||||||
|
mock_sio.get_session = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
result = await require_auth(mock_sio, "test-sid")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthResultDataclass:
|
||||||
|
"""Tests for the AuthResult dataclass."""
|
||||||
|
|
||||||
|
def test_auth_result_success(self) -> None:
|
||||||
|
"""Test creating a successful AuthResult.
|
||||||
|
|
||||||
|
Success results should have user_id and no error fields.
|
||||||
|
"""
|
||||||
|
user_id = uuid4()
|
||||||
|
result = AuthResult(success=True, user_id=user_id)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.user_id == user_id
|
||||||
|
assert result.error_code is None
|
||||||
|
assert result.error_message is None
|
||||||
|
|
||||||
|
def test_auth_result_failure(self) -> None:
|
||||||
|
"""Test creating a failed AuthResult.
|
||||||
|
|
||||||
|
Failure results should have error code/message and no user_id.
|
||||||
|
"""
|
||||||
|
result = AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_code="invalid_token",
|
||||||
|
error_message="Token expired",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.user_id is None
|
||||||
|
assert result.error_code == "invalid_token"
|
||||||
|
assert result.error_message == "Token expired"
|
||||||
113
backend/tests/socketio/test_server_setup.py
Normal file
113
backend/tests/socketio/test_server_setup.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
"""Tests for Socket.IO server setup and configuration.
|
||||||
|
|
||||||
|
These tests verify that the Socket.IO server is correctly configured
|
||||||
|
and can be mounted alongside FastAPI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TestSocketIOSetup:
|
||||||
|
"""Tests for Socket.IO server initialization."""
|
||||||
|
|
||||||
|
def test_sio_server_exists(self) -> None:
|
||||||
|
"""
|
||||||
|
Verify that the Socket.IO AsyncServer is created.
|
||||||
|
|
||||||
|
The server instance should be available as a module-level export
|
||||||
|
and configured for ASGI mode.
|
||||||
|
"""
|
||||||
|
from app.socketio import sio
|
||||||
|
|
||||||
|
assert sio is not None
|
||||||
|
assert sio.async_mode == "asgi"
|
||||||
|
|
||||||
|
def test_game_namespace_handlers_registered(self) -> None:
|
||||||
|
"""
|
||||||
|
Verify that /game namespace event handlers are registered.
|
||||||
|
|
||||||
|
All required event handlers should be available on the sio instance
|
||||||
|
before any connections are made.
|
||||||
|
"""
|
||||||
|
from app.socketio import sio
|
||||||
|
|
||||||
|
# Check that handlers are registered for /game namespace
|
||||||
|
assert "/game" in sio.handlers
|
||||||
|
|
||||||
|
game_handlers = sio.handlers["/game"]
|
||||||
|
expected_events = [
|
||||||
|
"connect",
|
||||||
|
"disconnect",
|
||||||
|
"game:join",
|
||||||
|
"game:action",
|
||||||
|
"game:resign",
|
||||||
|
"game:heartbeat",
|
||||||
|
]
|
||||||
|
|
||||||
|
for event in expected_events:
|
||||||
|
assert event in game_handlers, f"Missing handler for {event}"
|
||||||
|
|
||||||
|
def test_create_socketio_app_returns_asgi_app(self) -> None:
|
||||||
|
"""
|
||||||
|
Verify that create_socketio_app returns a valid ASGI application.
|
||||||
|
|
||||||
|
The combined app should wrap the FastAPI app and handle both
|
||||||
|
HTTP and WebSocket connections.
|
||||||
|
"""
|
||||||
|
import socketio
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from app.socketio import create_socketio_app
|
||||||
|
|
||||||
|
# Create a minimal FastAPI app for testing
|
||||||
|
test_app = FastAPI()
|
||||||
|
|
||||||
|
# Create combined ASGI app
|
||||||
|
combined = create_socketio_app(test_app)
|
||||||
|
|
||||||
|
# Verify it's a Socket.IO ASGI app
|
||||||
|
assert isinstance(combined, socketio.ASGIApp)
|
||||||
|
|
||||||
|
def test_cors_configured_from_settings(self) -> None:
|
||||||
|
"""
|
||||||
|
Verify that Socket.IO CORS settings match application settings.
|
||||||
|
|
||||||
|
The Socket.IO server should allow the same origins as FastAPI
|
||||||
|
to ensure consistent cross-origin behavior.
|
||||||
|
"""
|
||||||
|
from app.config import settings
|
||||||
|
from app.socketio import sio
|
||||||
|
|
||||||
|
# Socket.IO stores CORS origins in eio (Engine.IO) settings
|
||||||
|
# The cors_allowed_origins should match settings
|
||||||
|
assert sio.eio.cors_allowed_origins == settings.cors_origins
|
||||||
|
|
||||||
|
|
||||||
|
class TestMainAppIntegration:
|
||||||
|
"""Tests for Socket.IO integration with main app."""
|
||||||
|
|
||||||
|
def test_main_app_is_combined_asgi_app(self) -> None:
|
||||||
|
"""
|
||||||
|
Verify that the main app module exports the combined ASGI app.
|
||||||
|
|
||||||
|
The 'app' variable in main.py should be the Socket.IO wrapped
|
||||||
|
FastAPI application, not the raw FastAPI app.
|
||||||
|
"""
|
||||||
|
import socketio
|
||||||
|
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
# The exported 'app' should be a Socket.IO ASGI app
|
||||||
|
assert isinstance(app, socketio.ASGIApp)
|
||||||
|
|
||||||
|
def test_fastapi_app_accessible(self) -> None:
|
||||||
|
"""
|
||||||
|
Verify that the underlying FastAPI app is still accessible.
|
||||||
|
|
||||||
|
The FastAPI app should be available as _fastapi_app in main
|
||||||
|
for testing and direct access when needed.
|
||||||
|
"""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from app.main import _fastapi_app
|
||||||
|
|
||||||
|
assert isinstance(_fastapi_app, FastAPI)
|
||||||
|
assert _fastapi_app.title == "Mantimon TCG"
|
||||||
1
backend/tests/unit/schemas/__init__.py
Normal file
1
backend/tests/unit/schemas/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Unit tests for schemas
|
||||||
701
backend/tests/unit/schemas/test_ws_messages.py
Normal file
701
backend/tests/unit/schemas/test_ws_messages.py
Normal file
@ -0,0 +1,701 @@
|
|||||||
|
"""Tests for WebSocket message schemas.
|
||||||
|
|
||||||
|
This module tests the Pydantic models for WebSocket communication, verifying
|
||||||
|
discriminated union parsing, field validation, and serialization behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import TypeAdapter, ValidationError
|
||||||
|
|
||||||
|
from app.core.enums import GameEndReason, TurnPhase
|
||||||
|
from app.core.models.actions import AttackAction, PassAction
|
||||||
|
from app.core.visibility import VisibleGameState, VisiblePlayerState
|
||||||
|
from app.schemas.ws_messages import (
|
||||||
|
ActionMessage,
|
||||||
|
ActionResultMessage,
|
||||||
|
ClientMessage,
|
||||||
|
ConnectionStatus,
|
||||||
|
ErrorMessage,
|
||||||
|
GameOverMessage,
|
||||||
|
GameStateMessage,
|
||||||
|
HeartbeatAckMessage,
|
||||||
|
HeartbeatMessage,
|
||||||
|
JoinGameMessage,
|
||||||
|
OpponentStatusMessage,
|
||||||
|
ResignMessage,
|
||||||
|
ServerMessage,
|
||||||
|
TurnStartMessage,
|
||||||
|
TurnTimeoutMessage,
|
||||||
|
WSErrorCode,
|
||||||
|
parse_client_message,
|
||||||
|
parse_server_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestClientMessageDiscriminator:
|
||||||
|
"""Tests for client message discriminated union parsing.
|
||||||
|
|
||||||
|
The discriminated union pattern allows automatic type resolution based on
|
||||||
|
the 'type' field, enabling type-safe message handling without explicit
|
||||||
|
type checking logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_parse_join_game_message(self) -> None:
|
||||||
|
"""Test parsing JoinGameMessage from dictionary.
|
||||||
|
|
||||||
|
JoinGameMessage is the first message clients send to establish their
|
||||||
|
presence in a game session. The discriminator should correctly resolve
|
||||||
|
the type from the raw dictionary.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": "join_game",
|
||||||
|
"message_id": "test-123",
|
||||||
|
"game_id": "game-456",
|
||||||
|
}
|
||||||
|
message = parse_client_message(data)
|
||||||
|
|
||||||
|
assert isinstance(message, JoinGameMessage)
|
||||||
|
assert message.game_id == "game-456"
|
||||||
|
assert message.message_id == "test-123"
|
||||||
|
assert message.last_event_id is None
|
||||||
|
|
||||||
|
def test_parse_join_game_with_last_event_id(self) -> None:
|
||||||
|
"""Test JoinGameMessage with reconnection event ID.
|
||||||
|
|
||||||
|
When reconnecting after a disconnect, clients provide last_event_id
|
||||||
|
to receive any missed events since that point, enabling seamless
|
||||||
|
session recovery.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": "join_game",
|
||||||
|
"message_id": "test-123",
|
||||||
|
"game_id": "game-456",
|
||||||
|
"last_event_id": "event-789",
|
||||||
|
}
|
||||||
|
message = parse_client_message(data)
|
||||||
|
|
||||||
|
assert isinstance(message, JoinGameMessage)
|
||||||
|
assert message.last_event_id == "event-789"
|
||||||
|
|
||||||
|
def test_parse_action_message(self) -> None:
|
||||||
|
"""Test parsing ActionMessage with embedded game action.
|
||||||
|
|
||||||
|
ActionMessage wraps the Action union type, so the discriminator must
|
||||||
|
handle both the message type and the nested action type correctly.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": "action",
|
||||||
|
"message_id": "test-123",
|
||||||
|
"game_id": "game-456",
|
||||||
|
"action": {"type": "attack", "attack_index": 0},
|
||||||
|
}
|
||||||
|
message = parse_client_message(data)
|
||||||
|
|
||||||
|
assert isinstance(message, ActionMessage)
|
||||||
|
assert message.game_id == "game-456"
|
||||||
|
assert isinstance(message.action, AttackAction)
|
||||||
|
assert message.action.attack_index == 0
|
||||||
|
|
||||||
|
def test_parse_action_message_with_pass(self) -> None:
|
||||||
|
"""Test parsing ActionMessage with PassAction.
|
||||||
|
|
||||||
|
PassAction is the simplest action type - verifies that minimal
|
||||||
|
action payloads are handled correctly.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": "action",
|
||||||
|
"message_id": "test-123",
|
||||||
|
"game_id": "game-456",
|
||||||
|
"action": {"type": "pass"},
|
||||||
|
}
|
||||||
|
message = parse_client_message(data)
|
||||||
|
|
||||||
|
assert isinstance(message, ActionMessage)
|
||||||
|
assert isinstance(message.action, PassAction)
|
||||||
|
|
||||||
|
def test_parse_resign_message(self) -> None:
|
||||||
|
"""Test parsing ResignMessage.
|
||||||
|
|
||||||
|
ResignMessage allows resignation at the WebSocket layer, separate from
|
||||||
|
the game engine's ResignAction, for cases where engine interaction
|
||||||
|
may not be possible.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": "resign",
|
||||||
|
"message_id": "test-123",
|
||||||
|
"game_id": "game-456",
|
||||||
|
}
|
||||||
|
message = parse_client_message(data)
|
||||||
|
|
||||||
|
assert isinstance(message, ResignMessage)
|
||||||
|
assert message.game_id == "game-456"
|
||||||
|
|
||||||
|
def test_parse_heartbeat_message(self) -> None:
|
||||||
|
"""Test parsing HeartbeatMessage.
|
||||||
|
|
||||||
|
Heartbeat messages keep the WebSocket connection alive and should
|
||||||
|
be the most minimal message type with just the type discriminator.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": "heartbeat",
|
||||||
|
"message_id": "test-123",
|
||||||
|
}
|
||||||
|
message = parse_client_message(data)
|
||||||
|
|
||||||
|
assert isinstance(message, HeartbeatMessage)
|
||||||
|
|
||||||
|
def test_unknown_message_type_raises_error(self) -> None:
|
||||||
|
"""Test that unknown message types raise ValidationError.
|
||||||
|
|
||||||
|
The discriminated union should fail fast on unknown types rather
|
||||||
|
than silently accepting invalid messages.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": "unknown_type",
|
||||||
|
"message_id": "test-123",
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
parse_client_message(data)
|
||||||
|
|
||||||
|
def test_message_id_auto_generated(self) -> None:
|
||||||
|
"""Test that message_id is auto-generated if not provided.
|
||||||
|
|
||||||
|
Clients may omit message_id in which case a UUID should be generated,
|
||||||
|
ensuring all messages have a unique identifier for tracking.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": "heartbeat",
|
||||||
|
}
|
||||||
|
message = parse_client_message(data)
|
||||||
|
|
||||||
|
assert message.message_id is not None
|
||||||
|
assert len(message.message_id) > 0
|
||||||
|
|
||||||
|
def test_empty_message_id_rejected(self) -> None:
|
||||||
|
"""Test that empty message_id is rejected.
|
||||||
|
|
||||||
|
An explicitly empty message_id would break idempotency tracking,
|
||||||
|
so it should be rejected by validation.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": "heartbeat",
|
||||||
|
"message_id": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
parse_client_message(data)
|
||||||
|
|
||||||
|
def test_type_adapter_client_message(self) -> None:
|
||||||
|
"""Test ClientMessage union with TypeAdapter.
|
||||||
|
|
||||||
|
TypeAdapter is the recommended way to work with discriminated unions
|
||||||
|
in Pydantic v2 - verifies it works for both parsing and validation.
|
||||||
|
"""
|
||||||
|
adapter = TypeAdapter(ClientMessage)
|
||||||
|
|
||||||
|
data = {"type": "heartbeat", "message_id": "test"}
|
||||||
|
message = adapter.validate_python(data)
|
||||||
|
|
||||||
|
assert isinstance(message, HeartbeatMessage)
|
||||||
|
|
||||||
|
|
||||||
|
class TestServerMessageDiscriminator:
|
||||||
|
"""Tests for server message discriminated union parsing.
|
||||||
|
|
||||||
|
Server messages are more complex than client messages, often containing
|
||||||
|
nested game state. The discriminator must handle all message types and
|
||||||
|
their embedded data correctly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_parse_game_state_message(self) -> None:
|
||||||
|
"""Test parsing GameStateMessage with full visible state.
|
||||||
|
|
||||||
|
GameStateMessage carries the complete game state from a player's
|
||||||
|
perspective, used for initial sync and full state updates.
|
||||||
|
"""
|
||||||
|
visible_state = VisibleGameState(
|
||||||
|
game_id="game-456",
|
||||||
|
viewer_id="player-1",
|
||||||
|
players={
|
||||||
|
"player-1": VisiblePlayerState(player_id="player-1"),
|
||||||
|
"player-2": VisiblePlayerState(player_id="player-2"),
|
||||||
|
},
|
||||||
|
current_player_id="player-1",
|
||||||
|
turn_number=1,
|
||||||
|
phase=TurnPhase.MAIN,
|
||||||
|
is_my_turn=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = GameStateMessage(
|
||||||
|
game_id="game-456",
|
||||||
|
state=visible_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = message.model_dump()
|
||||||
|
parsed = parse_server_message(data)
|
||||||
|
|
||||||
|
assert isinstance(parsed, GameStateMessage)
|
||||||
|
assert parsed.game_id == "game-456"
|
||||||
|
assert parsed.state.viewer_id == "player-1"
|
||||||
|
assert parsed.event_id is not None
|
||||||
|
|
||||||
|
def test_parse_action_result_success(self) -> None:
|
||||||
|
"""Test parsing successful ActionResultMessage.
|
||||||
|
|
||||||
|
Successful action results should include the changes that occurred
|
||||||
|
as a result of the action for client-side state updates.
|
||||||
|
"""
|
||||||
|
message = ActionResultMessage(
|
||||||
|
game_id="game-456",
|
||||||
|
request_message_id="action-123",
|
||||||
|
success=True,
|
||||||
|
action_type="attack",
|
||||||
|
changes={"damage_dealt": 30, "defender_hp": 70},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = message.model_dump()
|
||||||
|
parsed = parse_server_message(data)
|
||||||
|
|
||||||
|
assert isinstance(parsed, ActionResultMessage)
|
||||||
|
assert parsed.success is True
|
||||||
|
assert parsed.changes["damage_dealt"] == 30
|
||||||
|
assert parsed.error_code is None
|
||||||
|
|
||||||
|
def test_parse_action_result_failure(self) -> None:
|
||||||
|
"""Test parsing failed ActionResultMessage.
|
||||||
|
|
||||||
|
Failed action results should include error code and message
|
||||||
|
for client-side error handling and user feedback.
|
||||||
|
"""
|
||||||
|
message = ActionResultMessage(
|
||||||
|
game_id="game-456",
|
||||||
|
request_message_id="action-123",
|
||||||
|
success=False,
|
||||||
|
action_type="attack",
|
||||||
|
error_code=WSErrorCode.NOT_YOUR_TURN,
|
||||||
|
error_message="It's not your turn",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = message.model_dump()
|
||||||
|
parsed = parse_server_message(data)
|
||||||
|
|
||||||
|
assert isinstance(parsed, ActionResultMessage)
|
||||||
|
assert parsed.success is False
|
||||||
|
assert parsed.error_code == WSErrorCode.NOT_YOUR_TURN
|
||||||
|
|
||||||
|
def test_parse_error_message(self) -> None:
|
||||||
|
"""Test parsing ErrorMessage.
|
||||||
|
|
||||||
|
Error messages are for protocol-level errors not tied to a specific
|
||||||
|
action, like authentication failures or invalid message format.
|
||||||
|
"""
|
||||||
|
message = ErrorMessage(
|
||||||
|
code=WSErrorCode.AUTHENTICATION_FAILED,
|
||||||
|
message="Invalid token",
|
||||||
|
details={"reason": "expired"},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = message.model_dump()
|
||||||
|
parsed = parse_server_message(data)
|
||||||
|
|
||||||
|
assert isinstance(parsed, ErrorMessage)
|
||||||
|
assert parsed.code == WSErrorCode.AUTHENTICATION_FAILED
|
||||||
|
assert parsed.details["reason"] == "expired"
|
||||||
|
|
||||||
|
def test_parse_turn_start_message(self) -> None:
|
||||||
|
"""Test parsing TurnStartMessage.
|
||||||
|
|
||||||
|
TurnStartMessage notifies both players when a turn begins,
|
||||||
|
enabling UI updates like starting a turn timer.
|
||||||
|
"""
|
||||||
|
message = TurnStartMessage(
|
||||||
|
game_id="game-456",
|
||||||
|
player_id="player-1",
|
||||||
|
turn_number=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = message.model_dump()
|
||||||
|
parsed = parse_server_message(data)
|
||||||
|
|
||||||
|
assert isinstance(parsed, TurnStartMessage)
|
||||||
|
assert parsed.turn_number == 5
|
||||||
|
assert parsed.player_id == "player-1"
|
||||||
|
|
||||||
|
def test_parse_turn_timeout_warning(self) -> None:
|
||||||
|
"""Test parsing TurnTimeoutMessage as warning.
|
||||||
|
|
||||||
|
Timeout warnings give players a chance to complete their action
|
||||||
|
before the turn is forcibly ended.
|
||||||
|
"""
|
||||||
|
message = TurnTimeoutMessage(
|
||||||
|
game_id="game-456",
|
||||||
|
remaining_seconds=30,
|
||||||
|
is_warning=True,
|
||||||
|
player_id="player-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = message.model_dump()
|
||||||
|
parsed = parse_server_message(data)
|
||||||
|
|
||||||
|
assert isinstance(parsed, TurnTimeoutMessage)
|
||||||
|
assert parsed.is_warning is True
|
||||||
|
assert parsed.remaining_seconds == 30
|
||||||
|
|
||||||
|
def test_parse_turn_timeout_expired(self) -> None:
|
||||||
|
"""Test parsing TurnTimeoutMessage as expired.
|
||||||
|
|
||||||
|
When is_warning is False, the timeout has occurred and the game
|
||||||
|
engine will force the turn to end.
|
||||||
|
"""
|
||||||
|
message = TurnTimeoutMessage(
|
||||||
|
game_id="game-456",
|
||||||
|
remaining_seconds=0,
|
||||||
|
is_warning=False,
|
||||||
|
player_id="player-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = message.model_dump()
|
||||||
|
parsed = parse_server_message(data)
|
||||||
|
|
||||||
|
assert isinstance(parsed, TurnTimeoutMessage)
|
||||||
|
assert parsed.is_warning is False
|
||||||
|
assert parsed.remaining_seconds == 0
|
||||||
|
|
||||||
|
def test_parse_game_over_with_winner(self) -> None:
|
||||||
|
"""Test parsing GameOverMessage with a winner.
|
||||||
|
|
||||||
|
Most games end with a winner - the final state is included so
|
||||||
|
clients can display the final board position.
|
||||||
|
"""
|
||||||
|
final_state = VisibleGameState(
|
||||||
|
game_id="game-456",
|
||||||
|
viewer_id="player-1",
|
||||||
|
winner_id="player-1",
|
||||||
|
end_reason=GameEndReason.PRIZES_TAKEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = GameOverMessage(
|
||||||
|
game_id="game-456",
|
||||||
|
winner_id="player-1",
|
||||||
|
end_reason=GameEndReason.PRIZES_TAKEN,
|
||||||
|
final_state=final_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = message.model_dump()
|
||||||
|
parsed = parse_server_message(data)
|
||||||
|
|
||||||
|
assert isinstance(parsed, GameOverMessage)
|
||||||
|
assert parsed.winner_id == "player-1"
|
||||||
|
assert parsed.end_reason == GameEndReason.PRIZES_TAKEN
|
||||||
|
|
||||||
|
def test_parse_game_over_draw(self) -> None:
|
||||||
|
"""Test parsing GameOverMessage as a draw.
|
||||||
|
|
||||||
|
Games can end in a draw (e.g., timeout with equal scores),
|
||||||
|
in which case winner_id is None.
|
||||||
|
"""
|
||||||
|
final_state = VisibleGameState(
|
||||||
|
game_id="game-456",
|
||||||
|
viewer_id="player-1",
|
||||||
|
end_reason=GameEndReason.DRAW,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = GameOverMessage(
|
||||||
|
game_id="game-456",
|
||||||
|
winner_id=None,
|
||||||
|
end_reason=GameEndReason.DRAW,
|
||||||
|
final_state=final_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = message.model_dump()
|
||||||
|
parsed = parse_server_message(data)
|
||||||
|
|
||||||
|
assert isinstance(parsed, GameOverMessage)
|
||||||
|
assert parsed.winner_id is None
|
||||||
|
assert parsed.end_reason == GameEndReason.DRAW
|
||||||
|
|
||||||
|
def test_parse_opponent_status_connected(self) -> None:
|
||||||
|
"""Test parsing OpponentStatusMessage for connection.
|
||||||
|
|
||||||
|
Notifies players when their opponent connects, enabling
|
||||||
|
connection status display in the UI.
|
||||||
|
"""
|
||||||
|
message = OpponentStatusMessage(
|
||||||
|
game_id="game-456",
|
||||||
|
opponent_id="player-2",
|
||||||
|
status=ConnectionStatus.CONNECTED,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = message.model_dump()
|
||||||
|
parsed = parse_server_message(data)
|
||||||
|
|
||||||
|
assert isinstance(parsed, OpponentStatusMessage)
|
||||||
|
assert parsed.status == ConnectionStatus.CONNECTED
|
||||||
|
|
||||||
|
def test_parse_opponent_status_disconnected(self) -> None:
|
||||||
|
"""Test parsing OpponentStatusMessage for disconnection.
|
||||||
|
|
||||||
|
When an opponent disconnects, the UI can show a waiting indicator
|
||||||
|
and possibly pause the game timer.
|
||||||
|
"""
|
||||||
|
message = OpponentStatusMessage(
|
||||||
|
game_id="game-456",
|
||||||
|
opponent_id="player-2",
|
||||||
|
status=ConnectionStatus.DISCONNECTED,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = message.model_dump()
|
||||||
|
parsed = parse_server_message(data)
|
||||||
|
|
||||||
|
assert isinstance(parsed, OpponentStatusMessage)
|
||||||
|
assert parsed.status == ConnectionStatus.DISCONNECTED
|
||||||
|
|
||||||
|
def test_parse_heartbeat_ack(self) -> None:
|
||||||
|
"""Test parsing HeartbeatAckMessage.
|
||||||
|
|
||||||
|
HeartbeatAck confirms the connection is alive, allowing clients
|
||||||
|
to measure round-trip latency.
|
||||||
|
"""
|
||||||
|
message = HeartbeatAckMessage()
|
||||||
|
|
||||||
|
data = message.model_dump()
|
||||||
|
parsed = parse_server_message(data)
|
||||||
|
|
||||||
|
assert isinstance(parsed, HeartbeatAckMessage)
|
||||||
|
|
||||||
|
def test_server_message_has_timestamp(self) -> None:
|
||||||
|
"""Test that all server messages have a timestamp.
|
||||||
|
|
||||||
|
Timestamps enable client-side latency tracking and event ordering
|
||||||
|
in case of out-of-order message delivery.
|
||||||
|
"""
|
||||||
|
message = HeartbeatAckMessage()
|
||||||
|
|
||||||
|
assert message.timestamp is not None
|
||||||
|
assert isinstance(message.timestamp, datetime)
|
||||||
|
# Should be UTC
|
||||||
|
assert message.timestamp.tzinfo == UTC
|
||||||
|
|
||||||
|
def test_type_adapter_server_message(self) -> None:
|
||||||
|
"""Test ServerMessage union with TypeAdapter.
|
||||||
|
|
||||||
|
Verifies TypeAdapter works correctly for the more complex
|
||||||
|
server message union with nested types.
|
||||||
|
"""
|
||||||
|
adapter = TypeAdapter(ServerMessage)
|
||||||
|
|
||||||
|
error_data = {
|
||||||
|
"type": "error",
|
||||||
|
"message_id": "test",
|
||||||
|
"timestamp": datetime.now(UTC).isoformat(),
|
||||||
|
"code": "game_not_found",
|
||||||
|
"message": "Game not found",
|
||||||
|
}
|
||||||
|
message = adapter.validate_python(error_data)
|
||||||
|
|
||||||
|
assert isinstance(message, ErrorMessage)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageSerialization:
|
||||||
|
"""Tests for message JSON serialization and round-trip behavior.
|
||||||
|
|
||||||
|
Messages must serialize cleanly to JSON for WebSocket transmission
|
||||||
|
and deserialize back to the same logical structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_client_message_round_trip(self) -> None:
|
||||||
|
"""Test client message JSON round-trip.
|
||||||
|
|
||||||
|
Messages should survive serialization to JSON and back without
|
||||||
|
losing data or changing types.
|
||||||
|
"""
|
||||||
|
original = ActionMessage(
|
||||||
|
message_id="test-123",
|
||||||
|
game_id="game-456",
|
||||||
|
action=AttackAction(attack_index=1, targets=["target-1"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
json_str = original.model_dump_json()
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = json.loads(json_str)
|
||||||
|
restored = parse_client_message(data)
|
||||||
|
|
||||||
|
assert isinstance(restored, ActionMessage)
|
||||||
|
assert restored.message_id == original.message_id
|
||||||
|
assert isinstance(restored.action, AttackAction)
|
||||||
|
assert restored.action.attack_index == 1
|
||||||
|
assert restored.action.targets == ["target-1"]
|
||||||
|
|
||||||
|
def test_server_message_round_trip(self) -> None:
|
||||||
|
"""Test server message JSON round-trip.
|
||||||
|
|
||||||
|
Server messages with nested state should serialize and deserialize
|
||||||
|
correctly, preserving all game state information.
|
||||||
|
"""
|
||||||
|
visible_state = VisibleGameState(
|
||||||
|
game_id="game-456",
|
||||||
|
viewer_id="player-1",
|
||||||
|
turn_number=3,
|
||||||
|
phase=TurnPhase.ATTACK,
|
||||||
|
)
|
||||||
|
|
||||||
|
original = GameStateMessage(
|
||||||
|
message_id="test-123",
|
||||||
|
game_id="game-456",
|
||||||
|
state=visible_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
json_str = original.model_dump_json()
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = json.loads(json_str)
|
||||||
|
restored = parse_server_message(data)
|
||||||
|
|
||||||
|
assert isinstance(restored, GameStateMessage)
|
||||||
|
assert restored.state.turn_number == 3
|
||||||
|
assert restored.state.phase == TurnPhase.ATTACK
|
||||||
|
|
||||||
|
def test_error_code_serializes_as_string(self) -> None:
|
||||||
|
"""Test that WSErrorCode serializes as string value.
|
||||||
|
|
||||||
|
Error codes should serialize to their string values for clean
|
||||||
|
JSON output that clients can easily handle.
|
||||||
|
"""
|
||||||
|
message = ErrorMessage(
|
||||||
|
code=WSErrorCode.GAME_NOT_FOUND,
|
||||||
|
message="Game not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
json_str = message.model_dump_json()
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = json.loads(json_str)
|
||||||
|
|
||||||
|
assert data["code"] == "game_not_found"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageValidation:
|
||||||
|
"""Tests for message field validation rules.
|
||||||
|
|
||||||
|
Validation ensures messages contain valid data before processing,
|
||||||
|
preventing invalid state from propagating through the system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_join_game_requires_game_id(self) -> None:
|
||||||
|
"""Test that JoinGameMessage requires game_id.
|
||||||
|
|
||||||
|
game_id is mandatory - clients must specify which game to join.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": "join_game",
|
||||||
|
"message_id": "test-123",
|
||||||
|
# missing game_id
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
parse_client_message(data)
|
||||||
|
|
||||||
|
def test_action_message_requires_action(self) -> None:
|
||||||
|
"""Test that ActionMessage requires action field.
|
||||||
|
|
||||||
|
The action field is mandatory - messages without an action
|
||||||
|
cannot be processed.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": "action",
|
||||||
|
"message_id": "test-123",
|
||||||
|
"game_id": "game-456",
|
||||||
|
# missing action
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
parse_client_message(data)
|
||||||
|
|
||||||
|
def test_turn_number_must_be_positive(self) -> None:
|
||||||
|
"""Test that turn_number must be >= 1.
|
||||||
|
|
||||||
|
Turn numbers start at 1 - zero or negative values are invalid.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TurnStartMessage(
|
||||||
|
game_id="game-456",
|
||||||
|
player_id="player-1",
|
||||||
|
turn_number=0, # Invalid
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_remaining_seconds_non_negative(self) -> None:
|
||||||
|
"""Test that remaining_seconds must be >= 0.
|
||||||
|
|
||||||
|
Negative timeout values are meaningless and should be rejected.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
TurnTimeoutMessage(
|
||||||
|
game_id="game-456",
|
||||||
|
remaining_seconds=-1, # Invalid
|
||||||
|
is_warning=True,
|
||||||
|
player_id="player-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_action_result_requires_request_id(self) -> None:
|
||||||
|
"""Test that ActionResultMessage requires request_message_id.
|
||||||
|
|
||||||
|
Results must link back to the original request for client-side
|
||||||
|
correlation of actions and responses.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
ActionResultMessage(
|
||||||
|
game_id="game-456",
|
||||||
|
# missing request_message_id
|
||||||
|
success=True,
|
||||||
|
action_type="attack",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWSErrorCode:
|
||||||
|
"""Tests for WebSocket error code enumeration.
|
||||||
|
|
||||||
|
Error codes provide machine-readable error classification for
|
||||||
|
programmatic error handling on the client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_all_error_codes_are_strings(self) -> None:
|
||||||
|
"""Test that all error codes are string values.
|
||||||
|
|
||||||
|
StrEnum ensures all values serialize as strings for JSON
|
||||||
|
compatibility.
|
||||||
|
"""
|
||||||
|
for code in WSErrorCode:
|
||||||
|
assert isinstance(code.value, str)
|
||||||
|
|
||||||
|
def test_error_codes_are_lowercase_snake_case(self) -> None:
|
||||||
|
"""Test error code naming convention.
|
||||||
|
|
||||||
|
Error codes follow lowercase_snake_case for consistency with
|
||||||
|
other API identifiers and JSON conventions.
|
||||||
|
"""
|
||||||
|
for code in WSErrorCode:
|
||||||
|
assert code.value == code.value.lower()
|
||||||
|
assert " " not in code.value
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnectionStatus:
|
||||||
|
"""Tests for connection status enumeration."""
|
||||||
|
|
||||||
|
def test_all_statuses_are_strings(self) -> None:
|
||||||
|
"""Test that all connection statuses are string values."""
|
||||||
|
for status in ConnectionStatus:
|
||||||
|
assert isinstance(status.value, str)
|
||||||
|
|
||||||
|
def test_expected_statuses_exist(self) -> None:
|
||||||
|
"""Test that expected connection statuses are defined.
|
||||||
|
|
||||||
|
The three key states (connected, disconnected, reconnecting)
|
||||||
|
cover all connection lifecycle phases.
|
||||||
|
"""
|
||||||
|
assert ConnectionStatus.CONNECTED
|
||||||
|
assert ConnectionStatus.DISCONNECTED
|
||||||
|
assert ConnectionStatus.RECONNECTING
|
||||||
665
backend/tests/unit/services/test_connection_manager.py
Normal file
665
backend/tests/unit/services/test_connection_manager.py
Normal file
@ -0,0 +1,665 @@
|
|||||||
|
"""Tests for ConnectionManager service.
|
||||||
|
|
||||||
|
This module tests WebSocket connection tracking with Redis. Since these are
|
||||||
|
unit tests, we mock the Redis operations to test the ConnectionManager logic
|
||||||
|
without requiring a real Redis instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.connection_manager import (
|
||||||
|
CONN_PREFIX,
|
||||||
|
GAME_CONNS_PREFIX,
|
||||||
|
HEARTBEAT_INTERVAL_SECONDS,
|
||||||
|
USER_CONN_PREFIX,
|
||||||
|
ConnectionInfo,
|
||||||
|
ConnectionManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager() -> ConnectionManager:
|
||||||
|
"""Create a ConnectionManager instance for testing."""
|
||||||
|
return ConnectionManager(conn_ttl_seconds=3600)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_redis() -> AsyncMock:
|
||||||
|
"""Create a mock Redis client."""
|
||||||
|
redis = AsyncMock()
|
||||||
|
redis.hset = AsyncMock()
|
||||||
|
redis.hget = AsyncMock()
|
||||||
|
redis.hgetall = AsyncMock(return_value={})
|
||||||
|
redis.set = AsyncMock()
|
||||||
|
redis.get = AsyncMock(return_value=None)
|
||||||
|
redis.delete = AsyncMock()
|
||||||
|
redis.exists = AsyncMock(return_value=False)
|
||||||
|
redis.expire = AsyncMock()
|
||||||
|
redis.sadd = AsyncMock()
|
||||||
|
redis.srem = AsyncMock()
|
||||||
|
redis.smembers = AsyncMock(return_value=set())
|
||||||
|
redis.scard = AsyncMock(return_value=0)
|
||||||
|
return redis
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnectionInfoDataclass:
|
||||||
|
"""Tests for the ConnectionInfo dataclass."""
|
||||||
|
|
||||||
|
def test_is_stale_returns_false_for_recent_connection(self) -> None:
|
||||||
|
"""Test that recent connections are not marked as stale.
|
||||||
|
|
||||||
|
Connections that have been seen within the threshold should be
|
||||||
|
considered active, not stale.
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
info = ConnectionInfo(
|
||||||
|
sid="test-sid",
|
||||||
|
user_id="user-123",
|
||||||
|
game_id=None,
|
||||||
|
connected_at=now,
|
||||||
|
last_seen=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert info.is_stale() is False
|
||||||
|
|
||||||
|
def test_is_stale_returns_true_for_old_connection(self) -> None:
|
||||||
|
"""Test that old connections are marked as stale.
|
||||||
|
|
||||||
|
Connections that haven't been seen for longer than the threshold
|
||||||
|
should be considered stale and eligible for cleanup.
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
old_time = now - timedelta(seconds=HEARTBEAT_INTERVAL_SECONDS * 4)
|
||||||
|
info = ConnectionInfo(
|
||||||
|
sid="test-sid",
|
||||||
|
user_id="user-123",
|
||||||
|
game_id=None,
|
||||||
|
connected_at=old_time,
|
||||||
|
last_seen=old_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert info.is_stale() is True
|
||||||
|
|
||||||
|
def test_is_stale_with_custom_threshold(self) -> None:
|
||||||
|
"""Test is_stale with a custom threshold.
|
||||||
|
|
||||||
|
The threshold can be adjusted for different use cases like
|
||||||
|
more aggressive cleanup or more lenient timeout.
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
last_seen = now - timedelta(seconds=60)
|
||||||
|
info = ConnectionInfo(
|
||||||
|
sid="test-sid",
|
||||||
|
user_id="user-123",
|
||||||
|
game_id=None,
|
||||||
|
connected_at=now,
|
||||||
|
last_seen=last_seen,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 60 seconds old, with 30 second threshold = stale
|
||||||
|
assert info.is_stale(threshold_seconds=30) is True
|
||||||
|
|
||||||
|
# 60 seconds old, with 120 second threshold = not stale
|
||||||
|
assert info.is_stale(threshold_seconds=120) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegisterConnection:
|
||||||
|
"""Tests for connection registration."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_connection_creates_records(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that registering a connection creates all necessary Redis records.
|
||||||
|
|
||||||
|
Registration should create:
|
||||||
|
1. Connection hash with user_id, game_id, timestamps
|
||||||
|
2. User-to-connection mapping
|
||||||
|
"""
|
||||||
|
sid = "test-sid-123"
|
||||||
|
user_id = str(uuid4())
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
await manager.register_connection(sid, user_id)
|
||||||
|
|
||||||
|
# Verify connection hash was created
|
||||||
|
conn_key = f"{CONN_PREFIX}{sid}"
|
||||||
|
mock_redis.hset.assert_called()
|
||||||
|
hset_call = mock_redis.hset.call_args
|
||||||
|
assert hset_call.args[0] == conn_key
|
||||||
|
mapping = hset_call.kwargs["mapping"]
|
||||||
|
assert mapping["user_id"] == user_id
|
||||||
|
assert mapping["game_id"] == ""
|
||||||
|
|
||||||
|
# Verify user-to-connection mapping was created
|
||||||
|
user_conn_key = f"{USER_CONN_PREFIX}{user_id}"
|
||||||
|
mock_redis.set.assert_called_with(user_conn_key, sid)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_connection_replaces_old_connection(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that registering a new connection replaces the old one.
|
||||||
|
|
||||||
|
When a user reconnects, their old connection should be cleaned up
|
||||||
|
to prevent stale connection data from lingering.
|
||||||
|
"""
|
||||||
|
old_sid = "old-sid"
|
||||||
|
new_sid = "new-sid"
|
||||||
|
user_id = str(uuid4())
|
||||||
|
|
||||||
|
# Mock: user has existing connection
|
||||||
|
mock_redis.get.return_value = old_sid
|
||||||
|
mock_redis.hgetall.return_value = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"game_id": "",
|
||||||
|
"connected_at": datetime.now(UTC).isoformat(),
|
||||||
|
"last_seen": datetime.now(UTC).isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
await manager.register_connection(new_sid, user_id)
|
||||||
|
|
||||||
|
# Verify old connection was cleaned up (delete called for old conn key)
|
||||||
|
old_conn_key = f"{CONN_PREFIX}{old_sid}"
|
||||||
|
delete_calls = [call.args[0] for call in mock_redis.delete.call_args_list]
|
||||||
|
assert old_conn_key in delete_calls
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_connection_accepts_uuid(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that register_connection accepts UUID objects.
|
||||||
|
|
||||||
|
The user_id can be passed as either a string or UUID object
|
||||||
|
for convenience.
|
||||||
|
"""
|
||||||
|
sid = "test-sid"
|
||||||
|
user_uuid = uuid4()
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
await manager.register_connection(sid, user_uuid)
|
||||||
|
|
||||||
|
# Verify user_id was converted to string
|
||||||
|
hset_call = mock_redis.hset.call_args
|
||||||
|
mapping = hset_call.kwargs["mapping"]
|
||||||
|
assert mapping["user_id"] == str(user_uuid)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnregisterConnection:
|
||||||
|
"""Tests for connection unregistration."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unregister_returns_none_for_unknown_sid(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that unregistering unknown connection returns None.
|
||||||
|
|
||||||
|
If the connection doesn't exist, we should return None rather than
|
||||||
|
raising an error.
|
||||||
|
"""
|
||||||
|
mock_redis.hgetall.return_value = {}
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.unregister_connection("unknown-sid")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unregister_cleans_up_all_data(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that unregistering cleans up all related Redis data.
|
||||||
|
|
||||||
|
Cleanup should remove:
|
||||||
|
1. Connection hash
|
||||||
|
2. User-to-connection mapping (if it still points to this sid)
|
||||||
|
3. Game connection set membership
|
||||||
|
"""
|
||||||
|
sid = "test-sid"
|
||||||
|
user_id = "user-123"
|
||||||
|
game_id = "game-456"
|
||||||
|
|
||||||
|
# Mock: connection exists with game
|
||||||
|
mock_redis.hgetall.return_value = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"game_id": game_id,
|
||||||
|
"connected_at": datetime.now(UTC).isoformat(),
|
||||||
|
"last_seen": datetime.now(UTC).isoformat(),
|
||||||
|
}
|
||||||
|
mock_redis.get.return_value = sid # user mapping points to this sid
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.unregister_connection(sid)
|
||||||
|
|
||||||
|
# Verify result
|
||||||
|
assert result is not None
|
||||||
|
assert result.sid == sid
|
||||||
|
assert result.user_id == user_id
|
||||||
|
assert result.game_id == game_id
|
||||||
|
|
||||||
|
# Verify cleanup
|
||||||
|
mock_redis.srem.assert_called() # Removed from game set
|
||||||
|
mock_redis.delete.assert_called() # Connection deleted
|
||||||
|
|
||||||
|
|
||||||
|
class TestGameAssociation:
|
||||||
|
"""Tests for game join/leave operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_join_game_adds_to_game_set(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that joining a game adds the connection to the game's set.
|
||||||
|
|
||||||
|
When a connection joins a game, it should:
|
||||||
|
1. Update the connection's game_id
|
||||||
|
2. Add the sid to the game's connection set
|
||||||
|
"""
|
||||||
|
sid = "test-sid"
|
||||||
|
game_id = "game-123"
|
||||||
|
|
||||||
|
mock_redis.exists.return_value = True
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.join_game(sid, game_id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", game_id)
|
||||||
|
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_join_game_returns_false_for_unknown_connection(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that joining a game fails for unknown connections.
|
||||||
|
|
||||||
|
If the connection doesn't exist, we should return False rather
|
||||||
|
than creating orphan game association data.
|
||||||
|
"""
|
||||||
|
mock_redis.exists.return_value = False
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.join_game("unknown-sid", "game-123")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
mock_redis.sadd.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_join_game_leaves_previous_game(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that joining a new game leaves the previous game.
|
||||||
|
|
||||||
|
When switching games, the connection should be removed from the
|
||||||
|
old game's connection set before joining the new one.
|
||||||
|
"""
|
||||||
|
sid = "test-sid"
|
||||||
|
old_game = "game-old"
|
||||||
|
new_game = "game-new"
|
||||||
|
|
||||||
|
mock_redis.exists.return_value = True
|
||||||
|
mock_redis.hget.return_value = old_game # Currently in old game
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
await manager.join_game(sid, new_game)
|
||||||
|
|
||||||
|
# Verify left old game
|
||||||
|
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{old_game}", sid)
|
||||||
|
# Verify joined new game
|
||||||
|
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{new_game}", sid)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_leave_game_removes_from_set(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that leaving a game removes the connection from the set.
|
||||||
|
|
||||||
|
Leave should:
|
||||||
|
1. Remove sid from game's connection set
|
||||||
|
2. Clear game_id on connection record
|
||||||
|
"""
|
||||||
|
sid = "test-sid"
|
||||||
|
game_id = "game-123"
|
||||||
|
|
||||||
|
mock_redis.hget.return_value = game_id
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.leave_game(sid)
|
||||||
|
|
||||||
|
assert result == game_id
|
||||||
|
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
|
||||||
|
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_leave_game_returns_none_when_not_in_game(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that leave_game returns None when not in a game.
|
||||||
|
|
||||||
|
If the connection isn't associated with any game, we should
|
||||||
|
return None without making unnecessary Redis calls.
|
||||||
|
"""
|
||||||
|
mock_redis.hget.return_value = "" # No game_id
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.leave_game("test-sid")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
mock_redis.srem.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestHeartbeat:
|
||||||
|
"""Tests for heartbeat/activity tracking."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_heartbeat_refreshes_last_seen(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that heartbeat updates the last_seen timestamp.
|
||||||
|
|
||||||
|
Heartbeats keep the connection alive by updating the timestamp
|
||||||
|
and refreshing TTLs on Redis records.
|
||||||
|
"""
|
||||||
|
sid = "test-sid"
|
||||||
|
mock_redis.exists.return_value = True
|
||||||
|
mock_redis.hget.return_value = "user-123"
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.update_heartbeat(sid)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
# Verify last_seen was updated
|
||||||
|
hset_call = mock_redis.hset.call_args
|
||||||
|
assert hset_call.args[0] == f"{CONN_PREFIX}{sid}"
|
||||||
|
assert hset_call.args[1] == "last_seen"
|
||||||
|
# Verify TTL was refreshed
|
||||||
|
mock_redis.expire.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_heartbeat_returns_false_for_unknown(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that heartbeat returns False for unknown connections.
|
||||||
|
|
||||||
|
If the connection doesn't exist, we shouldn't try to update it.
|
||||||
|
"""
|
||||||
|
mock_redis.exists.return_value = False
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.update_heartbeat("unknown-sid")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryMethods:
|
||||||
|
"""Tests for connection query methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_connection_returns_info(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that get_connection returns ConnectionInfo for valid sid.
|
||||||
|
|
||||||
|
The returned ConnectionInfo should have all fields populated
|
||||||
|
from the Redis hash.
|
||||||
|
"""
|
||||||
|
sid = "test-sid"
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
mock_redis.hgetall.return_value = {
|
||||||
|
"user_id": "user-123",
|
||||||
|
"game_id": "game-456",
|
||||||
|
"connected_at": now.isoformat(),
|
||||||
|
"last_seen": now.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.get_connection(sid)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.sid == sid
|
||||||
|
assert result.user_id == "user-123"
|
||||||
|
assert result.game_id == "game-456"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_connection_returns_none_for_unknown(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that get_connection returns None for unknown sid."""
|
||||||
|
mock_redis.hgetall.return_value = {}
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.get_connection("unknown-sid")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_user_online_returns_true_for_connected_user(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that is_user_online returns True for connected users.
|
||||||
|
|
||||||
|
A user is online if they have an active connection that isn't stale.
|
||||||
|
"""
|
||||||
|
user_id = "user-123"
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
mock_redis.get.return_value = "test-sid"
|
||||||
|
mock_redis.hgetall.return_value = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"game_id": "",
|
||||||
|
"connected_at": now.isoformat(),
|
||||||
|
"last_seen": now.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.is_user_online(user_id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_user_online_returns_false_for_stale_connection(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that is_user_online returns False for stale connections.
|
||||||
|
|
||||||
|
Even if a connection record exists, if it's stale (no recent heartbeat),
|
||||||
|
the user should be considered offline.
|
||||||
|
"""
|
||||||
|
user_id = "user-123"
|
||||||
|
old_time = datetime.now(UTC) - timedelta(minutes=5)
|
||||||
|
|
||||||
|
mock_redis.get.return_value = "test-sid"
|
||||||
|
mock_redis.hgetall.return_value = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"game_id": "",
|
||||||
|
"connected_at": old_time.isoformat(),
|
||||||
|
"last_seen": old_time.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.is_user_online(user_id)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_game_connections_returns_all_participants(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that get_game_connections returns all game participants.
|
||||||
|
|
||||||
|
Should return ConnectionInfo for each sid in the game's connection set.
|
||||||
|
"""
|
||||||
|
game_id = "game-123"
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
mock_redis.smembers.return_value = {"sid-1", "sid-2"}
|
||||||
|
mock_redis.hgetall.side_effect = [
|
||||||
|
{
|
||||||
|
"user_id": "user-1",
|
||||||
|
"game_id": game_id,
|
||||||
|
"connected_at": now.isoformat(),
|
||||||
|
"last_seen": now.isoformat(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": "user-2",
|
||||||
|
"game_id": game_id,
|
||||||
|
"connected_at": now.isoformat(),
|
||||||
|
"last_seen": now.isoformat(),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.get_game_connections(game_id)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
user_ids = {conn.user_id for conn in result}
|
||||||
|
assert user_ids == {"user-1", "user-2"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_opponent_sid_returns_other_player(
|
||||||
|
self,
|
||||||
|
manager: ConnectionManager,
|
||||||
|
mock_redis: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that get_opponent_sid returns the other player's sid.
|
||||||
|
|
||||||
|
In a 2-player game, this should return the sid of the player
|
||||||
|
who is not the current user.
|
||||||
|
"""
|
||||||
|
game_id = "game-123"
|
||||||
|
current_user = "user-1"
|
||||||
|
opponent_user = "user-2"
|
||||||
|
opponent_sid = "sid-2"
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
mock_redis.smembers.return_value = {"sid-1", "sid-2"}
|
||||||
|
|
||||||
|
# Use a function to return the right data based on the key queried
|
||||||
|
# This avoids dependency on set iteration order
|
||||||
|
def hgetall_by_key(key: str) -> dict[str, str]:
|
||||||
|
if key == f"{CONN_PREFIX}sid-1":
|
||||||
|
return {
|
||||||
|
"user_id": current_user,
|
||||||
|
"game_id": game_id,
|
||||||
|
"connected_at": now.isoformat(),
|
||||||
|
"last_seen": now.isoformat(),
|
||||||
|
}
|
||||||
|
elif key == f"{CONN_PREFIX}sid-2":
|
||||||
|
return {
|
||||||
|
"user_id": opponent_user,
|
||||||
|
"game_id": game_id,
|
||||||
|
"connected_at": now.isoformat(),
|
||||||
|
"last_seen": now.isoformat(),
|
||||||
|
}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
mock_redis.hgetall.side_effect = hgetall_by_key
|
||||||
|
|
||||||
|
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||||
|
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||||
|
|
||||||
|
result = await manager.get_opponent_sid(game_id, current_user)
|
||||||
|
|
||||||
|
assert result == opponent_sid
|
||||||
|
|
||||||
|
|
||||||
|
class TestKeyGeneration:
|
||||||
|
"""Tests for Redis key generation methods."""
|
||||||
|
|
||||||
|
def test_conn_key_format(self, manager: ConnectionManager) -> None:
|
||||||
|
"""Test that connection keys have correct format.
|
||||||
|
|
||||||
|
Keys should follow the pattern conn:{sid} for easy identification
|
||||||
|
and pattern matching.
|
||||||
|
"""
|
||||||
|
key = manager._conn_key("test-sid-123")
|
||||||
|
assert key == "conn:test-sid-123"
|
||||||
|
|
||||||
|
def test_user_conn_key_format(self, manager: ConnectionManager) -> None:
|
||||||
|
"""Test that user connection keys have correct format.
|
||||||
|
|
||||||
|
Keys should follow the pattern user_conn:{user_id}.
|
||||||
|
"""
|
||||||
|
key = manager._user_conn_key("user-456")
|
||||||
|
assert key == "user_conn:user-456"
|
||||||
|
|
||||||
|
def test_game_conns_key_format(self, manager: ConnectionManager) -> None:
|
||||||
|
"""Test that game connections keys have correct format.
|
||||||
|
|
||||||
|
Keys should follow the pattern game_conns:{game_id}.
|
||||||
|
"""
|
||||||
|
key = manager._game_conns_key("game-789")
|
||||||
|
assert key == "game_conns:game-789"
|
||||||
809
backend/tests/unit/services/test_game_service.py
Normal file
809
backend/tests/unit/services/test_game_service.py
Normal file
@ -0,0 +1,809 @@
|
|||||||
|
"""Tests for GameService.
|
||||||
|
|
||||||
|
This module tests the game service layer that orchestrates between
|
||||||
|
WebSocket communication and the core GameEngine.
|
||||||
|
|
||||||
|
The GameService is STATELESS regarding game rules:
|
||||||
|
- No GameEngine is stored in the service
|
||||||
|
- Engine is created per-operation using rules from GameState
|
||||||
|
- Rules come from frontend at game creation, stored in GameState
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.engine import ActionResult
|
||||||
|
from app.core.enums import GameEndReason, TurnPhase
|
||||||
|
from app.core.models.actions import AttackAction, PassAction, ResignAction
|
||||||
|
from app.core.models.game_state import GameState, PlayerState
|
||||||
|
from app.core.win_conditions import WinResult
|
||||||
|
from app.services.game_service import (
|
||||||
|
GameAlreadyEndedError,
|
||||||
|
GameNotFoundError,
|
||||||
|
GameService,
|
||||||
|
InvalidActionError,
|
||||||
|
NotPlayerTurnError,
|
||||||
|
PlayerNotInGameError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_state_manager() -> AsyncMock:
|
||||||
|
"""Create a mock GameStateManager.
|
||||||
|
|
||||||
|
The state manager handles persistence to Redis (cache) and
|
||||||
|
Postgres (durable storage).
|
||||||
|
"""
|
||||||
|
manager = AsyncMock()
|
||||||
|
manager.load_state = AsyncMock(return_value=None)
|
||||||
|
manager.save_to_cache = AsyncMock()
|
||||||
|
manager.persist_to_db = AsyncMock()
|
||||||
|
manager.cache_exists = AsyncMock(return_value=False)
|
||||||
|
return manager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_card_service() -> MagicMock:
|
||||||
|
"""Create a mock CardService.
|
||||||
|
|
||||||
|
CardService provides card definitions for game creation.
|
||||||
|
"""
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def game_service(
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
mock_card_service: MagicMock,
|
||||||
|
) -> GameService:
|
||||||
|
"""Create a GameService with mocked dependencies.
|
||||||
|
|
||||||
|
Note: No engine is passed - GameService creates engines per-operation
|
||||||
|
using rules stored in each game's GameState.
|
||||||
|
"""
|
||||||
|
return GameService(
|
||||||
|
state_manager=mock_state_manager,
|
||||||
|
card_service=mock_card_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_game_state() -> GameState:
|
||||||
|
"""Create a sample game state for testing.
|
||||||
|
|
||||||
|
The game state includes two players and basic turn tracking.
|
||||||
|
The rules are stored in the state itself (default RulesConfig).
|
||||||
|
"""
|
||||||
|
player1 = PlayerState(player_id="player-1")
|
||||||
|
player2 = PlayerState(player_id="player-2")
|
||||||
|
|
||||||
|
return GameState(
|
||||||
|
game_id="game-123",
|
||||||
|
players={"player-1": player1, "player-2": player2},
|
||||||
|
current_player_id="player-1",
|
||||||
|
turn_number=1,
|
||||||
|
phase=TurnPhase.MAIN,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGameStateAccess:
|
||||||
|
"""Tests for game state access methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_game_state_returns_state(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that get_game_state returns the game state when found.
|
||||||
|
|
||||||
|
The state manager should be called to load the state, and the
|
||||||
|
result should be returned to the caller.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
result = await game_service.get_game_state("game-123")
|
||||||
|
|
||||||
|
assert result == sample_game_state
|
||||||
|
mock_state_manager.load_state.assert_called_once_with("game-123")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_game_state_raises_not_found(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that get_game_state raises GameNotFoundError when not found.
|
||||||
|
|
||||||
|
When the state manager returns None, we should raise a specific
|
||||||
|
exception rather than returning None.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(GameNotFoundError) as exc_info:
|
||||||
|
await game_service.get_game_state("nonexistent")
|
||||||
|
|
||||||
|
assert exc_info.value.game_id == "nonexistent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_player_view_returns_visible_state(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that get_player_view returns visibility-filtered state.
|
||||||
|
|
||||||
|
The returned state should be filtered for the requesting player,
|
||||||
|
hiding the opponent's private information.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
result = await game_service.get_player_view("game-123", "player-1")
|
||||||
|
|
||||||
|
assert result.game_id == "game-123"
|
||||||
|
assert result.viewer_id == "player-1"
|
||||||
|
assert result.is_my_turn is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_player_view_raises_not_in_game(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that get_player_view raises error for non-participants.
|
||||||
|
|
||||||
|
Only players in the game should be able to view its state.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
with pytest.raises(PlayerNotInGameError) as exc_info:
|
||||||
|
await game_service.get_player_view("game-123", "stranger")
|
||||||
|
|
||||||
|
assert exc_info.value.player_id == "stranger"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_player_turn_returns_true_for_current_player(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that is_player_turn returns True for the current player.
|
||||||
|
|
||||||
|
The current player is determined by the game state's current_player_id.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
result = await game_service.is_player_turn("game-123", "player-1")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_player_turn_returns_false_for_other_player(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that is_player_turn returns False for non-current player.
|
||||||
|
|
||||||
|
Players who are not the current player should get False.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
result = await game_service.is_player_turn("game-123", "player-2")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_game_exists_returns_true_when_cached(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that game_exists returns True when game is in cache.
|
||||||
|
|
||||||
|
This is a quick check that doesn't need to load the full state.
|
||||||
|
"""
|
||||||
|
mock_state_manager.cache_exists.return_value = True
|
||||||
|
|
||||||
|
result = await game_service.game_exists("game-123")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_state_manager.cache_exists.assert_called_once_with("game-123")
|
||||||
|
|
||||||
|
|
||||||
|
class TestJoinGame:
|
||||||
|
"""Tests for the join_game method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_join_game_success(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test successful game join returns visible state.
|
||||||
|
|
||||||
|
When a player joins their game, they should receive the
|
||||||
|
visibility-filtered state and know if it's their turn.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
result = await game_service.join_game("game-123", "player-1")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.game_id == "game-123"
|
||||||
|
assert result.player_id == "player-1"
|
||||||
|
assert result.visible_state is not None
|
||||||
|
assert result.is_your_turn is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_join_game_not_your_turn(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that is_your_turn is False when it's opponent's turn.
|
||||||
|
|
||||||
|
The second player should see is_your_turn=False when the first
|
||||||
|
player is the current player.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
result = await game_service.join_game("game-123", "player-2")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.is_your_turn is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_join_game_not_found(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test join_game returns failure for non-existent game.
|
||||||
|
|
||||||
|
Rather than raising an exception, join_game returns a failed
|
||||||
|
result for better WebSocket error handling.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = None
|
||||||
|
|
||||||
|
result = await game_service.join_game("nonexistent", "player-1")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "not found" in result.message.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_join_game_not_participant(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test join_game returns failure for non-participants.
|
||||||
|
|
||||||
|
Players who are not in the game should not be able to join.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
result = await game_service.join_game("game-123", "stranger")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "not a participant" in result.message.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_join_ended_game(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test joining an ended game still succeeds but indicates game over.
|
||||||
|
|
||||||
|
Players should be able to rejoin ended games to see the final
|
||||||
|
state, but is_your_turn should be False.
|
||||||
|
"""
|
||||||
|
sample_game_state.winner_id = "player-1"
|
||||||
|
sample_game_state.end_reason = GameEndReason.PRIZES_TAKEN
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
result = await game_service.join_game("game-123", "player-2")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.is_your_turn is False
|
||||||
|
assert "ended" in result.message.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecuteAction:
|
||||||
|
"""Tests for the execute_action method.
|
||||||
|
|
||||||
|
These tests verify action execution through GameService. Since GameService
|
||||||
|
creates engines per-operation, we patch _create_engine_for_game to return
|
||||||
|
a mock engine with controlled behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_action_success(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test successful action execution.
|
||||||
|
|
||||||
|
A valid action by the current player should be executed and
|
||||||
|
the state should be saved to cache.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_engine.execute_action = AsyncMock(
|
||||||
|
return_value=ActionResult(
|
||||||
|
success=True,
|
||||||
|
message="Attack executed",
|
||||||
|
state_changes=[{"type": "damage", "amount": 30}],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
|
||||||
|
action = AttackAction(attack_index=0)
|
||||||
|
result = await game_service.execute_action("game-123", "player-1", action)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.action_type == "attack"
|
||||||
|
mock_state_manager.save_to_cache.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_action_game_not_found(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test execute_action raises error when game not found.
|
||||||
|
|
||||||
|
Missing games should raise GameNotFoundError for proper
|
||||||
|
error handling in the WebSocket layer.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(GameNotFoundError):
|
||||||
|
await game_service.execute_action("nonexistent", "player-1", PassAction())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_action_not_in_game(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test execute_action raises error for non-participants.
|
||||||
|
|
||||||
|
Only players in the game can execute actions.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
with pytest.raises(PlayerNotInGameError):
|
||||||
|
await game_service.execute_action("game-123", "stranger", PassAction())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_action_game_ended(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test execute_action raises error on ended games.
|
||||||
|
|
||||||
|
No actions should be allowed once a game has ended.
|
||||||
|
"""
|
||||||
|
sample_game_state.winner_id = "player-1"
|
||||||
|
sample_game_state.end_reason = GameEndReason.RESIGNATION
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
with pytest.raises(GameAlreadyEndedError):
|
||||||
|
await game_service.execute_action("game-123", "player-1", PassAction())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_action_not_your_turn(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test execute_action raises error when not player's turn.
|
||||||
|
|
||||||
|
Only the current player can execute actions.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
with pytest.raises(NotPlayerTurnError) as exc_info:
|
||||||
|
await game_service.execute_action("game-123", "player-2", PassAction())
|
||||||
|
|
||||||
|
assert exc_info.value.player_id == "player-2"
|
||||||
|
assert exc_info.value.current_player_id == "player-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_action_resign_allowed_out_of_turn(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that resignation is allowed even when not your turn.
|
||||||
|
|
||||||
|
Resignation is a special action that can be executed anytime
|
||||||
|
by either player.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_engine.execute_action = AsyncMock(
|
||||||
|
return_value=ActionResult(
|
||||||
|
success=True,
|
||||||
|
message="Player resigned",
|
||||||
|
win_result=WinResult(
|
||||||
|
winner_id="player-1",
|
||||||
|
loser_id="player-2",
|
||||||
|
end_reason=GameEndReason.RESIGNATION,
|
||||||
|
reason="Player resigned",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
|
||||||
|
# player-2 resigns even though it's player-1's turn
|
||||||
|
result = await game_service.execute_action("game-123", "player-2", ResignAction())
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.game_over is True
|
||||||
|
assert result.winner_id == "player-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_action_invalid_action(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test execute_action raises error for invalid actions.
|
||||||
|
|
||||||
|
When the GameEngine rejects an action, we should raise
|
||||||
|
InvalidActionError with the reason.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_engine.execute_action = AsyncMock(
|
||||||
|
return_value=ActionResult(
|
||||||
|
success=False,
|
||||||
|
message="Not enough energy to attack",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(game_service, "_create_engine_for_game", return_value=mock_engine),
|
||||||
|
pytest.raises(InvalidActionError) as exc_info,
|
||||||
|
):
|
||||||
|
await game_service.execute_action("game-123", "player-1", AttackAction(attack_index=0))
|
||||||
|
|
||||||
|
assert "Not enough energy" in exc_info.value.reason
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_action_game_over(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test execute_action detects game over and persists to DB.
|
||||||
|
|
||||||
|
When an action results in a win, the state should be persisted
|
||||||
|
to the database for durability.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_engine.execute_action = AsyncMock(
|
||||||
|
return_value=ActionResult(
|
||||||
|
success=True,
|
||||||
|
message="Final prize taken!",
|
||||||
|
win_result=WinResult(
|
||||||
|
winner_id="player-1",
|
||||||
|
loser_id="player-2",
|
||||||
|
end_reason=GameEndReason.PRIZES_TAKEN,
|
||||||
|
reason="All prizes taken",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
|
||||||
|
result = await game_service.execute_action(
|
||||||
|
"game-123", "player-1", AttackAction(attack_index=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.game_over is True
|
||||||
|
assert result.winner_id == "player-1"
|
||||||
|
assert result.end_reason == GameEndReason.PRIZES_TAKEN
|
||||||
|
mock_state_manager.persist_to_db.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_action_uses_game_rules(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that execute_action creates engine with game's rules.
|
||||||
|
|
||||||
|
The engine should be created using the rules stored in the game
|
||||||
|
state, not any service-level defaults.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_engine.execute_action = AsyncMock(
|
||||||
|
return_value=ActionResult(success=True, message="OK")
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
game_service, "_create_engine_for_game", return_value=mock_engine
|
||||||
|
) as mock_create:
|
||||||
|
await game_service.execute_action("game-123", "player-1", PassAction())
|
||||||
|
|
||||||
|
# Verify engine was created with the game state
|
||||||
|
mock_create.assert_called_once_with(sample_game_state)
|
||||||
|
|
||||||
|
|
||||||
|
class TestResignGame:
|
||||||
|
"""Tests for the resign_game convenience method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resign_game_executes_resign_action(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that resign_game is a convenience wrapper for execute_action.
|
||||||
|
|
||||||
|
The resign_game method should internally create a ResignAction
|
||||||
|
and call execute_action.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_engine.execute_action = AsyncMock(
|
||||||
|
return_value=ActionResult(
|
||||||
|
success=True,
|
||||||
|
message="Player resigned",
|
||||||
|
win_result=WinResult(
|
||||||
|
winner_id="player-2",
|
||||||
|
loser_id="player-1",
|
||||||
|
end_reason=GameEndReason.RESIGNATION,
|
||||||
|
reason="Player resigned",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
|
||||||
|
result = await game_service.resign_game("game-123", "player-1")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.action_type == "resign"
|
||||||
|
assert result.game_over is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestEndGame:
|
||||||
|
"""Tests for the end_game method (forced ending)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_game_sets_winner_and_reason(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that end_game sets winner and end reason.
|
||||||
|
|
||||||
|
Used for timeout or disconnection scenarios where the game
|
||||||
|
needs to be forcibly ended.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
await game_service.end_game(
|
||||||
|
"game-123",
|
||||||
|
winner_id="player-1",
|
||||||
|
end_reason=GameEndReason.TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify state was modified
|
||||||
|
assert sample_game_state.winner_id == "player-1"
|
||||||
|
assert sample_game_state.end_reason == GameEndReason.TIMEOUT
|
||||||
|
|
||||||
|
# Verify persistence
|
||||||
|
mock_state_manager.save_to_cache.assert_called_once()
|
||||||
|
mock_state_manager.persist_to_db.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_game_draw(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test ending a game as a draw (no winner).
|
||||||
|
|
||||||
|
Some scenarios (timeout with equal scores) can result in a draw.
|
||||||
|
"""
|
||||||
|
mock_state_manager.load_state.return_value = sample_game_state
|
||||||
|
|
||||||
|
await game_service.end_game(
|
||||||
|
"game-123",
|
||||||
|
winner_id=None,
|
||||||
|
end_reason=GameEndReason.DRAW,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sample_game_state.winner_id is None
|
||||||
|
assert sample_game_state.end_reason == GameEndReason.DRAW
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_game_not_found(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
mock_state_manager: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test end_game raises error when game not found."""
|
||||||
|
mock_state_manager.load_state.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(GameNotFoundError):
|
||||||
|
await game_service.end_game(
|
||||||
|
"nonexistent",
|
||||||
|
winner_id="player-1",
|
||||||
|
end_reason=GameEndReason.TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateGame:
|
||||||
|
"""Tests for the create_game method (skeleton)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_game_raises_not_implemented(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
) -> None:
|
||||||
|
"""Test that create_game raises NotImplementedError.
|
||||||
|
|
||||||
|
The full implementation will be done in GS-002. For now,
|
||||||
|
it should raise NotImplementedError with a clear message.
|
||||||
|
"""
|
||||||
|
with pytest.raises(NotImplementedError) as exc_info:
|
||||||
|
await game_service.create_game(
|
||||||
|
player1_id=str(uuid4()),
|
||||||
|
player2_id=str(uuid4()),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "GS-002" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateEngineForGame:
|
||||||
|
"""Tests for the _create_engine_for_game method.
|
||||||
|
|
||||||
|
This method is responsible for creating a GameEngine configured
|
||||||
|
with the rules from a specific game's state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_create_engine_uses_game_rules(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that engine is created with the game's rules.
|
||||||
|
|
||||||
|
The engine should use the RulesConfig stored in the game state,
|
||||||
|
not any default configuration.
|
||||||
|
"""
|
||||||
|
engine = game_service._create_engine_for_game(sample_game_state)
|
||||||
|
|
||||||
|
# Engine should have the game's rules
|
||||||
|
assert engine.rules == sample_game_state.rules
|
||||||
|
|
||||||
|
def test_create_engine_with_rng_seed(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that engine uses seeded RNG when game has rng_seed.
|
||||||
|
|
||||||
|
When a game has an rng_seed set, the engine should use a
|
||||||
|
deterministic RNG for replay support.
|
||||||
|
"""
|
||||||
|
sample_game_state.rng_seed = 12345
|
||||||
|
|
||||||
|
engine = game_service._create_engine_for_game(sample_game_state)
|
||||||
|
|
||||||
|
# Engine should have been created (we can't easily verify seed,
|
||||||
|
# but we can verify it doesn't error)
|
||||||
|
assert engine is not None
|
||||||
|
|
||||||
|
def test_create_engine_without_rng_seed(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that engine uses secure RNG when no seed is set.
|
||||||
|
|
||||||
|
Without an rng_seed, the engine should use cryptographically
|
||||||
|
secure random number generation.
|
||||||
|
"""
|
||||||
|
sample_game_state.rng_seed = None
|
||||||
|
|
||||||
|
engine = game_service._create_engine_for_game(sample_game_state)
|
||||||
|
|
||||||
|
assert engine is not None
|
||||||
|
|
||||||
|
def test_create_engine_derives_unique_seed_per_action(
|
||||||
|
self,
|
||||||
|
game_service: GameService,
|
||||||
|
sample_game_state: GameState,
|
||||||
|
) -> None:
|
||||||
|
"""Test that different action counts produce different RNG sequences.
|
||||||
|
|
||||||
|
For deterministic replay, each action needs a unique but
|
||||||
|
reproducible RNG seed based on game seed + action count.
|
||||||
|
"""
|
||||||
|
sample_game_state.rng_seed = 12345
|
||||||
|
|
||||||
|
# Simulate first action (action_log is empty)
|
||||||
|
sample_game_state.action_log = []
|
||||||
|
engine1 = game_service._create_engine_for_game(sample_game_state)
|
||||||
|
|
||||||
|
# Simulate second action (one action in log)
|
||||||
|
sample_game_state.action_log = [{"type": "pass"}]
|
||||||
|
engine2 = game_service._create_engine_for_game(sample_game_state)
|
||||||
|
|
||||||
|
# Both engines should be created successfully
|
||||||
|
# (They will have different seeds due to action count)
|
||||||
|
assert engine1 is not None
|
||||||
|
assert engine2 is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestExceptionMessages:
|
||||||
|
"""Tests for exception message formatting."""
|
||||||
|
|
||||||
|
def test_game_not_found_error_message(self) -> None:
|
||||||
|
"""Test GameNotFoundError has descriptive message."""
|
||||||
|
error = GameNotFoundError("game-123")
|
||||||
|
assert "game-123" in str(error)
|
||||||
|
assert error.game_id == "game-123"
|
||||||
|
|
||||||
|
def test_not_player_turn_error_message(self) -> None:
|
||||||
|
"""Test NotPlayerTurnError has descriptive message."""
|
||||||
|
error = NotPlayerTurnError("game-123", "player-2", "player-1")
|
||||||
|
assert "player-2" in str(error)
|
||||||
|
assert "player-1" in str(error)
|
||||||
|
assert error.game_id == "game-123"
|
||||||
|
|
||||||
|
def test_invalid_action_error_message(self) -> None:
|
||||||
|
"""Test InvalidActionError has descriptive message."""
|
||||||
|
error = InvalidActionError("game-123", "player-1", "Not enough energy")
|
||||||
|
assert "Not enough energy" in str(error)
|
||||||
|
assert error.reason == "Not enough energy"
|
||||||
|
|
||||||
|
def test_player_not_in_game_error_message(self) -> None:
|
||||||
|
"""Test PlayerNotInGameError has descriptive message."""
|
||||||
|
error = PlayerNotInGameError("game-123", "stranger")
|
||||||
|
assert "stranger" in str(error)
|
||||||
|
assert "game-123" in str(error)
|
||||||
|
|
||||||
|
def test_game_already_ended_error_message(self) -> None:
|
||||||
|
"""Test GameAlreadyEndedError has descriptive message."""
|
||||||
|
error = GameAlreadyEndedError("game-123")
|
||||||
|
assert "game-123" in str(error)
|
||||||
|
assert "ended" in str(error).lower()
|
||||||
Loading…
Reference in New Issue
Block a user