Add Phase 4 WebSocket infrastructure (WS-001 through GS-001)

WebSocket Message Schemas (WS-002):
- Add Pydantic models for all client/server WebSocket messages
- Implement discriminated unions for message type parsing
- Include JoinGame, Action, Resign, Heartbeat client messages
- Include GameState, ActionResult, Error, TurnStart server messages

Connection Manager (WS-003):
- Add Redis-backed WebSocket connection tracking
- Implement user-to-sid mapping with TTL management
- Support game room association and opponent lookup
- Add heartbeat tracking for connection health

Socket.IO Authentication (WS-004):
- Add JWT-based authentication middleware
- Support token extraction from multiple formats
- Implement session setup with ConnectionManager integration
- Add require_auth helper for event handlers

Socket.IO Server Setup (WS-001):
- Configure AsyncServer with ASGI mode
- Register /game namespace with event handlers
- Integrate with FastAPI via ASGIApp wrapper
- Configure CORS from application settings

Game Service (GS-001):
- Add stateless GameService for game lifecycle orchestration
- Create engine per-operation using rules from GameState
- Implement action-based RNG seeding for deterministic replay
- Add rng_seed field to GameState for replay support

Architecture verified:
- Core module independence (no forbidden imports)
- Config from request pattern (rules in GameState)
- Dependency injection (constructor deps, method config)
- All 1090 tests passing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2026-01-28 22:21:20 -06:00
parent c00ee87f25
commit 0c810e5b30
17 changed files with 4910 additions and 10 deletions

View File

@ -379,6 +379,7 @@ class GameState(BaseModel):
forced_actions: Queue of ForcedAction items that must be completed before game proceeds. forced_actions: Queue of ForcedAction items that must be completed before game proceeds.
Actions are processed in FIFO order (first added = first to resolve). Actions are processed in FIFO order (first added = first to resolve).
action_log: Log of actions taken (for replays/debugging). action_log: Log of actions taken (for replays/debugging).
rng_seed: Optional seed for deterministic RNG. When set, enables replay capability.
""" """
game_id: str game_id: str
@ -412,6 +413,9 @@ class GameState(BaseModel):
# Optional action log for replays # Optional action log for replays
action_log: list[dict[str, Any]] = Field(default_factory=list) action_log: list[dict[str, Any]] = Field(default_factory=list)
# Optional RNG seed for deterministic replays
rng_seed: int | None = None
def get_current_player(self) -> PlayerState: def get_current_player(self) -> PlayerState:
"""Get the PlayerState for the current player. """Get the PlayerState for the current player.

View File

@ -26,6 +26,7 @@ from app.config import settings
from app.db import close_db, init_db from app.db import close_db, init_db
from app.db.redis import close_redis, init_redis from app.db.redis import close_redis, init_redis
from app.services import get_card_service from app.services import get_card_service
from app.socketio import create_socketio_app
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -173,3 +174,16 @@ app.include_router(decks_router, prefix="/api")
# app.include_router(cards.router, prefix="/api/cards", tags=["cards"]) # app.include_router(cards.router, prefix="/api/cards", tags=["cards"])
# app.include_router(games.router, prefix="/api/games", tags=["games"]) # app.include_router(games.router, prefix="/api/games", tags=["games"])
# app.include_router(campaign.router, prefix="/api/campaign", tags=["campaign"]) # app.include_router(campaign.router, prefix="/api/campaign", tags=["campaign"])
# === Socket.IO Integration ===
# Wrap FastAPI with Socket.IO ASGI app for WebSocket support.
# The combined app handles:
# - WebSocket connections at /socket.io/
# - HTTP requests passed through to FastAPI
#
# Note: The module-level 'app' variable is now the combined ASGI app.
# This is intentional - uvicorn imports 'app' and will serve both
# FastAPI HTTP routes and Socket.IO WebSocket connections.
_fastapi_app = app
app = create_socketio_app(_fastapi_app)

View File

@ -0,0 +1,491 @@
"""WebSocket message schemas for Mantimon TCG real-time communication.
This module defines Pydantic models for all WebSocket message types, following
the discriminated union pattern from app/core/models/actions.py. Messages are
categorized into client-to-server and server-to-client types.
Message Design Principles:
1. All messages have a 'type' discriminator field for automatic parsing
2. All messages have a 'message_id' for idempotency and event ordering
3. Server messages include a 'timestamp' for client-side latency tracking
4. Game-scoped messages include 'game_id' for multi-game support
Example:
# Parse incoming client message
data = {"type": "join_game", "message_id": "abc-123", "game_id": "game-456"}
message = parse_client_message(data)
assert isinstance(message, JoinGameMessage)
# Create server response
response = GameStateMessage(
message_id=str(uuid4()),
game_id="game-456",
state=visible_state,
)
await websocket.send_text(response.model_dump_json())
"""
from datetime import UTC, datetime
from enum import StrEnum
from typing import Annotated, Any, Literal
from uuid import uuid4
from pydantic import BaseModel, Field, field_validator
from app.core.enums import GameEndReason
from app.core.models.actions import Action
from app.core.visibility import VisibleGameState
class WSErrorCode(StrEnum):
"""Error codes for WebSocket error messages.
These codes help clients handle errors programmatically without
parsing error message text.
"""
# Connection errors
AUTHENTICATION_FAILED = "authentication_failed"
CONNECTION_CLOSED = "connection_closed"
RATE_LIMITED = "rate_limited"
# Game errors
GAME_NOT_FOUND = "game_not_found"
NOT_IN_GAME = "not_in_game"
ALREADY_IN_GAME = "already_in_game"
GAME_FULL = "game_full"
GAME_ENDED = "game_ended"
# Action errors
INVALID_ACTION = "invalid_action"
NOT_YOUR_TURN = "not_your_turn"
ACTION_NOT_ALLOWED = "action_not_allowed"
# Protocol errors
INVALID_MESSAGE = "invalid_message"
UNKNOWN_MESSAGE_TYPE = "unknown_message_type"
# Server errors
INTERNAL_ERROR = "internal_error"
class ConnectionStatus(StrEnum):
"""Connection status for opponent status messages."""
CONNECTED = "connected"
DISCONNECTED = "disconnected"
RECONNECTING = "reconnecting"
# =============================================================================
# Base Message Classes
# =============================================================================
def _generate_message_id() -> str:
"""Generate a unique message ID."""
return str(uuid4())
def _utc_now() -> datetime:
"""Get current UTC timestamp."""
return datetime.now(UTC)
class BaseClientMessage(BaseModel):
"""Base class for all client-to-server messages.
All client messages must include a message_id for idempotency. The client
generates this ID to allow detecting and handling duplicate messages.
Attributes:
message_id: Client-generated UUID for idempotency and tracking.
"""
message_id: str = Field(
default_factory=_generate_message_id,
description="Client-generated UUID for idempotency",
)
@field_validator("message_id")
@classmethod
def validate_message_id(cls, v: str) -> str:
"""Ensure message_id is not empty."""
if not v or not v.strip():
raise ValueError("message_id cannot be empty")
return v
class BaseServerMessage(BaseModel):
"""Base class for all server-to-client messages.
All server messages include a message_id and timestamp. The timestamp
enables client-side latency calculation and event ordering.
Attributes:
message_id: Server-generated UUID for tracking.
timestamp: UTC timestamp when the message was created.
"""
message_id: str = Field(
default_factory=_generate_message_id,
description="Server-generated UUID for tracking",
)
timestamp: datetime = Field(
default_factory=_utc_now,
description="UTC timestamp when message was created",
)
# =============================================================================
# Client -> Server Messages
# =============================================================================
class JoinGameMessage(BaseClientMessage):
"""Request to join or rejoin a game session.
When rejoining after a disconnect, the client can provide last_event_id
to receive any missed events since that point.
Attributes:
type: Discriminator field, always "join_game".
game_id: The game to join.
last_event_id: For reconnection - ID of last received event for replay.
"""
type: Literal["join_game"] = "join_game"
game_id: str = Field(..., description="ID of the game to join")
last_event_id: str | None = Field(
default=None,
description="Last event ID received (for reconnection replay)",
)
class ActionMessage(BaseClientMessage):
"""Submit a game action for execution.
Wraps the Action union type from app/core/models/actions.py to provide
consistent message envelope with message_id and game context.
Attributes:
type: Discriminator field, always "action".
game_id: The game this action is for.
action: The game action to execute.
"""
type: Literal["action"] = "action"
game_id: str = Field(..., description="ID of the game")
action: Action = Field(..., description="The game action to execute")
class ResignMessage(BaseClientMessage):
"""Request to resign from a game.
This is separate from ResignAction in the game engine to allow
resignation handling at the WebSocket layer (e.g., when the game
engine is unavailable).
Attributes:
type: Discriminator field, always "resign".
game_id: The game to resign from.
"""
type: Literal["resign"] = "resign"
game_id: str = Field(..., description="ID of the game to resign from")
class HeartbeatMessage(BaseClientMessage):
"""Keep-alive message to maintain connection.
Clients should send heartbeats periodically (e.g., every 30 seconds)
to prevent connection timeout. The server responds with a HeartbeatAck.
Attributes:
type: Discriminator field, always "heartbeat".
"""
type: Literal["heartbeat"] = "heartbeat"
# Union type for client messages
ClientMessage = Annotated[
JoinGameMessage | ActionMessage | ResignMessage | HeartbeatMessage,
Field(discriminator="type"),
]
# =============================================================================
# Server -> Client Messages
# =============================================================================
class GameStateMessage(BaseServerMessage):
"""Full game state update.
Sent when a player joins a game or when a full state sync is needed.
Contains the complete visible game state from that player's perspective.
Attributes:
type: Discriminator field, always "game_state".
game_id: The game this state is for.
state: The full visible game state.
event_id: Monotonic event ID for reconnection replay.
"""
type: Literal["game_state"] = "game_state"
game_id: str = Field(..., description="ID of the game")
state: VisibleGameState = Field(..., description="Full visible game state")
event_id: str = Field(
default_factory=_generate_message_id,
description="Event ID for reconnection replay",
)
class ActionResultMessage(BaseServerMessage):
"""Result of a player action.
Sent after processing an ActionMessage to confirm success or failure.
On success, includes changes that resulted from the action.
Attributes:
type: Discriminator field, always "action_result".
game_id: The game this result is for.
request_message_id: The message_id of the original ActionMessage.
success: Whether the action succeeded.
action_type: The type of action that was attempted.
changes: Description of state changes (on success).
error_code: Error code if action failed.
error_message: Human-readable error message if failed.
"""
type: Literal["action_result"] = "action_result"
game_id: str = Field(..., description="ID of the game")
request_message_id: str = Field(..., description="message_id of the original action request")
success: bool = Field(..., description="Whether the action succeeded")
action_type: str = Field(..., description="Type of action that was attempted")
changes: dict[str, Any] = Field(
default_factory=dict,
description="State changes resulting from the action",
)
error_code: WSErrorCode | None = Field(default=None, description="Error code if action failed")
error_message: str | None = Field(default=None, description="Human-readable error message")
class ErrorMessage(BaseServerMessage):
"""Error notification for protocol or connection errors.
Used for errors not associated with a specific action, such as
invalid message format, authentication failures, or server errors.
Attributes:
type: Discriminator field, always "error".
code: Machine-readable error code.
message: Human-readable error description.
details: Additional error context.
request_message_id: ID of the message that caused the error, if applicable.
"""
type: Literal["error"] = "error"
code: WSErrorCode = Field(..., description="Error code")
message: str = Field(..., description="Human-readable error description")
details: dict[str, Any] = Field(default_factory=dict, description="Additional error context")
request_message_id: str | None = Field(
default=None, description="ID of message that caused error"
)
class TurnStartMessage(BaseServerMessage):
"""Notification that a new turn has started.
Sent to both players when a turn begins. Indicates whose turn it is
and the current turn number.
Attributes:
type: Discriminator field, always "turn_start".
game_id: The game this notification is for.
player_id: The player whose turn is starting.
turn_number: The current turn number.
event_id: Event ID for reconnection replay.
"""
type: Literal["turn_start"] = "turn_start"
game_id: str = Field(..., description="ID of the game")
player_id: str = Field(..., description="Player whose turn is starting")
turn_number: int = Field(..., description="Current turn number", ge=1)
event_id: str = Field(
default_factory=_generate_message_id,
description="Event ID for reconnection replay",
)
class TurnTimeoutMessage(BaseServerMessage):
"""Timeout warning or expiration notification.
Sent when a player's turn is approaching timeout (warning) or has
expired. Warnings give players time to complete their action.
Attributes:
type: Discriminator field, always "turn_timeout".
game_id: The game this notification is for.
remaining_seconds: Seconds remaining before timeout.
is_warning: True if this is a warning, False if timeout has occurred.
player_id: The player whose turn is timing out.
"""
type: Literal["turn_timeout"] = "turn_timeout"
game_id: str = Field(..., description="ID of the game")
remaining_seconds: int = Field(..., description="Seconds remaining before timeout", ge=0)
is_warning: bool = Field(..., description="True if warning, False if timeout occurred")
player_id: str = Field(..., description="Player whose turn is timing out")
class GameOverMessage(BaseServerMessage):
"""Notification that the game has ended.
Sent to all players when the game concludes, regardless of the reason.
Attributes:
type: Discriminator field, always "game_over".
game_id: The game that ended.
winner_id: The winning player, or None for a draw.
end_reason: Why the game ended.
final_state: The final visible game state.
event_id: Event ID for reconnection replay.
"""
type: Literal["game_over"] = "game_over"
game_id: str = Field(..., description="ID of the game that ended")
winner_id: str | None = Field(default=None, description="Winner player ID, or None for draw")
end_reason: GameEndReason = Field(..., description="Reason the game ended")
final_state: VisibleGameState = Field(..., description="Final visible game state")
event_id: str = Field(
default_factory=_generate_message_id,
description="Event ID for reconnection replay",
)
class OpponentStatusMessage(BaseServerMessage):
"""Notification of opponent connection status change.
Sent when the opponent connects, disconnects, or is reconnecting.
Allows the UI to show connection status to the player.
Attributes:
type: Discriminator field, always "opponent_status".
game_id: The game this status is for.
opponent_id: The opponent's player ID.
status: The opponent's current connection status.
"""
type: Literal["opponent_status"] = "opponent_status"
game_id: str = Field(..., description="ID of the game")
opponent_id: str = Field(..., description="Opponent's player ID")
status: ConnectionStatus = Field(..., description="Opponent's connection status")
class HeartbeatAckMessage(BaseServerMessage):
"""Acknowledgment of a client heartbeat.
Sent in response to HeartbeatMessage to confirm the connection is alive.
Attributes:
type: Discriminator field, always "heartbeat_ack".
"""
type: Literal["heartbeat_ack"] = "heartbeat_ack"
# Union type for server messages
ServerMessage = Annotated[
GameStateMessage
| ActionResultMessage
| ErrorMessage
| TurnStartMessage
| TurnTimeoutMessage
| GameOverMessage
| OpponentStatusMessage
| HeartbeatAckMessage,
Field(discriminator="type"),
]
# =============================================================================
# Parsing Functions
# =============================================================================
def parse_client_message(data: dict[str, Any]) -> ClientMessage:
"""Parse an incoming client message from a dictionary.
Uses the 'type' field to determine which message model to use.
Args:
data: Dictionary containing message data with a 'type' field.
Returns:
The appropriate ClientMessage subtype.
Raises:
ValueError: If the message type is unknown.
ValidationError: If the data doesn't match the message schema.
Example:
data = {"type": "join_game", "message_id": "abc", "game_id": "123"}
message = parse_client_message(data)
assert isinstance(message, JoinGameMessage)
"""
from pydantic import TypeAdapter
adapter: TypeAdapter[ClientMessage] = TypeAdapter(ClientMessage)
return adapter.validate_python(data)
def parse_server_message(data: dict[str, Any]) -> ServerMessage:
"""Parse a server message from a dictionary.
Primarily used for testing - clients receive JSON and parse directly.
Args:
data: Dictionary containing message data with a 'type' field.
Returns:
The appropriate ServerMessage subtype.
Raises:
ValueError: If the message type is unknown.
ValidationError: If the data doesn't match the message schema.
"""
from pydantic import TypeAdapter
adapter: TypeAdapter[ServerMessage] = TypeAdapter(ServerMessage)
return adapter.validate_python(data)
__all__ = [
# Enums
"ConnectionStatus",
"WSErrorCode",
# Base classes
"BaseClientMessage",
"BaseServerMessage",
# Client messages
"ActionMessage",
"ClientMessage",
"HeartbeatMessage",
"JoinGameMessage",
"ResignMessage",
# Server messages
"ActionResultMessage",
"ErrorMessage",
"GameOverMessage",
"GameStateMessage",
"HeartbeatAckMessage",
"OpponentStatusMessage",
"ServerMessage",
"TurnStartMessage",
"TurnTimeoutMessage",
# Parsing functions
"parse_client_message",
"parse_server_message",
]

View File

@ -0,0 +1,578 @@
"""WebSocket connection management with Redis-backed session tracking.
This module manages WebSocket connection state using Redis for persistence
and cross-instance coordination. It tracks:
- Individual connection sessions (sid -> user_id, game_id, timestamps)
- User to connection mapping (user_id -> sid)
- Game participants (game_id -> set of sids)
Key Patterns:
conn:{sid} - Hash with connection details (user_id, game_id, connected_at, last_seen)
user_conn:{user_id} - String with active sid
game_conns:{game_id} - Set of sids for game participants
Lifecycle:
1. on_connect: Register connection, map user to sid
2. on_join_game: Add sid to game's connection set
3. on_heartbeat: Update last_seen timestamp
4. on_disconnect: Remove from all tracking structures
Example:
manager = ConnectionManager()
# On WebSocket connect
await manager.register_connection(sid, user_id)
# On joining a game
await manager.join_game(sid, game_id)
# Check if opponent is online
if await manager.is_user_online(opponent_id):
# Send real-time update
# On disconnect
await manager.unregister_connection(sid)
"""
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from uuid import UUID
from app.db.redis import get_redis
logger = logging.getLogger(__name__)
# Redis key patterns
CONN_PREFIX = "conn:"
USER_CONN_PREFIX = "user_conn:"
GAME_CONNS_PREFIX = "game_conns:"
# Connection TTL (auto-expire stale connections)
DEFAULT_CONN_TTL_SECONDS = 3600 # 1 hour
HEARTBEAT_INTERVAL_SECONDS = 30
@dataclass
class ConnectionInfo:
"""Information about a WebSocket connection.
Attributes:
sid: Socket.IO session ID.
user_id: Authenticated user's UUID.
game_id: Current game ID if in a game, None otherwise.
connected_at: When the connection was established.
last_seen: Last heartbeat or activity timestamp.
"""
sid: str
user_id: str
game_id: str | None
connected_at: datetime
last_seen: datetime
def is_stale(self, threshold_seconds: int = HEARTBEAT_INTERVAL_SECONDS * 3) -> bool:
"""Check if connection is stale (no recent activity).
Args:
threshold_seconds: Seconds since last_seen to consider stale.
Returns:
True if connection hasn't been seen recently.
"""
elapsed = (datetime.now(UTC) - self.last_seen).total_seconds()
return elapsed > threshold_seconds
class ConnectionManager:
"""Manages WebSocket connections with Redis-backed session tracking.
Provides methods to:
- Register/unregister connections
- Track which game a connection is participating in
- Check user online status
- Get all connections for a game
- Handle heartbeats and stale connection detection
Attributes:
conn_ttl_seconds: TTL for connection records in Redis.
"""
def __init__(self, conn_ttl_seconds: int = DEFAULT_CONN_TTL_SECONDS) -> None:
"""Initialize the ConnectionManager.
Args:
conn_ttl_seconds: How long to keep connection records in Redis.
"""
self.conn_ttl_seconds = conn_ttl_seconds
def _conn_key(self, sid: str) -> str:
"""Generate Redis key for a connection."""
return f"{CONN_PREFIX}{sid}"
def _user_conn_key(self, user_id: str) -> str:
"""Generate Redis key for user-to-connection mapping."""
return f"{USER_CONN_PREFIX}{user_id}"
def _game_conns_key(self, game_id: str) -> str:
"""Generate Redis key for game connection set."""
return f"{GAME_CONNS_PREFIX}{game_id}"
# =========================================================================
# Connection Lifecycle
# =========================================================================
async def register_connection(
self,
sid: str,
user_id: str | UUID,
) -> None:
"""Register a new WebSocket connection.
Creates connection record in Redis and maps user to this connection.
If user had a previous connection, it will be replaced.
Args:
sid: Socket.IO session ID.
user_id: Authenticated user's ID (UUID or string).
Example:
await manager.register_connection("abc123", user.id)
"""
user_id_str = str(user_id)
now = datetime.now(UTC).isoformat()
async with get_redis() as redis:
# Check for existing connection and clean it up
old_sid = await redis.get(self._user_conn_key(user_id_str))
if old_sid and old_sid != sid:
logger.info(f"Replacing old connection {old_sid} for user {user_id_str}")
await self._cleanup_connection(old_sid)
# Create connection record (hash)
conn_key = self._conn_key(sid)
await redis.hset(
conn_key,
mapping={
"user_id": user_id_str,
"game_id": "",
"connected_at": now,
"last_seen": now,
},
)
await redis.expire(conn_key, self.conn_ttl_seconds)
# Map user to this connection
user_conn_key = self._user_conn_key(user_id_str)
await redis.set(user_conn_key, sid)
await redis.expire(user_conn_key, self.conn_ttl_seconds)
logger.debug(f"Registered connection: sid={sid}, user_id={user_id_str}")
async def unregister_connection(self, sid: str) -> ConnectionInfo | None:
"""Unregister a WebSocket connection and clean up all related data.
Removes connection from:
- Connection record
- User-to-connection mapping
- Any game connection sets
Args:
sid: Socket.IO session ID.
Returns:
ConnectionInfo of the removed connection, or None if not found.
Example:
info = await manager.unregister_connection("abc123")
if info and info.game_id:
# Notify opponent of disconnect
"""
# Get connection info before cleanup
info = await self.get_connection(sid)
if info is None:
logger.debug(f"Connection not found for unregister: {sid}")
return None
await self._cleanup_connection(sid)
logger.debug(f"Unregistered connection: sid={sid}, user_id={info.user_id}")
return info
async def _cleanup_connection(self, sid: str) -> None:
"""Internal cleanup of a connection's Redis data.
Args:
sid: Socket.IO session ID.
"""
async with get_redis() as redis:
conn_key = self._conn_key(sid)
# Get connection data for cleanup
data = await redis.hgetall(conn_key)
if not data:
return
user_id = data.get("user_id", "")
game_id = data.get("game_id", "")
# Remove from game connection set if in a game
if game_id:
game_conns_key = self._game_conns_key(game_id)
await redis.srem(game_conns_key, sid)
# Remove user-to-connection mapping if it points to this sid
if user_id:
user_conn_key = self._user_conn_key(user_id)
current_sid = await redis.get(user_conn_key)
if current_sid == sid:
await redis.delete(user_conn_key)
# Delete connection record
await redis.delete(conn_key)
# =========================================================================
# Game Association
# =========================================================================
async def join_game(self, sid: str, game_id: str) -> bool:
"""Associate a connection with a game.
Updates the connection's game_id and adds it to the game's connection set.
If connection was in a different game, leaves that game first.
Args:
sid: Socket.IO session ID.
game_id: Game to join.
Returns:
True if successful, False if connection not found.
Example:
await manager.join_game("abc123", "game-456")
"""
async with get_redis() as redis:
conn_key = self._conn_key(sid)
# Check connection exists
exists = await redis.exists(conn_key)
if not exists:
logger.warning(f"Cannot join game: connection not found {sid}")
return False
# Leave current game if in one
current_game = await redis.hget(conn_key, "game_id")
if current_game and current_game != game_id:
old_game_conns_key = self._game_conns_key(current_game)
await redis.srem(old_game_conns_key, sid)
logger.debug(f"Connection {sid} left game {current_game}")
# Update connection's game_id
await redis.hset(conn_key, "game_id", game_id)
# Add to game's connection set
game_conns_key = self._game_conns_key(game_id)
await redis.sadd(game_conns_key, sid)
# Set TTL on game conns set (cleanup if game becomes inactive)
await redis.expire(game_conns_key, self.conn_ttl_seconds)
logger.debug(f"Connection {sid} joined game {game_id}")
return True
async def leave_game(self, sid: str) -> str | None:
"""Remove a connection from its current game.
Args:
sid: Socket.IO session ID.
Returns:
The game_id that was left, or None if not in a game.
Example:
game_id = await manager.leave_game("abc123")
"""
async with get_redis() as redis:
conn_key = self._conn_key(sid)
# Get current game
game_id = await redis.hget(conn_key, "game_id")
if not game_id:
return None
# Remove from game's connection set
game_conns_key = self._game_conns_key(game_id)
await redis.srem(game_conns_key, sid)
# Clear game_id on connection
await redis.hset(conn_key, "game_id", "")
logger.debug(f"Connection {sid} left game {game_id}")
return game_id
# =========================================================================
# Heartbeat / Activity
# =========================================================================
async def update_heartbeat(self, sid: str) -> bool:
"""Update the last_seen timestamp for a connection.
Should be called on each heartbeat message from the client.
Also refreshes TTL on connection records.
Args:
sid: Socket.IO session ID.
Returns:
True if successful, False if connection not found.
Example:
await manager.update_heartbeat("abc123")
"""
now = datetime.now(UTC).isoformat()
async with get_redis() as redis:
conn_key = self._conn_key(sid)
# Check exists and update atomically
exists = await redis.exists(conn_key)
if not exists:
return False
# Update last_seen
await redis.hset(conn_key, "last_seen", now)
# Refresh TTL
await redis.expire(conn_key, self.conn_ttl_seconds)
# Also refresh user mapping TTL
user_id = await redis.hget(conn_key, "user_id")
if user_id:
user_conn_key = self._user_conn_key(user_id)
await redis.expire(user_conn_key, self.conn_ttl_seconds)
return True
# =========================================================================
# Query Methods
# =========================================================================
async def get_connection(self, sid: str) -> ConnectionInfo | None:
"""Get connection info by session ID.
Args:
sid: Socket.IO session ID.
Returns:
ConnectionInfo if found, None otherwise.
Example:
info = await manager.get_connection("abc123")
if info:
print(f"User {info.user_id} connected at {info.connected_at}")
"""
async with get_redis() as redis:
conn_key = self._conn_key(sid)
data = await redis.hgetall(conn_key)
if not data:
return None
return ConnectionInfo(
sid=sid,
user_id=data.get("user_id", ""),
game_id=data.get("game_id") or None,
connected_at=datetime.fromisoformat(data["connected_at"]),
last_seen=datetime.fromisoformat(data["last_seen"]),
)
async def get_user_connection(self, user_id: str | UUID) -> ConnectionInfo | None:
"""Get connection info for a user.
Args:
user_id: User's UUID or string ID.
Returns:
ConnectionInfo if user is connected, None otherwise.
Example:
info = await manager.get_user_connection(opponent_id)
if info:
# User is online
"""
user_id_str = str(user_id)
async with get_redis() as redis:
user_conn_key = self._user_conn_key(user_id_str)
sid = await redis.get(user_conn_key)
if not sid:
return None
return await self.get_connection(sid)
async def is_user_online(self, user_id: str | UUID) -> bool:
"""Check if a user has an active connection.
Args:
user_id: User's UUID or string ID.
Returns:
True if user is connected, False otherwise.
Example:
if await manager.is_user_online(opponent_id):
# Send real-time notification
"""
info = await self.get_user_connection(user_id)
if info is None:
return False
# Check if connection is stale
if info.is_stale():
logger.debug(f"User {user_id} connection is stale")
return False
return True
async def get_game_connections(self, game_id: str) -> list[ConnectionInfo]:
"""Get all connections for a game.
Args:
game_id: Game ID.
Returns:
List of ConnectionInfo for all connected participants.
Example:
connections = await manager.get_game_connections("game-456")
for conn in connections:
# Send state update to each participant
"""
async with get_redis() as redis:
game_conns_key = self._game_conns_key(game_id)
sids = await redis.smembers(game_conns_key)
connections = []
for sid in sids:
info = await self.get_connection(sid)
if info is not None:
connections.append(info)
return connections
async def get_game_user_sids(self, game_id: str) -> dict[str, str]:
"""Get mapping of user_id to sid for a game.
Useful for targeted message delivery to specific players.
Args:
game_id: Game ID.
Returns:
Dict mapping user_id to their sid.
Example:
user_sids = await manager.get_game_user_sids("game-456")
player1_sid = user_sids.get(player1_id)
"""
connections = await self.get_game_connections(game_id)
return {conn.user_id: conn.sid for conn in connections}
async def get_opponent_sid(
self,
game_id: str,
current_user_id: str | UUID,
) -> str | None:
"""Get the sid of the opponent in a 2-player game.
Args:
game_id: Game ID.
current_user_id: The current user's ID.
Returns:
Opponent's sid if found and connected, None otherwise.
Example:
opponent_sid = await manager.get_opponent_sid("game-456", user_id)
if opponent_sid:
# Send message to opponent
"""
current_user_str = str(current_user_id)
connections = await self.get_game_connections(game_id)
for conn in connections:
if conn.user_id != current_user_str:
return conn.sid
return None
# =========================================================================
# Maintenance
# =========================================================================
async def cleanup_stale_connections(
self,
threshold_seconds: int = HEARTBEAT_INTERVAL_SECONDS * 3,
) -> int:
"""Clean up stale connections that haven't sent heartbeats.
Should be called periodically by a background task.
Args:
threshold_seconds: Seconds since last_seen to consider stale.
Returns:
Number of connections cleaned up.
Example:
# In background task
count = await manager.cleanup_stale_connections()
logger.info(f"Cleaned up {count} stale connections")
"""
count = 0
async with get_redis() as redis:
# Scan for all connection keys
async for key in redis.scan_iter(match=f"{CONN_PREFIX}*"):
sid = key[len(CONN_PREFIX) :]
info = await self.get_connection(sid)
if info and info.is_stale(threshold_seconds):
await self._cleanup_connection(sid)
count += 1
logger.debug(f"Cleaned up stale connection: {sid}")
if count > 0:
logger.info(f"Cleaned up {count} stale connections")
return count
async def get_connection_count(self) -> int:
"""Get the total number of active connections.
Useful for monitoring and admin dashboards.
Returns:
Number of active connections.
"""
count = 0
async with get_redis() as redis:
async for _ in redis.scan_iter(match=f"{CONN_PREFIX}*"):
count += 1
return count
async def get_game_connection_count(self, game_id: str) -> int:
"""Get the number of connections for a specific game.
Args:
game_id: Game ID.
Returns:
Number of connected participants.
"""
async with get_redis() as redis:
game_conns_key = self._game_conns_key(game_id)
return await redis.scard(game_conns_key)
# Global singleton instance
connection_manager = ConnectionManager()

View File

@ -0,0 +1,558 @@
"""Game service for orchestrating game lifecycle in Mantimon TCG.
This service is the bridge between WebSocket communication and the core
GameEngine. It handles:
- Game creation with deck loading and validation
- Action execution with persistence
- Game state retrieval with visibility filtering
- Game lifecycle (join, resign, end)
IMPORTANT: This service is stateless. All game-specific configuration
(RulesConfig) is stored in the GameState itself, not in this service.
Rules come from the frontend request at game creation time.
Architecture:
WebSocket Layer -> GameService -> GameEngine + GameStateManager
-> DeckService + CardService
Example:
from app.services.game_service import GameService, game_service
# Create a new game (rules from frontend)
result = await game_service.create_game(
player1_id=user1.id,
player2_id=user2.id,
deck1_id=deck1.id,
deck2_id=deck2.id,
rules_config=rules_from_request, # Frontend provides this
)
# Execute an action (uses rules stored in game state)
result = await game_service.execute_action(
game_id=game.id,
player_id=user1.id,
action=AttackAction(attack_index=0),
)
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
from app.core.engine import ActionResult, GameEngine
from app.core.enums import GameEndReason
from app.core.models.actions import Action, ResignAction
from app.core.models.game_state import GameState
from app.core.rng import create_rng
from app.core.visibility import VisibleGameState, get_visible_state
from app.services.card_service import CardService, get_card_service
from app.services.game_state_manager import GameStateManager, game_state_manager
logger = logging.getLogger(__name__)
# =============================================================================
# Exceptions
# =============================================================================
class GameServiceError(Exception):
"""Base exception for GameService errors."""
pass
class GameNotFoundError(GameServiceError):
"""Raised when a game cannot be found."""
def __init__(self, game_id: str) -> None:
self.game_id = game_id
super().__init__(f"Game not found: {game_id}")
class NotPlayerTurnError(GameServiceError):
"""Raised when a player tries to act out of turn."""
def __init__(self, game_id: str, player_id: str, current_player_id: str) -> None:
self.game_id = game_id
self.player_id = player_id
self.current_player_id = current_player_id
super().__init__(
f"Not player's turn: {player_id} tried to act, but it's {current_player_id}'s turn"
)
class InvalidActionError(GameServiceError):
"""Raised when an action is invalid."""
def __init__(self, game_id: str, player_id: str, reason: str) -> None:
self.game_id = game_id
self.player_id = player_id
self.reason = reason
super().__init__(f"Invalid action in game {game_id}: {reason}")
class PlayerNotInGameError(GameServiceError):
"""Raised when a player is not a participant in the game."""
def __init__(self, game_id: str, player_id: str) -> None:
self.game_id = game_id
self.player_id = player_id
super().__init__(f"Player {player_id} is not in game {game_id}")
class GameAlreadyEndedError(GameServiceError):
"""Raised when trying to act on a game that has already ended."""
def __init__(self, game_id: str) -> None:
self.game_id = game_id
super().__init__(f"Game {game_id} has already ended")
# =============================================================================
# Result Types
# =============================================================================
@dataclass
class GameActionResult:
"""Result of executing a game action.
Attributes:
success: Whether the action succeeded.
game_id: The game ID.
action_type: The type of action executed.
message: Description of what happened.
state_changes: Dict of state changes for client updates.
game_over: Whether the game ended as a result.
winner_id: Winner's player ID if game ended with a winner.
end_reason: Reason the game ended, if applicable.
"""
success: bool
game_id: str
action_type: str
message: str = ""
state_changes: dict[str, Any] = field(default_factory=dict)
game_over: bool = False
winner_id: str | None = None
end_reason: GameEndReason | None = None
@dataclass
class GameJoinResult:
"""Result of joining a game.
Attributes:
success: Whether the join succeeded.
game_id: The game ID.
player_id: The joining player's ID.
visible_state: The game state visible to the player.
is_your_turn: Whether it's this player's turn.
message: Additional information or error message.
"""
success: bool
game_id: str
player_id: str
visible_state: VisibleGameState | None = None
is_your_turn: bool = False
message: str = ""
# =============================================================================
# GameService
# =============================================================================
class GameService:
"""Service for orchestrating game lifecycle operations.
This service coordinates between the WebSocket layer and the core
GameEngine, handling persistence and state management.
IMPORTANT: This service is STATELESS regarding game rules.
- Rules are stored in each GameState (set at creation time)
- The GameEngine is instantiated per-operation with the game's rules
- No RulesConfig is stored in this service
Attributes:
_state_manager: GameStateManager for persistence.
_card_service: CardService for card definitions.
"""
def __init__(
self,
state_manager: GameStateManager | None = None,
card_service: CardService | None = None,
) -> None:
"""Initialize the GameService.
Note: No GameEngine or RulesConfig here - those are per-game,
not per-service. The engine is created as needed using the
rules stored in each game's state.
Args:
state_manager: GameStateManager instance. Uses global if not provided.
card_service: CardService instance. Uses global if not provided.
"""
self._state_manager = state_manager or game_state_manager
self._card_service = card_service or get_card_service()
def _create_engine_for_game(self, game: GameState) -> GameEngine:
"""Create a GameEngine configured for a specific game's rules.
The engine is created on-demand using the rules stored in the
game state. This ensures each game uses its own configuration.
For deterministic replay support, we derive a unique seed per action
by combining the game's base seed with the action count. This ensures:
- Same game + same action sequence = identical RNG results
- Each action gets a unique but reproducible random sequence
Args:
game: The game state containing the rules to use.
Returns:
A GameEngine configured with the game's rules and RNG.
"""
if game.rng_seed is not None:
# Derive unique seed per action for deterministic replay
# Action count ensures each action gets different but reproducible RNG
action_count = len(game.action_log)
action_seed = game.rng_seed + action_count
rng = create_rng(seed=action_seed)
else:
# No seed - use cryptographically secure RNG
rng = create_rng()
return GameEngine(rules=game.rules, rng=rng)
# =========================================================================
# Game State Access
# =========================================================================
async def get_game_state(self, game_id: str) -> GameState:
"""Get the full game state.
Args:
game_id: The game ID.
Returns:
The GameState.
Raises:
GameNotFoundError: If game doesn't exist.
"""
state = await self._state_manager.load_state(game_id)
if state is None:
raise GameNotFoundError(game_id)
return state
async def get_player_view(
self,
game_id: str,
player_id: str,
) -> VisibleGameState:
"""Get the game state filtered for a specific player's view.
This applies visibility rules to hide opponent's hidden information
(hand, deck, prizes).
Args:
game_id: The game ID.
player_id: The player to get the view for.
Returns:
VisibleGameState with appropriate filtering.
Raises:
GameNotFoundError: If game doesn't exist.
PlayerNotInGameError: If player is not in the game.
"""
state = await self.get_game_state(game_id)
if player_id not in state.players:
raise PlayerNotInGameError(game_id, player_id)
return get_visible_state(state, player_id)
async def is_player_turn(self, game_id: str, player_id: str) -> bool:
"""Check if it's the specified player's turn.
Args:
game_id: The game ID.
player_id: The player to check.
Returns:
True if it's the player's turn.
Raises:
GameNotFoundError: If game doesn't exist.
"""
state = await self.get_game_state(game_id)
return state.current_player_id == player_id
async def game_exists(self, game_id: str) -> bool:
"""Check if a game exists.
Args:
game_id: The game ID.
Returns:
True if the game exists in cache or database.
"""
return await self._state_manager.cache_exists(game_id)
# =========================================================================
# Game Lifecycle
# =========================================================================
async def join_game(
self,
game_id: str,
player_id: str,
last_event_id: str | None = None,
) -> GameJoinResult:
"""Join or rejoin a game session.
Loads the game state and returns the player's visible view.
Used when a player connects or reconnects to a game.
Args:
game_id: The game to join.
player_id: The joining player's ID.
last_event_id: Last event ID for reconnection replay (future use).
Returns:
GameJoinResult with the visible state.
"""
try:
state = await self.get_game_state(game_id)
except GameNotFoundError:
return GameJoinResult(
success=False,
game_id=game_id,
player_id=player_id,
message="Game not found",
)
if player_id not in state.players:
return GameJoinResult(
success=False,
game_id=game_id,
player_id=player_id,
message="You are not a participant in this game",
)
# Check if game already ended
if state.winner_id is not None or state.end_reason is not None:
visible = get_visible_state(state, player_id)
return GameJoinResult(
success=True,
game_id=game_id,
player_id=player_id,
visible_state=visible,
is_your_turn=False,
message="Game has ended",
)
visible = get_visible_state(state, player_id)
logger.info(f"Player {player_id} joined game {game_id}")
return GameJoinResult(
success=True,
game_id=game_id,
player_id=player_id,
visible_state=visible,
is_your_turn=state.current_player_id == player_id,
)
async def execute_action(
self,
game_id: str,
player_id: str,
action: Action,
) -> GameActionResult:
"""Execute a player action in the game.
Validates the action, executes it through GameEngine, and
persists the updated state. The GameEngine is created using
the rules stored in the game state.
Args:
game_id: The game ID.
player_id: The acting player's ID.
action: The action to execute.
Returns:
GameActionResult with success status and state changes.
Raises:
GameNotFoundError: If game doesn't exist.
PlayerNotInGameError: If player is not in the game.
GameAlreadyEndedError: If game has already ended.
NotPlayerTurnError: If it's not the player's turn.
InvalidActionError: If the action is invalid.
"""
# Load game state
state = await self.get_game_state(game_id)
# Validate player is in game
if player_id not in state.players:
raise PlayerNotInGameError(game_id, player_id)
# Check game hasn't ended
if state.winner_id is not None or state.end_reason is not None:
raise GameAlreadyEndedError(game_id)
# Check it's player's turn (unless resignation, which can happen anytime)
if not isinstance(action, ResignAction) and state.current_player_id != player_id:
raise NotPlayerTurnError(game_id, player_id, state.current_player_id)
# Create engine with this game's rules
engine = self._create_engine_for_game(state)
# Execute the action
result: ActionResult = await engine.execute_action(state, player_id, action)
if not result.success:
raise InvalidActionError(game_id, player_id, result.message)
# Save state to cache (fast path)
await self._state_manager.save_to_cache(state)
# Check if turn ended - persist to DB at turn boundaries
# TODO: Implement turn boundary detection for DB persistence
# Build response
action_result = GameActionResult(
success=True,
game_id=game_id,
action_type=action.type,
message=result.message,
state_changes={
"changes": result.state_changes,
},
)
# Check for game over
if result.win_result is not None:
action_result.game_over = True
action_result.winner_id = result.win_result.winner_id
action_result.end_reason = result.win_result.end_reason
# Persist final state to DB
await self._state_manager.persist_to_db(state)
logger.info(
f"Game {game_id} ended: winner={result.win_result.winner_id}, "
f"reason={result.win_result.end_reason}"
)
logger.debug(f"Action executed: game={game_id}, player={player_id}, type={action.type}")
return action_result
async def resign_game(
self,
game_id: str,
player_id: str,
) -> GameActionResult:
"""Resign from a game.
Convenience method that executes a ResignAction.
Args:
game_id: The game ID.
player_id: The resigning player's ID.
Returns:
GameActionResult indicating game over.
"""
return await self.execute_action(
game_id=game_id,
player_id=player_id,
action=ResignAction(),
)
async def end_game(
self,
game_id: str,
winner_id: str | None,
end_reason: GameEndReason,
) -> None:
"""Forcibly end a game (e.g., due to timeout or disconnection).
This should be called by the timeout system or when a player
disconnects without reconnecting within the grace period.
Args:
game_id: The game ID.
winner_id: The winner's player ID, or None for a draw.
end_reason: Why the game ended.
Raises:
GameNotFoundError: If game doesn't exist.
"""
state = await self.get_game_state(game_id)
# Set winner and end reason
state.winner_id = winner_id
state.end_reason = end_reason
# Persist to both cache and DB
await self._state_manager.save_to_cache(state)
await self._state_manager.persist_to_db(state)
logger.info(f"Game {game_id} forcibly ended: winner={winner_id}, reason={end_reason}")
# =========================================================================
# Game Creation (Skeleton - Full implementation in GS-002)
# =========================================================================
async def create_game(
self,
player1_id: str | UUID,
player2_id: str | UUID,
deck1_id: str | UUID | None = None,
deck2_id: str | UUID | None = None,
# Rules come from the frontend request - this is required, not optional
# Defaulting to None here only for the skeleton; GS-002 will make it required
) -> None:
"""Create a new game between two players.
This is a skeleton that will be fully implemented in GS-002.
IMPORTANT: rules_config will be a required parameter - it comes
from the frontend request, not from server-side defaults.
Args:
player1_id: First player's ID.
player2_id: Second player's ID.
deck1_id: First player's deck ID.
deck2_id: Second player's deck ID.
Raises:
NotImplementedError: Until GS-002 is complete.
"""
# TODO (GS-002): Full implementation with:
# - rules_config: RulesConfig parameter (required, from frontend)
# - Load decks via DeckService
# - Load card registry from CardService
# - Convert to CardInstances with unique IDs
# - Create GameState with the provided rules_config
# - Persist to Redis and Postgres
raise NotImplementedError(
"Game creation not yet implemented - see GS-002. "
"Rules will come from frontend request, not server defaults."
)
# Global singleton instance
# Note: This is safe because GameService is stateless regarding game rules.
# Each game's rules are stored in its GameState, not in this service.
game_service = GameService()

View File

@ -0,0 +1,46 @@
"""Socket.IO module for real-time game communication.
This module provides WebSocket-based real-time communication for:
- Active game sessions (actions, state updates, turn notifications)
- Connection management with session recovery
- Turn timeout handling
- JWT-based authentication
Architecture:
- Uses python-socketio with ASGI mode for FastAPI integration
- /game namespace handles all active game communication
- /lobby namespace (Phase 6) will handle matchmaking
- JWT authentication on connect (WS-004)
Usage:
The Socket.IO server is mounted alongside FastAPI in app/main.py.
Clients connect to ws://host/socket.io/ and join the /game namespace.
Client connection example:
socket = io("ws://host", {
auth: { token: "JWT_ACCESS_TOKEN" }
});
"""
from app.socketio.auth import (
AuthResult,
authenticate_connection,
cleanup_authenticated_session,
get_session_user_id,
require_auth,
setup_authenticated_session,
)
from app.socketio.server import create_socketio_app, sio
__all__ = [
# Server
"create_socketio_app",
"sio",
# Auth
"AuthResult",
"authenticate_connection",
"cleanup_authenticated_session",
"get_session_user_id",
"require_auth",
"setup_authenticated_session",
]

View File

@ -0,0 +1,283 @@
"""Socket.IO authentication middleware for WebSocket connections.
This module provides JWT-based authentication for Socket.IO connections.
It validates access tokens and attaches user information to the socket session.
Authentication Flow:
1. Client connects with `auth: { token: "JWT_ACCESS_TOKEN" }`
2. Server extracts and validates the JWT
3. If valid, user_id is stored in socket session
4. If invalid, connection is rejected with appropriate error
Session Data:
After successful authentication, the socket session contains:
- user_id: str (UUID as string)
- authenticated_at: str (ISO timestamp)
Example:
# In connect handler:
auth_result = await authenticate_connection(sid, auth)
if not auth_result.success:
return False # Reject connection
# Later, get user_id:
session = await sio.get_session(sid, namespace="/game")
user_id = session.get("user_id")
"""
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from uuid import UUID
from app.services.connection_manager import connection_manager
from app.services.jwt_service import verify_access_token
logger = logging.getLogger(__name__)
@dataclass
class AuthResult:
"""Result of authentication attempt.
Attributes:
success: Whether authentication succeeded.
user_id: User's UUID if successful, None otherwise.
error_code: Error code for client if failed.
error_message: Human-readable error message if failed.
"""
success: bool
user_id: UUID | None = None
error_code: str | None = None
error_message: str | None = None
def extract_token(auth: dict[str, object] | None) -> str | None:
"""Extract JWT token from Socket.IO auth data.
Clients should send the token in the auth dict:
socket.connect({ auth: { token: "JWT_TOKEN" } })
Also supports:
- auth.authorization: "Bearer TOKEN"
- auth.access_token: "TOKEN"
Args:
auth: Authentication data from Socket.IO connect.
Returns:
JWT token string if found, None otherwise.
Example:
token = extract_token({"token": "eyJ..."})
token = extract_token({"authorization": "Bearer eyJ..."})
"""
if auth is None:
return None
# Primary: auth.token
token = auth.get("token")
if token and isinstance(token, str):
return token
# Alternative: auth.authorization (Bearer token)
authorization = auth.get("authorization")
if authorization and isinstance(authorization, str):
if authorization.lower().startswith("bearer "):
return authorization[7:]
return authorization
# Alternative: auth.access_token
access_token = auth.get("access_token")
if access_token and isinstance(access_token, str):
return access_token
return None
async def authenticate_connection(
sid: str,
auth: dict[str, object] | None,
) -> AuthResult:
"""Authenticate a Socket.IO connection using JWT.
Extracts the JWT from auth data, validates it, and returns the result.
Does NOT modify socket session - caller should handle that.
Args:
sid: Socket session ID (for logging).
auth: Authentication data from connect event.
Returns:
AuthResult with success status and user_id or error details.
Example:
result = await authenticate_connection(sid, auth)
if result.success:
await sio.save_session(sid, {"user_id": str(result.user_id)})
else:
logger.warning(f"Auth failed: {result.error_message}")
return False # Reject connection
"""
# Extract token from auth data
token = extract_token(auth)
if token is None:
logger.debug(f"Connection {sid}: No token provided")
return AuthResult(
success=False,
error_code="missing_token",
error_message="Authentication token required",
)
# Validate the token
user_id = verify_access_token(token)
if user_id is None:
logger.debug(f"Connection {sid}: Invalid or expired token")
return AuthResult(
success=False,
error_code="invalid_token",
error_message="Invalid or expired token",
)
logger.debug(f"Connection {sid}: Authenticated as user {user_id}")
return AuthResult(
success=True,
user_id=user_id,
)
async def setup_authenticated_session(
sio: object,
sid: str,
user_id: UUID,
namespace: str = "/game",
) -> None:
"""Set up socket session with authenticated user data.
Saves user_id and authentication timestamp to the socket session,
and registers the connection with ConnectionManager.
Args:
sio: Socket.IO AsyncServer instance.
sid: Socket session ID.
user_id: Authenticated user's UUID.
namespace: Socket.IO namespace.
Example:
if auth_result.success:
await setup_authenticated_session(sio, sid, auth_result.user_id)
"""
# Import here to avoid circular dependency
from app.socketio.server import sio as server_sio
# Use provided sio or fall back to server sio
socket_server = sio if sio is not None else server_sio
# Save to socket session
session_data = {
"user_id": str(user_id),
"authenticated_at": datetime.now(UTC).isoformat(),
}
await socket_server.save_session(sid, session_data, namespace=namespace)
# Register with ConnectionManager
await connection_manager.register_connection(sid, user_id)
logger.info(f"Session established: sid={sid}, user_id={user_id}")
async def cleanup_authenticated_session(
sid: str,
namespace: str = "/game",
) -> str | None:
"""Clean up session data on disconnect.
Unregisters the connection from ConnectionManager and returns
the user_id for any additional cleanup needed.
Args:
sid: Socket session ID.
namespace: Socket.IO namespace.
Returns:
user_id if session was authenticated, None otherwise.
Example:
user_id = await cleanup_authenticated_session(sid)
if user_id:
# Notify opponent, etc.
"""
# Unregister from ConnectionManager
conn_info = await connection_manager.unregister_connection(sid)
if conn_info:
logger.info(f"Session cleaned up: sid={sid}, user_id={conn_info.user_id}")
return conn_info.user_id
logger.debug(f"No session to clean up for {sid}")
return None
async def get_session_user_id(
sio: object,
sid: str,
namespace: str = "/game",
) -> str | None:
"""Get the authenticated user_id from a socket session.
Convenience function to extract user_id from session data.
Args:
sio: Socket.IO AsyncServer instance.
sid: Socket session ID.
namespace: Socket.IO namespace.
Returns:
user_id string if authenticated, None otherwise.
Example:
user_id = await get_session_user_id(sio, sid)
if not user_id:
await sio.emit("error", {"message": "Not authenticated"}, to=sid)
return
"""
try:
session = await sio.get_session(sid, namespace=namespace)
return session.get("user_id") if session else None
except Exception:
return None
async def require_auth(
sio: object,
sid: str,
namespace: str = "/game",
) -> str | None:
"""Require authentication for an event handler.
Returns the user_id if authenticated, None if not.
Logs a warning if authentication is missing.
Args:
sio: Socket.IO AsyncServer instance.
sid: Socket session ID.
namespace: Socket.IO namespace.
Returns:
user_id string if authenticated, None otherwise.
Example:
@sio.on("game:action", namespace="/game")
async def on_action(sid, data):
user_id = await require_auth(sio, sid)
if not user_id:
return {"error": "Not authenticated"}
# ... handle action
"""
user_id = await get_session_user_id(sio, sid, namespace)
if user_id is None:
logger.warning(f"Unauthenticated event from {sid}")
return user_id

View File

@ -0,0 +1,240 @@
"""Socket.IO server configuration and ASGI app creation.
This module sets up the python-socketio AsyncServer and creates the combined
ASGI application that mounts Socket.IO alongside FastAPI.
Architecture:
- AsyncServer handles WebSocket connections with async_mode='asgi'
- Socket.IO app is mounted at /socket.io path
- CORS settings match FastAPI configuration
- JWT authentication on connect via auth.py
- Namespaces are registered for different communication domains
Namespaces:
/game - Active game communication (actions, state updates)
/lobby - Pre-game lobby (matchmaking, invites) - Phase 6
Authentication:
Clients must provide a JWT access token in the auth parameter:
socket.connect({ auth: { token: "JWT_ACCESS_TOKEN" } })
Example:
from app.socketio import create_socketio_app
from fastapi import FastAPI
fastapi_app = FastAPI()
combined_app = create_socketio_app(fastapi_app)
# Run with: uvicorn app.main:app
"""
import logging
from typing import TYPE_CHECKING
import socketio
from app.config import settings
from app.socketio.auth import (
authenticate_connection,
cleanup_authenticated_session,
require_auth,
setup_authenticated_session,
)
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__)
# Create the AsyncServer instance
# - async_mode='asgi' for ASGI compatibility with uvicorn
# - cors_allowed_origins matches FastAPI CORS settings
# - logger enables Socket.IO internal logging in debug mode
sio = socketio.AsyncServer(
async_mode="asgi",
cors_allowed_origins=settings.cors_origins if settings.cors_origins else [],
logger=settings.debug,
engineio_logger=settings.debug,
)
# =============================================================================
# /game Namespace - Active Game Communication
# =============================================================================
# These are skeleton handlers that will be fully implemented in WS-005.
# For now, they provide basic connection lifecycle handling.
@sio.event(namespace="/game")
async def connect(
sid: str, environ: dict[str, object], auth: dict[str, object] | None = None
) -> bool | None:
"""Handle client connection to /game namespace.
Authenticates the connection using JWT from auth data.
Rejects connections without valid authentication.
Args:
sid: Socket session ID assigned by Socket.IO.
environ: WSGI/ASGI environ dict with request info.
auth: Authentication data sent by client (JWT token).
Expected format: { token: "JWT_ACCESS_TOKEN" }
Returns:
True to accept connection, False to reject.
None is treated as True (accept).
"""
# Authenticate the connection
auth_result = await authenticate_connection(sid, auth)
if not auth_result.success:
logger.warning(
f"Connection rejected for {sid}: {auth_result.error_code} - {auth_result.error_message}"
)
# Emit error before rejecting (client may receive this)
await sio.emit(
"auth_error",
{
"code": auth_result.error_code,
"message": auth_result.error_message,
},
to=sid,
namespace="/game",
)
return False
# Set up authenticated session and register connection
await setup_authenticated_session(sio, sid, auth_result.user_id, namespace="/game")
logger.info(f"Client authenticated to /game: sid={sid}, user_id={auth_result.user_id}")
return True
@sio.event(namespace="/game")
async def disconnect(sid: str) -> None:
"""Handle client disconnection from /game namespace.
Cleans up connection state and notifies other game participants.
Args:
sid: Socket session ID of disconnecting client.
"""
# Clean up session and get user info
user_id = await cleanup_authenticated_session(sid, namespace="/game")
if user_id:
logger.info(f"Client disconnected from /game: sid={sid}, user_id={user_id}")
# TODO (WS-005): Notify opponent of disconnection if in game
else:
logger.debug(f"Unauthenticated client disconnected: {sid}")
@sio.on("game:join", namespace="/game")
async def on_game_join(sid: str, data: dict[str, object]) -> dict[str, object]:
"""Handle request to join/rejoin a game session.
Args:
sid: Socket session ID.
data: Message containing game_id and optional last_event_id for resume.
Returns:
Response with game state or error.
"""
logger.debug(f"game:join from {sid}: {data}")
# TODO (WS-005): Implement with GameService
return {"error": "Not implemented yet"}
@sio.on("game:action", namespace="/game")
async def on_game_action(sid: str, data: dict[str, object]) -> dict[str, object]:
"""Handle game action from player.
Args:
sid: Socket session ID.
data: Action message with type and parameters.
Returns:
Action result or error.
"""
logger.debug(f"game:action from {sid}: {data}")
# TODO (WS-005): Implement with GameService
return {"error": "Not implemented yet"}
@sio.on("game:resign", namespace="/game")
async def on_game_resign(sid: str, data: dict[str, object]) -> dict[str, object]:
"""Handle player resignation.
Args:
sid: Socket session ID.
data: Resignation message (may be empty).
Returns:
Confirmation or error.
"""
logger.debug(f"game:resign from {sid}: {data}")
# TODO (WS-005): Implement with GameService
return {"error": "Not implemented yet"}
@sio.on("game:heartbeat", namespace="/game")
async def on_game_heartbeat(sid: str, data: dict[str, object] | None = None) -> dict[str, object]:
"""Handle heartbeat to keep connection alive.
Updates last_seen timestamp in ConnectionManager to prevent
the connection from being marked as stale.
Args:
sid: Socket session ID.
data: Optional heartbeat data (message_id for tracking).
Returns:
Heartbeat acknowledgment with server timestamp.
"""
from datetime import UTC, datetime
from app.services.connection_manager import connection_manager
# Require authentication
user_id = await require_auth(sio, sid)
if not user_id:
return {"error": "Not authenticated", "code": "unauthenticated"}
# Update last_seen in ConnectionManager
await connection_manager.update_heartbeat(sid)
# Return acknowledgment with timestamp
return {
"type": "heartbeat_ack",
"timestamp": datetime.now(UTC).isoformat(),
"message_id": data.get("message_id") if data else None,
}
# =============================================================================
# ASGI App Creation
# =============================================================================
def create_socketio_app(fastapi_app: "FastAPI") -> socketio.ASGIApp:
"""Create combined ASGI app with Socket.IO mounted alongside FastAPI.
This wraps the FastAPI app with Socket.IO, handling:
- WebSocket connections at /socket.io/
- HTTP requests passed through to FastAPI
Args:
fastapi_app: The FastAPI application instance.
Returns:
Combined ASGI application to use with uvicorn.
Example:
app = FastAPI()
combined = create_socketio_app(app)
# uvicorn will use 'combined' as the ASGI app
"""
return socketio.ASGIApp(
sio,
other_asgi_app=fastapi_app,
socketio_path="socket.io",
)

View File

@ -9,8 +9,8 @@
"description": "Real-time gameplay infrastructure - WebSocket communication, game lifecycle management, reconnection handling, and turn timeout system", "description": "Real-time gameplay infrastructure - WebSocket communication, game lifecycle management, reconnection handling, and turn timeout system",
"totalEstimatedHours": 45, "totalEstimatedHours": 45,
"totalTasks": 18, "totalTasks": 18,
"completedTasks": 0, "completedTasks": 5,
"status": "not_started", "status": "in_progress",
"masterPlan": "../PROJECT_PLAN_MASTER.json" "masterPlan": "../PROJECT_PLAN_MASTER.json"
}, },
@ -105,8 +105,8 @@
"description": "Install and configure python-socketio ASGI server, mount alongside FastAPI app", "description": "Install and configure python-socketio ASGI server, mount alongside FastAPI app",
"category": "infrastructure", "category": "infrastructure",
"priority": 1, "priority": 1,
"completed": false, "completed": true,
"tested": false, "tested": true,
"dependencies": [], "dependencies": [],
"files": [ "files": [
{"path": "app/socketio/__init__.py", "status": "create"}, {"path": "app/socketio/__init__.py", "status": "create"},
@ -130,8 +130,8 @@
"description": "Define Pydantic models for all WebSocket message types", "description": "Define Pydantic models for all WebSocket message types",
"category": "schemas", "category": "schemas",
"priority": 2, "priority": 2,
"completed": false, "completed": true,
"tested": false, "tested": true,
"dependencies": ["WS-001"], "dependencies": ["WS-001"],
"files": [ "files": [
{"path": "app/schemas/ws_messages.py", "status": "create"} {"path": "app/schemas/ws_messages.py", "status": "create"}
@ -153,8 +153,8 @@
"description": "Manage WebSocket connections with Redis-backed session tracking", "description": "Manage WebSocket connections with Redis-backed session tracking",
"category": "services", "category": "services",
"priority": 3, "priority": 3,
"completed": false, "completed": true,
"tested": false, "tested": true,
"dependencies": ["WS-001"], "dependencies": ["WS-001"],
"files": [ "files": [
{"path": "app/services/connection_manager.py", "status": "create"} {"path": "app/services/connection_manager.py", "status": "create"}
@ -177,8 +177,8 @@
"description": "Authenticate WebSocket connections using JWT tokens", "description": "Authenticate WebSocket connections using JWT tokens",
"category": "auth", "category": "auth",
"priority": 4, "priority": 4,
"completed": false, "completed": true,
"tested": false, "tested": true,
"dependencies": ["WS-001", "WS-003"], "dependencies": ["WS-001", "WS-003"],
"files": [ "files": [
{"path": "app/socketio/auth.py", "status": "create"}, {"path": "app/socketio/auth.py", "status": "create"},
@ -614,6 +614,13 @@
"risk": "Turn timeout drift due to server restart", "risk": "Turn timeout drift due to server restart",
"mitigation": "Store absolute deadline in Redis/Postgres, recalculate on startup", "mitigation": "Store absolute deadline in Redis/Postgres, recalculate on startup",
"priority": "medium" "priority": "medium"
},
{
"risk": "ConnectionManager race condition on rapid reconnects",
"mitigation": "Consider Redis Lua script for atomic old-connection cleanup in register_connection. Low probability in practice but could cause connection tracking issues during rapid connect/disconnect cycles.",
"priority": "low",
"status": "identified-in-review",
"notes": "Also consider: periodic cleanup of game_conns sets for long-running games, rate limiting in auth layer"
} }
], ],

View File

@ -107,6 +107,11 @@ module = [
] ]
ignore_missing_imports = true ignore_missing_imports = true
[[tool.mypy.overrides]]
# Socket.IO handlers use untyped decorators from the library
module = "app.socketio.*"
disallow_untyped_decorators = false
# Coverage configuration # Coverage configuration
[tool.coverage.run] [tool.coverage.run]
source = ["app"] source = ["app"]

View File

@ -0,0 +1 @@
# Socket.IO integration tests

View File

@ -0,0 +1,384 @@
"""Tests for Socket.IO authentication middleware.
This module tests JWT-based authentication for WebSocket connections,
including token extraction, validation, and session management.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from app.socketio.auth import (
AuthResult,
authenticate_connection,
cleanup_authenticated_session,
extract_token,
get_session_user_id,
require_auth,
setup_authenticated_session,
)
class TestExtractToken:
"""Tests for token extraction from Socket.IO auth data."""
def test_extract_token_from_token_field(self) -> None:
"""Test extracting token from the primary 'token' field.
The standard way for clients to pass the JWT is via auth.token.
This is the recommended format for Socket.IO clients.
"""
auth = {"token": "my-jwt-token"}
result = extract_token(auth)
assert result == "my-jwt-token"
def test_extract_token_from_authorization_bearer(self) -> None:
"""Test extracting token from Bearer authorization header format.
Some clients may pass the token as a Bearer token for consistency
with HTTP API authentication patterns.
"""
auth = {"authorization": "Bearer my-jwt-token"}
result = extract_token(auth)
assert result == "my-jwt-token"
def test_extract_token_from_authorization_without_bearer(self) -> None:
"""Test extracting token from authorization without Bearer prefix.
If the client provides just the token in authorization field,
we should still accept it.
"""
auth = {"authorization": "my-jwt-token"}
result = extract_token(auth)
assert result == "my-jwt-token"
def test_extract_token_from_access_token_field(self) -> None:
"""Test extracting token from access_token field.
Alternative field name for OAuth-style clients.
"""
auth = {"access_token": "my-jwt-token"}
result = extract_token(auth)
assert result == "my-jwt-token"
def test_extract_token_returns_none_for_none_auth(self) -> None:
"""Test that None auth data returns None token.
Clients that don't provide any auth should get None,
triggering the authentication failure path.
"""
result = extract_token(None)
assert result is None
def test_extract_token_returns_none_for_empty_auth(self) -> None:
"""Test that empty auth dict returns None token.
An empty auth object should be treated as unauthenticated.
"""
result = extract_token({})
assert result is None
def test_extract_token_returns_none_for_non_string_token(self) -> None:
"""Test that non-string token values are rejected.
Only string tokens are valid - reject numbers, objects, etc.
"""
result = extract_token({"token": 12345})
assert result is None
result = extract_token({"token": {"nested": "value"}})
assert result is None
def test_extract_token_prefers_token_field(self) -> None:
"""Test that 'token' field takes precedence over alternatives.
If multiple token fields are present, we should use the
primary 'token' field.
"""
auth = {
"token": "primary-token",
"authorization": "Bearer secondary-token",
"access_token": "tertiary-token",
}
result = extract_token(auth)
assert result == "primary-token"
class TestAuthenticateConnection:
"""Tests for connection authentication."""
@pytest.mark.asyncio
async def test_authenticate_success_with_valid_token(self) -> None:
"""Test successful authentication with a valid JWT.
A valid access token should result in AuthResult with
success=True and the user_id from the token.
"""
user_id = uuid4()
with patch("app.socketio.auth.verify_access_token") as mock_verify:
mock_verify.return_value = user_id
result = await authenticate_connection("test-sid", {"token": "valid-token"})
assert result.success is True
assert result.user_id == user_id
assert result.error_code is None
mock_verify.assert_called_once_with("valid-token")
@pytest.mark.asyncio
async def test_authenticate_fails_with_missing_token(self) -> None:
"""Test authentication failure when no token is provided.
Connections without any auth data should fail with
a 'missing_token' error code.
"""
result = await authenticate_connection("test-sid", None)
assert result.success is False
assert result.user_id is None
assert result.error_code == "missing_token"
assert "required" in result.error_message.lower()
@pytest.mark.asyncio
async def test_authenticate_fails_with_empty_auth(self) -> None:
"""Test authentication failure with empty auth object.
An auth object without any token fields should fail.
"""
result = await authenticate_connection("test-sid", {})
assert result.success is False
assert result.error_code == "missing_token"
@pytest.mark.asyncio
async def test_authenticate_fails_with_invalid_token(self) -> None:
"""Test authentication failure with invalid/expired JWT.
When verify_access_token returns None (invalid token),
we should fail with 'invalid_token' error.
"""
with patch("app.socketio.auth.verify_access_token") as mock_verify:
mock_verify.return_value = None # Token validation failed
result = await authenticate_connection("test-sid", {"token": "invalid-token"})
assert result.success is False
assert result.user_id is None
assert result.error_code == "invalid_token"
assert (
"invalid" in result.error_message.lower()
or "expired" in result.error_message.lower()
)
@pytest.mark.asyncio
async def test_authenticate_extracts_token_from_bearer(self) -> None:
"""Test that authentication works with Bearer format.
The authenticate function should handle Bearer token format
through the extract_token function.
"""
user_id = uuid4()
with patch("app.socketio.auth.verify_access_token") as mock_verify:
mock_verify.return_value = user_id
result = await authenticate_connection("test-sid", {"authorization": "Bearer my-token"})
assert result.success is True
mock_verify.assert_called_once_with("my-token")
class TestSetupAuthenticatedSession:
"""Tests for session setup after authentication."""
@pytest.mark.asyncio
async def test_setup_saves_session_data(self) -> None:
"""Test that session setup saves user_id and timestamp.
After authentication, the socket session should contain
the user_id and authentication timestamp.
"""
user_id = uuid4()
mock_sio = AsyncMock()
with patch("app.socketio.auth.connection_manager") as mock_cm:
mock_cm.register_connection = AsyncMock()
await setup_authenticated_session(mock_sio, "test-sid", user_id)
# Verify session was saved
mock_sio.save_session.assert_called_once()
call_args = mock_sio.save_session.call_args
assert call_args.args[0] == "test-sid"
session_data = call_args.args[1]
assert session_data["user_id"] == str(user_id)
assert "authenticated_at" in session_data
@pytest.mark.asyncio
async def test_setup_registers_with_connection_manager(self) -> None:
"""Test that session setup registers with ConnectionManager.
The connection should be tracked in ConnectionManager for
presence detection and game association.
"""
user_id = uuid4()
mock_sio = AsyncMock()
with patch("app.socketio.auth.connection_manager") as mock_cm:
mock_cm.register_connection = AsyncMock()
await setup_authenticated_session(mock_sio, "test-sid", user_id)
mock_cm.register_connection.assert_called_once_with("test-sid", user_id)
class TestCleanupAuthenticatedSession:
"""Tests for session cleanup on disconnect."""
@pytest.mark.asyncio
async def test_cleanup_unregisters_connection(self) -> None:
"""Test that cleanup unregisters from ConnectionManager.
On disconnect, the connection should be removed from
ConnectionManager to update presence tracking.
"""
with patch("app.socketio.auth.connection_manager") as mock_cm:
mock_conn_info = MagicMock()
mock_conn_info.user_id = "user-123"
mock_cm.unregister_connection = AsyncMock(return_value=mock_conn_info)
result = await cleanup_authenticated_session("test-sid")
assert result == "user-123"
mock_cm.unregister_connection.assert_called_once_with("test-sid")
@pytest.mark.asyncio
async def test_cleanup_returns_none_for_unknown_session(self) -> None:
"""Test cleanup returns None for non-existent sessions.
If the connection wasn't registered (e.g., auth failed),
cleanup should return None gracefully.
"""
with patch("app.socketio.auth.connection_manager") as mock_cm:
mock_cm.unregister_connection = AsyncMock(return_value=None)
result = await cleanup_authenticated_session("unknown-sid")
assert result is None
class TestGetSessionUserId:
"""Tests for session user_id retrieval."""
@pytest.mark.asyncio
async def test_get_session_user_id_returns_id(self) -> None:
"""Test retrieving user_id from authenticated session.
For authenticated sessions, get_session_user_id should
return the stored user_id string.
"""
mock_sio = AsyncMock()
mock_sio.get_session = AsyncMock(
return_value={"user_id": "user-123", "authenticated_at": "2024-01-01"}
)
result = await get_session_user_id(mock_sio, "test-sid")
assert result == "user-123"
mock_sio.get_session.assert_called_once_with("test-sid", namespace="/game")
@pytest.mark.asyncio
async def test_get_session_user_id_returns_none_for_missing(self) -> None:
"""Test that missing session returns None.
If no session exists for the sid, we should return None
rather than raising an error.
"""
mock_sio = AsyncMock()
mock_sio.get_session = AsyncMock(return_value=None)
result = await get_session_user_id(mock_sio, "test-sid")
assert result is None
@pytest.mark.asyncio
async def test_get_session_user_id_handles_exception(self) -> None:
"""Test that exceptions are caught and return None.
If get_session raises an exception, we should catch it
and return None to avoid breaking the event handler.
"""
mock_sio = AsyncMock()
mock_sio.get_session = AsyncMock(side_effect=Exception("Session error"))
result = await get_session_user_id(mock_sio, "test-sid")
assert result is None
class TestRequireAuth:
"""Tests for the require_auth helper."""
@pytest.mark.asyncio
async def test_require_auth_returns_user_id_for_authenticated(self) -> None:
"""Test that require_auth returns user_id for valid sessions.
Authenticated sessions should return the user_id for use
in event handlers.
"""
mock_sio = AsyncMock()
mock_sio.get_session = AsyncMock(return_value={"user_id": "user-123"})
result = await require_auth(mock_sio, "test-sid")
assert result == "user-123"
@pytest.mark.asyncio
async def test_require_auth_returns_none_for_unauthenticated(self) -> None:
"""Test that require_auth returns None for unauthenticated sessions.
Unauthenticated events should get None, allowing handlers
to return an error response.
"""
mock_sio = AsyncMock()
mock_sio.get_session = AsyncMock(return_value=None)
result = await require_auth(mock_sio, "test-sid")
assert result is None
class TestAuthResultDataclass:
"""Tests for the AuthResult dataclass."""
def test_auth_result_success(self) -> None:
"""Test creating a successful AuthResult.
Success results should have user_id and no error fields.
"""
user_id = uuid4()
result = AuthResult(success=True, user_id=user_id)
assert result.success is True
assert result.user_id == user_id
assert result.error_code is None
assert result.error_message is None
def test_auth_result_failure(self) -> None:
"""Test creating a failed AuthResult.
Failure results should have error code/message and no user_id.
"""
result = AuthResult(
success=False,
error_code="invalid_token",
error_message="Token expired",
)
assert result.success is False
assert result.user_id is None
assert result.error_code == "invalid_token"
assert result.error_message == "Token expired"

View File

@ -0,0 +1,113 @@
"""Tests for Socket.IO server setup and configuration.
These tests verify that the Socket.IO server is correctly configured
and can be mounted alongside FastAPI.
"""
class TestSocketIOSetup:
"""Tests for Socket.IO server initialization."""
def test_sio_server_exists(self) -> None:
"""
Verify that the Socket.IO AsyncServer is created.
The server instance should be available as a module-level export
and configured for ASGI mode.
"""
from app.socketio import sio
assert sio is not None
assert sio.async_mode == "asgi"
def test_game_namespace_handlers_registered(self) -> None:
"""
Verify that /game namespace event handlers are registered.
All required event handlers should be available on the sio instance
before any connections are made.
"""
from app.socketio import sio
# Check that handlers are registered for /game namespace
assert "/game" in sio.handlers
game_handlers = sio.handlers["/game"]
expected_events = [
"connect",
"disconnect",
"game:join",
"game:action",
"game:resign",
"game:heartbeat",
]
for event in expected_events:
assert event in game_handlers, f"Missing handler for {event}"
def test_create_socketio_app_returns_asgi_app(self) -> None:
"""
Verify that create_socketio_app returns a valid ASGI application.
The combined app should wrap the FastAPI app and handle both
HTTP and WebSocket connections.
"""
import socketio
from fastapi import FastAPI
from app.socketio import create_socketio_app
# Create a minimal FastAPI app for testing
test_app = FastAPI()
# Create combined ASGI app
combined = create_socketio_app(test_app)
# Verify it's a Socket.IO ASGI app
assert isinstance(combined, socketio.ASGIApp)
def test_cors_configured_from_settings(self) -> None:
"""
Verify that Socket.IO CORS settings match application settings.
The Socket.IO server should allow the same origins as FastAPI
to ensure consistent cross-origin behavior.
"""
from app.config import settings
from app.socketio import sio
# Socket.IO stores CORS origins in eio (Engine.IO) settings
# The cors_allowed_origins should match settings
assert sio.eio.cors_allowed_origins == settings.cors_origins
class TestMainAppIntegration:
"""Tests for Socket.IO integration with main app."""
def test_main_app_is_combined_asgi_app(self) -> None:
"""
Verify that the main app module exports the combined ASGI app.
The 'app' variable in main.py should be the Socket.IO wrapped
FastAPI application, not the raw FastAPI app.
"""
import socketio
from app.main import app
# The exported 'app' should be a Socket.IO ASGI app
assert isinstance(app, socketio.ASGIApp)
def test_fastapi_app_accessible(self) -> None:
"""
Verify that the underlying FastAPI app is still accessible.
The FastAPI app should be available as _fastapi_app in main
for testing and direct access when needed.
"""
from fastapi import FastAPI
from app.main import _fastapi_app
assert isinstance(_fastapi_app, FastAPI)
assert _fastapi_app.title == "Mantimon TCG"

View File

@ -0,0 +1 @@
# Unit tests for schemas

View File

@ -0,0 +1,701 @@
"""Tests for WebSocket message schemas.
This module tests the Pydantic models for WebSocket communication, verifying
discriminated union parsing, field validation, and serialization behavior.
"""
from datetime import UTC, datetime
import pytest
from pydantic import TypeAdapter, ValidationError
from app.core.enums import GameEndReason, TurnPhase
from app.core.models.actions import AttackAction, PassAction
from app.core.visibility import VisibleGameState, VisiblePlayerState
from app.schemas.ws_messages import (
ActionMessage,
ActionResultMessage,
ClientMessage,
ConnectionStatus,
ErrorMessage,
GameOverMessage,
GameStateMessage,
HeartbeatAckMessage,
HeartbeatMessage,
JoinGameMessage,
OpponentStatusMessage,
ResignMessage,
ServerMessage,
TurnStartMessage,
TurnTimeoutMessage,
WSErrorCode,
parse_client_message,
parse_server_message,
)
class TestClientMessageDiscriminator:
"""Tests for client message discriminated union parsing.
The discriminated union pattern allows automatic type resolution based on
the 'type' field, enabling type-safe message handling without explicit
type checking logic.
"""
def test_parse_join_game_message(self) -> None:
"""Test parsing JoinGameMessage from dictionary.
JoinGameMessage is the first message clients send to establish their
presence in a game session. The discriminator should correctly resolve
the type from the raw dictionary.
"""
data = {
"type": "join_game",
"message_id": "test-123",
"game_id": "game-456",
}
message = parse_client_message(data)
assert isinstance(message, JoinGameMessage)
assert message.game_id == "game-456"
assert message.message_id == "test-123"
assert message.last_event_id is None
def test_parse_join_game_with_last_event_id(self) -> None:
"""Test JoinGameMessage with reconnection event ID.
When reconnecting after a disconnect, clients provide last_event_id
to receive any missed events since that point, enabling seamless
session recovery.
"""
data = {
"type": "join_game",
"message_id": "test-123",
"game_id": "game-456",
"last_event_id": "event-789",
}
message = parse_client_message(data)
assert isinstance(message, JoinGameMessage)
assert message.last_event_id == "event-789"
def test_parse_action_message(self) -> None:
"""Test parsing ActionMessage with embedded game action.
ActionMessage wraps the Action union type, so the discriminator must
handle both the message type and the nested action type correctly.
"""
data = {
"type": "action",
"message_id": "test-123",
"game_id": "game-456",
"action": {"type": "attack", "attack_index": 0},
}
message = parse_client_message(data)
assert isinstance(message, ActionMessage)
assert message.game_id == "game-456"
assert isinstance(message.action, AttackAction)
assert message.action.attack_index == 0
def test_parse_action_message_with_pass(self) -> None:
"""Test parsing ActionMessage with PassAction.
PassAction is the simplest action type - verifies that minimal
action payloads are handled correctly.
"""
data = {
"type": "action",
"message_id": "test-123",
"game_id": "game-456",
"action": {"type": "pass"},
}
message = parse_client_message(data)
assert isinstance(message, ActionMessage)
assert isinstance(message.action, PassAction)
def test_parse_resign_message(self) -> None:
"""Test parsing ResignMessage.
ResignMessage allows resignation at the WebSocket layer, separate from
the game engine's ResignAction, for cases where engine interaction
may not be possible.
"""
data = {
"type": "resign",
"message_id": "test-123",
"game_id": "game-456",
}
message = parse_client_message(data)
assert isinstance(message, ResignMessage)
assert message.game_id == "game-456"
def test_parse_heartbeat_message(self) -> None:
"""Test parsing HeartbeatMessage.
Heartbeat messages keep the WebSocket connection alive and should
be the most minimal message type with just the type discriminator.
"""
data = {
"type": "heartbeat",
"message_id": "test-123",
}
message = parse_client_message(data)
assert isinstance(message, HeartbeatMessage)
def test_unknown_message_type_raises_error(self) -> None:
"""Test that unknown message types raise ValidationError.
The discriminated union should fail fast on unknown types rather
than silently accepting invalid messages.
"""
data = {
"type": "unknown_type",
"message_id": "test-123",
}
with pytest.raises(ValidationError):
parse_client_message(data)
def test_message_id_auto_generated(self) -> None:
"""Test that message_id is auto-generated if not provided.
Clients may omit message_id in which case a UUID should be generated,
ensuring all messages have a unique identifier for tracking.
"""
data = {
"type": "heartbeat",
}
message = parse_client_message(data)
assert message.message_id is not None
assert len(message.message_id) > 0
def test_empty_message_id_rejected(self) -> None:
"""Test that empty message_id is rejected.
An explicitly empty message_id would break idempotency tracking,
so it should be rejected by validation.
"""
data = {
"type": "heartbeat",
"message_id": "",
}
with pytest.raises(ValidationError):
parse_client_message(data)
def test_type_adapter_client_message(self) -> None:
"""Test ClientMessage union with TypeAdapter.
TypeAdapter is the recommended way to work with discriminated unions
in Pydantic v2 - verifies it works for both parsing and validation.
"""
adapter = TypeAdapter(ClientMessage)
data = {"type": "heartbeat", "message_id": "test"}
message = adapter.validate_python(data)
assert isinstance(message, HeartbeatMessage)
class TestServerMessageDiscriminator:
"""Tests for server message discriminated union parsing.
Server messages are more complex than client messages, often containing
nested game state. The discriminator must handle all message types and
their embedded data correctly.
"""
def test_parse_game_state_message(self) -> None:
"""Test parsing GameStateMessage with full visible state.
GameStateMessage carries the complete game state from a player's
perspective, used for initial sync and full state updates.
"""
visible_state = VisibleGameState(
game_id="game-456",
viewer_id="player-1",
players={
"player-1": VisiblePlayerState(player_id="player-1"),
"player-2": VisiblePlayerState(player_id="player-2"),
},
current_player_id="player-1",
turn_number=1,
phase=TurnPhase.MAIN,
is_my_turn=True,
)
message = GameStateMessage(
game_id="game-456",
state=visible_state,
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, GameStateMessage)
assert parsed.game_id == "game-456"
assert parsed.state.viewer_id == "player-1"
assert parsed.event_id is not None
def test_parse_action_result_success(self) -> None:
"""Test parsing successful ActionResultMessage.
Successful action results should include the changes that occurred
as a result of the action for client-side state updates.
"""
message = ActionResultMessage(
game_id="game-456",
request_message_id="action-123",
success=True,
action_type="attack",
changes={"damage_dealt": 30, "defender_hp": 70},
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, ActionResultMessage)
assert parsed.success is True
assert parsed.changes["damage_dealt"] == 30
assert parsed.error_code is None
def test_parse_action_result_failure(self) -> None:
"""Test parsing failed ActionResultMessage.
Failed action results should include error code and message
for client-side error handling and user feedback.
"""
message = ActionResultMessage(
game_id="game-456",
request_message_id="action-123",
success=False,
action_type="attack",
error_code=WSErrorCode.NOT_YOUR_TURN,
error_message="It's not your turn",
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, ActionResultMessage)
assert parsed.success is False
assert parsed.error_code == WSErrorCode.NOT_YOUR_TURN
def test_parse_error_message(self) -> None:
"""Test parsing ErrorMessage.
Error messages are for protocol-level errors not tied to a specific
action, like authentication failures or invalid message format.
"""
message = ErrorMessage(
code=WSErrorCode.AUTHENTICATION_FAILED,
message="Invalid token",
details={"reason": "expired"},
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, ErrorMessage)
assert parsed.code == WSErrorCode.AUTHENTICATION_FAILED
assert parsed.details["reason"] == "expired"
def test_parse_turn_start_message(self) -> None:
"""Test parsing TurnStartMessage.
TurnStartMessage notifies both players when a turn begins,
enabling UI updates like starting a turn timer.
"""
message = TurnStartMessage(
game_id="game-456",
player_id="player-1",
turn_number=5,
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, TurnStartMessage)
assert parsed.turn_number == 5
assert parsed.player_id == "player-1"
def test_parse_turn_timeout_warning(self) -> None:
"""Test parsing TurnTimeoutMessage as warning.
Timeout warnings give players a chance to complete their action
before the turn is forcibly ended.
"""
message = TurnTimeoutMessage(
game_id="game-456",
remaining_seconds=30,
is_warning=True,
player_id="player-1",
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, TurnTimeoutMessage)
assert parsed.is_warning is True
assert parsed.remaining_seconds == 30
def test_parse_turn_timeout_expired(self) -> None:
"""Test parsing TurnTimeoutMessage as expired.
When is_warning is False, the timeout has occurred and the game
engine will force the turn to end.
"""
message = TurnTimeoutMessage(
game_id="game-456",
remaining_seconds=0,
is_warning=False,
player_id="player-1",
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, TurnTimeoutMessage)
assert parsed.is_warning is False
assert parsed.remaining_seconds == 0
def test_parse_game_over_with_winner(self) -> None:
"""Test parsing GameOverMessage with a winner.
Most games end with a winner - the final state is included so
clients can display the final board position.
"""
final_state = VisibleGameState(
game_id="game-456",
viewer_id="player-1",
winner_id="player-1",
end_reason=GameEndReason.PRIZES_TAKEN,
)
message = GameOverMessage(
game_id="game-456",
winner_id="player-1",
end_reason=GameEndReason.PRIZES_TAKEN,
final_state=final_state,
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, GameOverMessage)
assert parsed.winner_id == "player-1"
assert parsed.end_reason == GameEndReason.PRIZES_TAKEN
def test_parse_game_over_draw(self) -> None:
"""Test parsing GameOverMessage as a draw.
Games can end in a draw (e.g., timeout with equal scores),
in which case winner_id is None.
"""
final_state = VisibleGameState(
game_id="game-456",
viewer_id="player-1",
end_reason=GameEndReason.DRAW,
)
message = GameOverMessage(
game_id="game-456",
winner_id=None,
end_reason=GameEndReason.DRAW,
final_state=final_state,
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, GameOverMessage)
assert parsed.winner_id is None
assert parsed.end_reason == GameEndReason.DRAW
def test_parse_opponent_status_connected(self) -> None:
"""Test parsing OpponentStatusMessage for connection.
Notifies players when their opponent connects, enabling
connection status display in the UI.
"""
message = OpponentStatusMessage(
game_id="game-456",
opponent_id="player-2",
status=ConnectionStatus.CONNECTED,
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, OpponentStatusMessage)
assert parsed.status == ConnectionStatus.CONNECTED
def test_parse_opponent_status_disconnected(self) -> None:
"""Test parsing OpponentStatusMessage for disconnection.
When an opponent disconnects, the UI can show a waiting indicator
and possibly pause the game timer.
"""
message = OpponentStatusMessage(
game_id="game-456",
opponent_id="player-2",
status=ConnectionStatus.DISCONNECTED,
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, OpponentStatusMessage)
assert parsed.status == ConnectionStatus.DISCONNECTED
def test_parse_heartbeat_ack(self) -> None:
"""Test parsing HeartbeatAckMessage.
HeartbeatAck confirms the connection is alive, allowing clients
to measure round-trip latency.
"""
message = HeartbeatAckMessage()
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, HeartbeatAckMessage)
def test_server_message_has_timestamp(self) -> None:
"""Test that all server messages have a timestamp.
Timestamps enable client-side latency tracking and event ordering
in case of out-of-order message delivery.
"""
message = HeartbeatAckMessage()
assert message.timestamp is not None
assert isinstance(message.timestamp, datetime)
# Should be UTC
assert message.timestamp.tzinfo == UTC
def test_type_adapter_server_message(self) -> None:
"""Test ServerMessage union with TypeAdapter.
Verifies TypeAdapter works correctly for the more complex
server message union with nested types.
"""
adapter = TypeAdapter(ServerMessage)
error_data = {
"type": "error",
"message_id": "test",
"timestamp": datetime.now(UTC).isoformat(),
"code": "game_not_found",
"message": "Game not found",
}
message = adapter.validate_python(error_data)
assert isinstance(message, ErrorMessage)
class TestMessageSerialization:
"""Tests for message JSON serialization and round-trip behavior.
Messages must serialize cleanly to JSON for WebSocket transmission
and deserialize back to the same logical structure.
"""
def test_client_message_round_trip(self) -> None:
"""Test client message JSON round-trip.
Messages should survive serialization to JSON and back without
losing data or changing types.
"""
original = ActionMessage(
message_id="test-123",
game_id="game-456",
action=AttackAction(attack_index=1, targets=["target-1"]),
)
json_str = original.model_dump_json()
import json
data = json.loads(json_str)
restored = parse_client_message(data)
assert isinstance(restored, ActionMessage)
assert restored.message_id == original.message_id
assert isinstance(restored.action, AttackAction)
assert restored.action.attack_index == 1
assert restored.action.targets == ["target-1"]
def test_server_message_round_trip(self) -> None:
"""Test server message JSON round-trip.
Server messages with nested state should serialize and deserialize
correctly, preserving all game state information.
"""
visible_state = VisibleGameState(
game_id="game-456",
viewer_id="player-1",
turn_number=3,
phase=TurnPhase.ATTACK,
)
original = GameStateMessage(
message_id="test-123",
game_id="game-456",
state=visible_state,
)
json_str = original.model_dump_json()
import json
data = json.loads(json_str)
restored = parse_server_message(data)
assert isinstance(restored, GameStateMessage)
assert restored.state.turn_number == 3
assert restored.state.phase == TurnPhase.ATTACK
def test_error_code_serializes_as_string(self) -> None:
"""Test that WSErrorCode serializes as string value.
Error codes should serialize to their string values for clean
JSON output that clients can easily handle.
"""
message = ErrorMessage(
code=WSErrorCode.GAME_NOT_FOUND,
message="Game not found",
)
json_str = message.model_dump_json()
import json
data = json.loads(json_str)
assert data["code"] == "game_not_found"
class TestMessageValidation:
"""Tests for message field validation rules.
Validation ensures messages contain valid data before processing,
preventing invalid state from propagating through the system.
"""
def test_join_game_requires_game_id(self) -> None:
"""Test that JoinGameMessage requires game_id.
game_id is mandatory - clients must specify which game to join.
"""
data = {
"type": "join_game",
"message_id": "test-123",
# missing game_id
}
with pytest.raises(ValidationError):
parse_client_message(data)
def test_action_message_requires_action(self) -> None:
"""Test that ActionMessage requires action field.
The action field is mandatory - messages without an action
cannot be processed.
"""
data = {
"type": "action",
"message_id": "test-123",
"game_id": "game-456",
# missing action
}
with pytest.raises(ValidationError):
parse_client_message(data)
def test_turn_number_must_be_positive(self) -> None:
"""Test that turn_number must be >= 1.
Turn numbers start at 1 - zero or negative values are invalid.
"""
with pytest.raises(ValidationError):
TurnStartMessage(
game_id="game-456",
player_id="player-1",
turn_number=0, # Invalid
)
def test_remaining_seconds_non_negative(self) -> None:
"""Test that remaining_seconds must be >= 0.
Negative timeout values are meaningless and should be rejected.
"""
with pytest.raises(ValidationError):
TurnTimeoutMessage(
game_id="game-456",
remaining_seconds=-1, # Invalid
is_warning=True,
player_id="player-1",
)
def test_action_result_requires_request_id(self) -> None:
"""Test that ActionResultMessage requires request_message_id.
Results must link back to the original request for client-side
correlation of actions and responses.
"""
with pytest.raises(ValidationError):
ActionResultMessage(
game_id="game-456",
# missing request_message_id
success=True,
action_type="attack",
)
class TestWSErrorCode:
"""Tests for WebSocket error code enumeration.
Error codes provide machine-readable error classification for
programmatic error handling on the client.
"""
def test_all_error_codes_are_strings(self) -> None:
"""Test that all error codes are string values.
StrEnum ensures all values serialize as strings for JSON
compatibility.
"""
for code in WSErrorCode:
assert isinstance(code.value, str)
def test_error_codes_are_lowercase_snake_case(self) -> None:
"""Test error code naming convention.
Error codes follow lowercase_snake_case for consistency with
other API identifiers and JSON conventions.
"""
for code in WSErrorCode:
assert code.value == code.value.lower()
assert " " not in code.value
class TestConnectionStatus:
"""Tests for connection status enumeration."""
def test_all_statuses_are_strings(self) -> None:
"""Test that all connection statuses are string values."""
for status in ConnectionStatus:
assert isinstance(status.value, str)
def test_expected_statuses_exist(self) -> None:
"""Test that expected connection statuses are defined.
The three key states (connected, disconnected, reconnecting)
cover all connection lifecycle phases.
"""
assert ConnectionStatus.CONNECTED
assert ConnectionStatus.DISCONNECTED
assert ConnectionStatus.RECONNECTING

View File

@ -0,0 +1,665 @@
"""Tests for ConnectionManager service.
This module tests WebSocket connection tracking with Redis. Since these are
unit tests, we mock the Redis operations to test the ConnectionManager logic
without requiring a real Redis instance.
"""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from app.services.connection_manager import (
CONN_PREFIX,
GAME_CONNS_PREFIX,
HEARTBEAT_INTERVAL_SECONDS,
USER_CONN_PREFIX,
ConnectionInfo,
ConnectionManager,
)
@pytest.fixture
def manager() -> ConnectionManager:
"""Create a ConnectionManager instance for testing."""
return ConnectionManager(conn_ttl_seconds=3600)
@pytest.fixture
def mock_redis() -> AsyncMock:
"""Create a mock Redis client."""
redis = AsyncMock()
redis.hset = AsyncMock()
redis.hget = AsyncMock()
redis.hgetall = AsyncMock(return_value={})
redis.set = AsyncMock()
redis.get = AsyncMock(return_value=None)
redis.delete = AsyncMock()
redis.exists = AsyncMock(return_value=False)
redis.expire = AsyncMock()
redis.sadd = AsyncMock()
redis.srem = AsyncMock()
redis.smembers = AsyncMock(return_value=set())
redis.scard = AsyncMock(return_value=0)
return redis
class TestConnectionInfoDataclass:
"""Tests for the ConnectionInfo dataclass."""
def test_is_stale_returns_false_for_recent_connection(self) -> None:
"""Test that recent connections are not marked as stale.
Connections that have been seen within the threshold should be
considered active, not stale.
"""
now = datetime.now(UTC)
info = ConnectionInfo(
sid="test-sid",
user_id="user-123",
game_id=None,
connected_at=now,
last_seen=now,
)
assert info.is_stale() is False
def test_is_stale_returns_true_for_old_connection(self) -> None:
"""Test that old connections are marked as stale.
Connections that haven't been seen for longer than the threshold
should be considered stale and eligible for cleanup.
"""
now = datetime.now(UTC)
old_time = now - timedelta(seconds=HEARTBEAT_INTERVAL_SECONDS * 4)
info = ConnectionInfo(
sid="test-sid",
user_id="user-123",
game_id=None,
connected_at=old_time,
last_seen=old_time,
)
assert info.is_stale() is True
def test_is_stale_with_custom_threshold(self) -> None:
"""Test is_stale with a custom threshold.
The threshold can be adjusted for different use cases like
more aggressive cleanup or more lenient timeout.
"""
now = datetime.now(UTC)
last_seen = now - timedelta(seconds=60)
info = ConnectionInfo(
sid="test-sid",
user_id="user-123",
game_id=None,
connected_at=now,
last_seen=last_seen,
)
# 60 seconds old, with 30 second threshold = stale
assert info.is_stale(threshold_seconds=30) is True
# 60 seconds old, with 120 second threshold = not stale
assert info.is_stale(threshold_seconds=120) is False
class TestRegisterConnection:
"""Tests for connection registration."""
@pytest.mark.asyncio
async def test_register_connection_creates_records(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that registering a connection creates all necessary Redis records.
Registration should create:
1. Connection hash with user_id, game_id, timestamps
2. User-to-connection mapping
"""
sid = "test-sid-123"
user_id = str(uuid4())
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
await manager.register_connection(sid, user_id)
# Verify connection hash was created
conn_key = f"{CONN_PREFIX}{sid}"
mock_redis.hset.assert_called()
hset_call = mock_redis.hset.call_args
assert hset_call.args[0] == conn_key
mapping = hset_call.kwargs["mapping"]
assert mapping["user_id"] == user_id
assert mapping["game_id"] == ""
# Verify user-to-connection mapping was created
user_conn_key = f"{USER_CONN_PREFIX}{user_id}"
mock_redis.set.assert_called_with(user_conn_key, sid)
@pytest.mark.asyncio
async def test_register_connection_replaces_old_connection(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that registering a new connection replaces the old one.
When a user reconnects, their old connection should be cleaned up
to prevent stale connection data from lingering.
"""
old_sid = "old-sid"
new_sid = "new-sid"
user_id = str(uuid4())
# Mock: user has existing connection
mock_redis.get.return_value = old_sid
mock_redis.hgetall.return_value = {
"user_id": user_id,
"game_id": "",
"connected_at": datetime.now(UTC).isoformat(),
"last_seen": datetime.now(UTC).isoformat(),
}
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
await manager.register_connection(new_sid, user_id)
# Verify old connection was cleaned up (delete called for old conn key)
old_conn_key = f"{CONN_PREFIX}{old_sid}"
delete_calls = [call.args[0] for call in mock_redis.delete.call_args_list]
assert old_conn_key in delete_calls
@pytest.mark.asyncio
async def test_register_connection_accepts_uuid(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that register_connection accepts UUID objects.
The user_id can be passed as either a string or UUID object
for convenience.
"""
sid = "test-sid"
user_uuid = uuid4()
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
await manager.register_connection(sid, user_uuid)
# Verify user_id was converted to string
hset_call = mock_redis.hset.call_args
mapping = hset_call.kwargs["mapping"]
assert mapping["user_id"] == str(user_uuid)
class TestUnregisterConnection:
"""Tests for connection unregistration."""
@pytest.mark.asyncio
async def test_unregister_returns_none_for_unknown_sid(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that unregistering unknown connection returns None.
If the connection doesn't exist, we should return None rather than
raising an error.
"""
mock_redis.hgetall.return_value = {}
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.unregister_connection("unknown-sid")
assert result is None
@pytest.mark.asyncio
async def test_unregister_cleans_up_all_data(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that unregistering cleans up all related Redis data.
Cleanup should remove:
1. Connection hash
2. User-to-connection mapping (if it still points to this sid)
3. Game connection set membership
"""
sid = "test-sid"
user_id = "user-123"
game_id = "game-456"
# Mock: connection exists with game
mock_redis.hgetall.return_value = {
"user_id": user_id,
"game_id": game_id,
"connected_at": datetime.now(UTC).isoformat(),
"last_seen": datetime.now(UTC).isoformat(),
}
mock_redis.get.return_value = sid # user mapping points to this sid
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.unregister_connection(sid)
# Verify result
assert result is not None
assert result.sid == sid
assert result.user_id == user_id
assert result.game_id == game_id
# Verify cleanup
mock_redis.srem.assert_called() # Removed from game set
mock_redis.delete.assert_called() # Connection deleted
class TestGameAssociation:
"""Tests for game join/leave operations."""
@pytest.mark.asyncio
async def test_join_game_adds_to_game_set(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that joining a game adds the connection to the game's set.
When a connection joins a game, it should:
1. Update the connection's game_id
2. Add the sid to the game's connection set
"""
sid = "test-sid"
game_id = "game-123"
mock_redis.exists.return_value = True
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.join_game(sid, game_id)
assert result is True
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", game_id)
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
@pytest.mark.asyncio
async def test_join_game_returns_false_for_unknown_connection(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that joining a game fails for unknown connections.
If the connection doesn't exist, we should return False rather
than creating orphan game association data.
"""
mock_redis.exists.return_value = False
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.join_game("unknown-sid", "game-123")
assert result is False
mock_redis.sadd.assert_not_called()
@pytest.mark.asyncio
async def test_join_game_leaves_previous_game(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that joining a new game leaves the previous game.
When switching games, the connection should be removed from the
old game's connection set before joining the new one.
"""
sid = "test-sid"
old_game = "game-old"
new_game = "game-new"
mock_redis.exists.return_value = True
mock_redis.hget.return_value = old_game # Currently in old game
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
await manager.join_game(sid, new_game)
# Verify left old game
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{old_game}", sid)
# Verify joined new game
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{new_game}", sid)
@pytest.mark.asyncio
async def test_leave_game_removes_from_set(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that leaving a game removes the connection from the set.
Leave should:
1. Remove sid from game's connection set
2. Clear game_id on connection record
"""
sid = "test-sid"
game_id = "game-123"
mock_redis.hget.return_value = game_id
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.leave_game(sid)
assert result == game_id
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "")
@pytest.mark.asyncio
async def test_leave_game_returns_none_when_not_in_game(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that leave_game returns None when not in a game.
If the connection isn't associated with any game, we should
return None without making unnecessary Redis calls.
"""
mock_redis.hget.return_value = "" # No game_id
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.leave_game("test-sid")
assert result is None
mock_redis.srem.assert_not_called()
class TestHeartbeat:
"""Tests for heartbeat/activity tracking."""
@pytest.mark.asyncio
async def test_update_heartbeat_refreshes_last_seen(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that heartbeat updates the last_seen timestamp.
Heartbeats keep the connection alive by updating the timestamp
and refreshing TTLs on Redis records.
"""
sid = "test-sid"
mock_redis.exists.return_value = True
mock_redis.hget.return_value = "user-123"
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.update_heartbeat(sid)
assert result is True
# Verify last_seen was updated
hset_call = mock_redis.hset.call_args
assert hset_call.args[0] == f"{CONN_PREFIX}{sid}"
assert hset_call.args[1] == "last_seen"
# Verify TTL was refreshed
mock_redis.expire.assert_called()
@pytest.mark.asyncio
async def test_update_heartbeat_returns_false_for_unknown(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that heartbeat returns False for unknown connections.
If the connection doesn't exist, we shouldn't try to update it.
"""
mock_redis.exists.return_value = False
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.update_heartbeat("unknown-sid")
assert result is False
class TestQueryMethods:
"""Tests for connection query methods."""
@pytest.mark.asyncio
async def test_get_connection_returns_info(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that get_connection returns ConnectionInfo for valid sid.
The returned ConnectionInfo should have all fields populated
from the Redis hash.
"""
sid = "test-sid"
now = datetime.now(UTC)
mock_redis.hgetall.return_value = {
"user_id": "user-123",
"game_id": "game-456",
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
}
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.get_connection(sid)
assert result is not None
assert result.sid == sid
assert result.user_id == "user-123"
assert result.game_id == "game-456"
@pytest.mark.asyncio
async def test_get_connection_returns_none_for_unknown(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that get_connection returns None for unknown sid."""
mock_redis.hgetall.return_value = {}
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.get_connection("unknown-sid")
assert result is None
@pytest.mark.asyncio
async def test_is_user_online_returns_true_for_connected_user(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that is_user_online returns True for connected users.
A user is online if they have an active connection that isn't stale.
"""
user_id = "user-123"
now = datetime.now(UTC)
mock_redis.get.return_value = "test-sid"
mock_redis.hgetall.return_value = {
"user_id": user_id,
"game_id": "",
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
}
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.is_user_online(user_id)
assert result is True
@pytest.mark.asyncio
async def test_is_user_online_returns_false_for_stale_connection(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that is_user_online returns False for stale connections.
Even if a connection record exists, if it's stale (no recent heartbeat),
the user should be considered offline.
"""
user_id = "user-123"
old_time = datetime.now(UTC) - timedelta(minutes=5)
mock_redis.get.return_value = "test-sid"
mock_redis.hgetall.return_value = {
"user_id": user_id,
"game_id": "",
"connected_at": old_time.isoformat(),
"last_seen": old_time.isoformat(),
}
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.is_user_online(user_id)
assert result is False
@pytest.mark.asyncio
async def test_get_game_connections_returns_all_participants(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that get_game_connections returns all game participants.
Should return ConnectionInfo for each sid in the game's connection set.
"""
game_id = "game-123"
now = datetime.now(UTC)
mock_redis.smembers.return_value = {"sid-1", "sid-2"}
mock_redis.hgetall.side_effect = [
{
"user_id": "user-1",
"game_id": game_id,
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
},
{
"user_id": "user-2",
"game_id": game_id,
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
},
]
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.get_game_connections(game_id)
assert len(result) == 2
user_ids = {conn.user_id for conn in result}
assert user_ids == {"user-1", "user-2"}
@pytest.mark.asyncio
async def test_get_opponent_sid_returns_other_player(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that get_opponent_sid returns the other player's sid.
In a 2-player game, this should return the sid of the player
who is not the current user.
"""
game_id = "game-123"
current_user = "user-1"
opponent_user = "user-2"
opponent_sid = "sid-2"
now = datetime.now(UTC)
mock_redis.smembers.return_value = {"sid-1", "sid-2"}
# Use a function to return the right data based on the key queried
# This avoids dependency on set iteration order
def hgetall_by_key(key: str) -> dict[str, str]:
if key == f"{CONN_PREFIX}sid-1":
return {
"user_id": current_user,
"game_id": game_id,
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
}
elif key == f"{CONN_PREFIX}sid-2":
return {
"user_id": opponent_user,
"game_id": game_id,
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
}
return {}
mock_redis.hgetall.side_effect = hgetall_by_key
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.get_opponent_sid(game_id, current_user)
assert result == opponent_sid
class TestKeyGeneration:
"""Tests for Redis key generation methods."""
def test_conn_key_format(self, manager: ConnectionManager) -> None:
"""Test that connection keys have correct format.
Keys should follow the pattern conn:{sid} for easy identification
and pattern matching.
"""
key = manager._conn_key("test-sid-123")
assert key == "conn:test-sid-123"
def test_user_conn_key_format(self, manager: ConnectionManager) -> None:
"""Test that user connection keys have correct format.
Keys should follow the pattern user_conn:{user_id}.
"""
key = manager._user_conn_key("user-456")
assert key == "user_conn:user-456"
def test_game_conns_key_format(self, manager: ConnectionManager) -> None:
"""Test that game connections keys have correct format.
Keys should follow the pattern game_conns:{game_id}.
"""
key = manager._game_conns_key("game-789")
assert key == "game_conns:game-789"

View File

@ -0,0 +1,809 @@
"""Tests for GameService.
This module tests the game service layer that orchestrates between
WebSocket communication and the core GameEngine.
The GameService is STATELESS regarding game rules:
- No GameEngine is stored in the service
- Engine is created per-operation using rules from GameState
- Rules come from frontend at game creation, stored in GameState
"""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from app.core.engine import ActionResult
from app.core.enums import GameEndReason, TurnPhase
from app.core.models.actions import AttackAction, PassAction, ResignAction
from app.core.models.game_state import GameState, PlayerState
from app.core.win_conditions import WinResult
from app.services.game_service import (
GameAlreadyEndedError,
GameNotFoundError,
GameService,
InvalidActionError,
NotPlayerTurnError,
PlayerNotInGameError,
)
@pytest.fixture
def mock_state_manager() -> AsyncMock:
"""Create a mock GameStateManager.
The state manager handles persistence to Redis (cache) and
Postgres (durable storage).
"""
manager = AsyncMock()
manager.load_state = AsyncMock(return_value=None)
manager.save_to_cache = AsyncMock()
manager.persist_to_db = AsyncMock()
manager.cache_exists = AsyncMock(return_value=False)
return manager
@pytest.fixture
def mock_card_service() -> MagicMock:
"""Create a mock CardService.
CardService provides card definitions for game creation.
"""
return MagicMock()
@pytest.fixture
def game_service(
mock_state_manager: AsyncMock,
mock_card_service: MagicMock,
) -> GameService:
"""Create a GameService with mocked dependencies.
Note: No engine is passed - GameService creates engines per-operation
using rules stored in each game's GameState.
"""
return GameService(
state_manager=mock_state_manager,
card_service=mock_card_service,
)
@pytest.fixture
def sample_game_state() -> GameState:
"""Create a sample game state for testing.
The game state includes two players and basic turn tracking.
The rules are stored in the state itself (default RulesConfig).
"""
player1 = PlayerState(player_id="player-1")
player2 = PlayerState(player_id="player-2")
return GameState(
game_id="game-123",
players={"player-1": player1, "player-2": player2},
current_player_id="player-1",
turn_number=1,
phase=TurnPhase.MAIN,
)
class TestGameStateAccess:
"""Tests for game state access methods."""
@pytest.mark.asyncio
async def test_get_game_state_returns_state(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that get_game_state returns the game state when found.
The state manager should be called to load the state, and the
result should be returned to the caller.
"""
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.get_game_state("game-123")
assert result == sample_game_state
mock_state_manager.load_state.assert_called_once_with("game-123")
@pytest.mark.asyncio
async def test_get_game_state_raises_not_found(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
) -> None:
"""Test that get_game_state raises GameNotFoundError when not found.
When the state manager returns None, we should raise a specific
exception rather than returning None.
"""
mock_state_manager.load_state.return_value = None
with pytest.raises(GameNotFoundError) as exc_info:
await game_service.get_game_state("nonexistent")
assert exc_info.value.game_id == "nonexistent"
@pytest.mark.asyncio
async def test_get_player_view_returns_visible_state(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that get_player_view returns visibility-filtered state.
The returned state should be filtered for the requesting player,
hiding the opponent's private information.
"""
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.get_player_view("game-123", "player-1")
assert result.game_id == "game-123"
assert result.viewer_id == "player-1"
assert result.is_my_turn is True
@pytest.mark.asyncio
async def test_get_player_view_raises_not_in_game(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that get_player_view raises error for non-participants.
Only players in the game should be able to view its state.
"""
mock_state_manager.load_state.return_value = sample_game_state
with pytest.raises(PlayerNotInGameError) as exc_info:
await game_service.get_player_view("game-123", "stranger")
assert exc_info.value.player_id == "stranger"
@pytest.mark.asyncio
async def test_is_player_turn_returns_true_for_current_player(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that is_player_turn returns True for the current player.
The current player is determined by the game state's current_player_id.
"""
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.is_player_turn("game-123", "player-1")
assert result is True
@pytest.mark.asyncio
async def test_is_player_turn_returns_false_for_other_player(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that is_player_turn returns False for non-current player.
Players who are not the current player should get False.
"""
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.is_player_turn("game-123", "player-2")
assert result is False
@pytest.mark.asyncio
async def test_game_exists_returns_true_when_cached(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
) -> None:
"""Test that game_exists returns True when game is in cache.
This is a quick check that doesn't need to load the full state.
"""
mock_state_manager.cache_exists.return_value = True
result = await game_service.game_exists("game-123")
assert result is True
mock_state_manager.cache_exists.assert_called_once_with("game-123")
class TestJoinGame:
"""Tests for the join_game method."""
@pytest.mark.asyncio
async def test_join_game_success(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test successful game join returns visible state.
When a player joins their game, they should receive the
visibility-filtered state and know if it's their turn.
"""
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.join_game("game-123", "player-1")
assert result.success is True
assert result.game_id == "game-123"
assert result.player_id == "player-1"
assert result.visible_state is not None
assert result.is_your_turn is True
@pytest.mark.asyncio
async def test_join_game_not_your_turn(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that is_your_turn is False when it's opponent's turn.
The second player should see is_your_turn=False when the first
player is the current player.
"""
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.join_game("game-123", "player-2")
assert result.success is True
assert result.is_your_turn is False
@pytest.mark.asyncio
async def test_join_game_not_found(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
) -> None:
"""Test join_game returns failure for non-existent game.
Rather than raising an exception, join_game returns a failed
result for better WebSocket error handling.
"""
mock_state_manager.load_state.return_value = None
result = await game_service.join_game("nonexistent", "player-1")
assert result.success is False
assert "not found" in result.message.lower()
@pytest.mark.asyncio
async def test_join_game_not_participant(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test join_game returns failure for non-participants.
Players who are not in the game should not be able to join.
"""
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.join_game("game-123", "stranger")
assert result.success is False
assert "not a participant" in result.message.lower()
@pytest.mark.asyncio
async def test_join_ended_game(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test joining an ended game still succeeds but indicates game over.
Players should be able to rejoin ended games to see the final
state, but is_your_turn should be False.
"""
sample_game_state.winner_id = "player-1"
sample_game_state.end_reason = GameEndReason.PRIZES_TAKEN
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.join_game("game-123", "player-2")
assert result.success is True
assert result.is_your_turn is False
assert "ended" in result.message.lower()
class TestExecuteAction:
"""Tests for the execute_action method.
These tests verify action execution through GameService. Since GameService
creates engines per-operation, we patch _create_engine_for_game to return
a mock engine with controlled behavior.
"""
@pytest.mark.asyncio
async def test_execute_action_success(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test successful action execution.
A valid action by the current player should be executed and
the state should be saved to cache.
"""
mock_state_manager.load_state.return_value = sample_game_state
mock_engine = MagicMock()
mock_engine.execute_action = AsyncMock(
return_value=ActionResult(
success=True,
message="Attack executed",
state_changes=[{"type": "damage", "amount": 30}],
)
)
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
action = AttackAction(attack_index=0)
result = await game_service.execute_action("game-123", "player-1", action)
assert result.success is True
assert result.action_type == "attack"
mock_state_manager.save_to_cache.assert_called_once()
@pytest.mark.asyncio
async def test_execute_action_game_not_found(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
) -> None:
"""Test execute_action raises error when game not found.
Missing games should raise GameNotFoundError for proper
error handling in the WebSocket layer.
"""
mock_state_manager.load_state.return_value = None
with pytest.raises(GameNotFoundError):
await game_service.execute_action("nonexistent", "player-1", PassAction())
@pytest.mark.asyncio
async def test_execute_action_not_in_game(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test execute_action raises error for non-participants.
Only players in the game can execute actions.
"""
mock_state_manager.load_state.return_value = sample_game_state
with pytest.raises(PlayerNotInGameError):
await game_service.execute_action("game-123", "stranger", PassAction())
@pytest.mark.asyncio
async def test_execute_action_game_ended(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test execute_action raises error on ended games.
No actions should be allowed once a game has ended.
"""
sample_game_state.winner_id = "player-1"
sample_game_state.end_reason = GameEndReason.RESIGNATION
mock_state_manager.load_state.return_value = sample_game_state
with pytest.raises(GameAlreadyEndedError):
await game_service.execute_action("game-123", "player-1", PassAction())
@pytest.mark.asyncio
async def test_execute_action_not_your_turn(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test execute_action raises error when not player's turn.
Only the current player can execute actions.
"""
mock_state_manager.load_state.return_value = sample_game_state
with pytest.raises(NotPlayerTurnError) as exc_info:
await game_service.execute_action("game-123", "player-2", PassAction())
assert exc_info.value.player_id == "player-2"
assert exc_info.value.current_player_id == "player-1"
@pytest.mark.asyncio
async def test_execute_action_resign_allowed_out_of_turn(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that resignation is allowed even when not your turn.
Resignation is a special action that can be executed anytime
by either player.
"""
mock_state_manager.load_state.return_value = sample_game_state
mock_engine = MagicMock()
mock_engine.execute_action = AsyncMock(
return_value=ActionResult(
success=True,
message="Player resigned",
win_result=WinResult(
winner_id="player-1",
loser_id="player-2",
end_reason=GameEndReason.RESIGNATION,
reason="Player resigned",
),
)
)
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
# player-2 resigns even though it's player-1's turn
result = await game_service.execute_action("game-123", "player-2", ResignAction())
assert result.success is True
assert result.game_over is True
assert result.winner_id == "player-1"
@pytest.mark.asyncio
async def test_execute_action_invalid_action(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test execute_action raises error for invalid actions.
When the GameEngine rejects an action, we should raise
InvalidActionError with the reason.
"""
mock_state_manager.load_state.return_value = sample_game_state
mock_engine = MagicMock()
mock_engine.execute_action = AsyncMock(
return_value=ActionResult(
success=False,
message="Not enough energy to attack",
)
)
with (
patch.object(game_service, "_create_engine_for_game", return_value=mock_engine),
pytest.raises(InvalidActionError) as exc_info,
):
await game_service.execute_action("game-123", "player-1", AttackAction(attack_index=0))
assert "Not enough energy" in exc_info.value.reason
@pytest.mark.asyncio
async def test_execute_action_game_over(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test execute_action detects game over and persists to DB.
When an action results in a win, the state should be persisted
to the database for durability.
"""
mock_state_manager.load_state.return_value = sample_game_state
mock_engine = MagicMock()
mock_engine.execute_action = AsyncMock(
return_value=ActionResult(
success=True,
message="Final prize taken!",
win_result=WinResult(
winner_id="player-1",
loser_id="player-2",
end_reason=GameEndReason.PRIZES_TAKEN,
reason="All prizes taken",
),
)
)
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
result = await game_service.execute_action(
"game-123", "player-1", AttackAction(attack_index=0)
)
assert result.game_over is True
assert result.winner_id == "player-1"
assert result.end_reason == GameEndReason.PRIZES_TAKEN
mock_state_manager.persist_to_db.assert_called_once()
@pytest.mark.asyncio
async def test_execute_action_uses_game_rules(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that execute_action creates engine with game's rules.
The engine should be created using the rules stored in the game
state, not any service-level defaults.
"""
mock_state_manager.load_state.return_value = sample_game_state
mock_engine = MagicMock()
mock_engine.execute_action = AsyncMock(
return_value=ActionResult(success=True, message="OK")
)
with patch.object(
game_service, "_create_engine_for_game", return_value=mock_engine
) as mock_create:
await game_service.execute_action("game-123", "player-1", PassAction())
# Verify engine was created with the game state
mock_create.assert_called_once_with(sample_game_state)
class TestResignGame:
"""Tests for the resign_game convenience method."""
@pytest.mark.asyncio
async def test_resign_game_executes_resign_action(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that resign_game is a convenience wrapper for execute_action.
The resign_game method should internally create a ResignAction
and call execute_action.
"""
mock_state_manager.load_state.return_value = sample_game_state
mock_engine = MagicMock()
mock_engine.execute_action = AsyncMock(
return_value=ActionResult(
success=True,
message="Player resigned",
win_result=WinResult(
winner_id="player-2",
loser_id="player-1",
end_reason=GameEndReason.RESIGNATION,
reason="Player resigned",
),
)
)
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
result = await game_service.resign_game("game-123", "player-1")
assert result.success is True
assert result.action_type == "resign"
assert result.game_over is True
class TestEndGame:
"""Tests for the end_game method (forced ending)."""
@pytest.mark.asyncio
async def test_end_game_sets_winner_and_reason(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that end_game sets winner and end reason.
Used for timeout or disconnection scenarios where the game
needs to be forcibly ended.
"""
mock_state_manager.load_state.return_value = sample_game_state
await game_service.end_game(
"game-123",
winner_id="player-1",
end_reason=GameEndReason.TIMEOUT,
)
# Verify state was modified
assert sample_game_state.winner_id == "player-1"
assert sample_game_state.end_reason == GameEndReason.TIMEOUT
# Verify persistence
mock_state_manager.save_to_cache.assert_called_once()
mock_state_manager.persist_to_db.assert_called_once()
@pytest.mark.asyncio
async def test_end_game_draw(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test ending a game as a draw (no winner).
Some scenarios (timeout with equal scores) can result in a draw.
"""
mock_state_manager.load_state.return_value = sample_game_state
await game_service.end_game(
"game-123",
winner_id=None,
end_reason=GameEndReason.DRAW,
)
assert sample_game_state.winner_id is None
assert sample_game_state.end_reason == GameEndReason.DRAW
@pytest.mark.asyncio
async def test_end_game_not_found(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
) -> None:
"""Test end_game raises error when game not found."""
mock_state_manager.load_state.return_value = None
with pytest.raises(GameNotFoundError):
await game_service.end_game(
"nonexistent",
winner_id="player-1",
end_reason=GameEndReason.TIMEOUT,
)
class TestCreateGame:
"""Tests for the create_game method (skeleton)."""
@pytest.mark.asyncio
async def test_create_game_raises_not_implemented(
self,
game_service: GameService,
) -> None:
"""Test that create_game raises NotImplementedError.
The full implementation will be done in GS-002. For now,
it should raise NotImplementedError with a clear message.
"""
with pytest.raises(NotImplementedError) as exc_info:
await game_service.create_game(
player1_id=str(uuid4()),
player2_id=str(uuid4()),
)
assert "GS-002" in str(exc_info.value)
class TestCreateEngineForGame:
"""Tests for the _create_engine_for_game method.
This method is responsible for creating a GameEngine configured
with the rules from a specific game's state.
"""
def test_create_engine_uses_game_rules(
self,
game_service: GameService,
sample_game_state: GameState,
) -> None:
"""Test that engine is created with the game's rules.
The engine should use the RulesConfig stored in the game state,
not any default configuration.
"""
engine = game_service._create_engine_for_game(sample_game_state)
# Engine should have the game's rules
assert engine.rules == sample_game_state.rules
def test_create_engine_with_rng_seed(
self,
game_service: GameService,
sample_game_state: GameState,
) -> None:
"""Test that engine uses seeded RNG when game has rng_seed.
When a game has an rng_seed set, the engine should use a
deterministic RNG for replay support.
"""
sample_game_state.rng_seed = 12345
engine = game_service._create_engine_for_game(sample_game_state)
# Engine should have been created (we can't easily verify seed,
# but we can verify it doesn't error)
assert engine is not None
def test_create_engine_without_rng_seed(
self,
game_service: GameService,
sample_game_state: GameState,
) -> None:
"""Test that engine uses secure RNG when no seed is set.
Without an rng_seed, the engine should use cryptographically
secure random number generation.
"""
sample_game_state.rng_seed = None
engine = game_service._create_engine_for_game(sample_game_state)
assert engine is not None
def test_create_engine_derives_unique_seed_per_action(
self,
game_service: GameService,
sample_game_state: GameState,
) -> None:
"""Test that different action counts produce different RNG sequences.
For deterministic replay, each action needs a unique but
reproducible RNG seed based on game seed + action count.
"""
sample_game_state.rng_seed = 12345
# Simulate first action (action_log is empty)
sample_game_state.action_log = []
engine1 = game_service._create_engine_for_game(sample_game_state)
# Simulate second action (one action in log)
sample_game_state.action_log = [{"type": "pass"}]
engine2 = game_service._create_engine_for_game(sample_game_state)
# Both engines should be created successfully
# (They will have different seeds due to action count)
assert engine1 is not None
assert engine2 is not None
class TestExceptionMessages:
"""Tests for exception message formatting."""
def test_game_not_found_error_message(self) -> None:
"""Test GameNotFoundError has descriptive message."""
error = GameNotFoundError("game-123")
assert "game-123" in str(error)
assert error.game_id == "game-123"
def test_not_player_turn_error_message(self) -> None:
"""Test NotPlayerTurnError has descriptive message."""
error = NotPlayerTurnError("game-123", "player-2", "player-1")
assert "player-2" in str(error)
assert "player-1" in str(error)
assert error.game_id == "game-123"
def test_invalid_action_error_message(self) -> None:
"""Test InvalidActionError has descriptive message."""
error = InvalidActionError("game-123", "player-1", "Not enough energy")
assert "Not enough energy" in str(error)
assert error.reason == "Not enough energy"
def test_player_not_in_game_error_message(self) -> None:
"""Test PlayerNotInGameError has descriptive message."""
error = PlayerNotInGameError("game-123", "stranger")
assert "stranger" in str(error)
assert "game-123" in str(error)
def test_game_already_ended_error_message(self) -> None:
"""Test GameAlreadyEndedError has descriptive message."""
error = GameAlreadyEndedError("game-123")
assert "game-123" in str(error)
assert "ended" in str(error).lower()