Complete Phase 4 implementation files

- TurnTimeoutService with percentage-based warnings (35 tests)
- ConnectionManager enhancements for spectators and reconnection
- GameService with timer integration, spectator support, handle_timeout
- GameNamespace with spectate/leave_spectate handlers, reconnection
- WebSocket message schemas for spectator events
- WinConditionsConfig additions for turn timer thresholds
- 83 GameService tests, 37 ConnectionManager tests, 37 GameNamespace tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2026-01-30 08:03:43 -06:00
parent 6f871d7187
commit f452e69999
11 changed files with 3820 additions and 7 deletions

View File

@ -180,6 +180,9 @@ class WinConditionsConfig(BaseModel):
turn counts as one turn (so 30 = 15 turns per player).
turn_timer_enabled: Enable per-turn time limits (multiplayer).
turn_timer_seconds: Seconds per turn before timeout (default 90).
turn_timer_warning_thresholds: Percentage of time remaining to send warnings.
Default [50, 25] means warnings at 50% and 25% remaining.
turn_timer_grace_seconds: Extra seconds granted on reconnection.
game_timer_enabled: Enable total game time limit (multiplayer).
game_timer_minutes: Total game time in minutes.
"""
@ -191,6 +194,8 @@ class WinConditionsConfig(BaseModel):
turn_limit: int = 30
turn_timer_enabled: bool = False
turn_timer_seconds: int = 90
turn_timer_warning_thresholds: list[int] = Field(default_factory=lambda: [50, 25])
turn_timer_grace_seconds: int = 15
game_timer_enabled: bool = False
game_timer_minutes: int = 30

View File

@ -231,6 +231,7 @@ class GameStateMessage(BaseServerMessage):
game_id: The game this state is for.
state: The full visible game state.
event_id: Monotonic event ID for reconnection replay.
spectator_count: Number of users spectating this game.
"""
type: Literal["game_state"] = "game_state"
@ -240,6 +241,11 @@ class GameStateMessage(BaseServerMessage):
default_factory=_generate_message_id,
description="Event ID for reconnection replay",
)
spectator_count: int = Field(
default=0,
description="Number of users spectating this game",
ge=0,
)
class ActionResultMessage(BaseServerMessage):

View File

@ -55,6 +55,7 @@ RedisFactory = Callable[[], AsyncIterator["Redis"]]
CONN_PREFIX = "conn:"
USER_CONN_PREFIX = "user_conn:"
GAME_CONNS_PREFIX = "game_conns:"
SPECTATORS_PREFIX = "spectators:"
# Connection TTL (auto-expire stale connections)
DEFAULT_CONN_TTL_SECONDS = 3600 # 1 hour
@ -133,6 +134,10 @@ class ConnectionManager:
"""Generate Redis key for game connection set."""
return f"{GAME_CONNS_PREFIX}{game_id}"
def _spectators_key(self, game_id: str) -> str:
"""Generate Redis key for game spectators set."""
return f"{SPECTATORS_PREFIX}{game_id}"
# =========================================================================
# Connection Lifecycle
# =========================================================================
@ -234,8 +239,13 @@ class ConnectionManager:
if not user_id:
logger.warning(f"Connection {sid} has no user_id - data may be corrupted")
# Remove from game connection set if in a game
if game_id:
# Check if spectating (game_id format: "spectating:{actual_game_id}")
if game_id and game_id.startswith("spectating:"):
actual_game_id = game_id[len("spectating:") :]
spectators_key = self._spectators_key(actual_game_id)
await redis.srem(spectators_key, sid)
elif game_id:
# Remove from game connection set if in a game as player
game_conns_key = self._game_conns_key(game_id)
await redis.srem(game_conns_key, sid)
@ -538,6 +548,28 @@ class ConnectionManager:
return None
async def get_user_active_game(self, user_id: str | UUID) -> str | None:
"""Get the game ID of a user's active game from their connection.
This checks the ConnectionManager's tracking, NOT the database.
Used during reconnection to find if the user was in a game.
Args:
user_id: User's UUID or string ID.
Returns:
Game ID if user was in a game, None otherwise.
Example:
game_id = await manager.get_user_active_game(user_id)
if game_id:
# User was in a game, auto-rejoin
"""
conn_info = await self.get_user_connection(user_id)
if conn_info is None:
return None
return conn_info.game_id
# =========================================================================
# Maintenance
# =========================================================================
@ -605,6 +637,134 @@ class ConnectionManager:
game_conns_key = self._game_conns_key(game_id)
return await redis.scard(game_conns_key)
# =========================================================================
# Spectator Management
# =========================================================================
async def register_spectator(
self,
sid: str,
user_id: str | UUID,
game_id: str,
) -> bool:
"""Register a connection as a spectator for a game.
Adds the sid to the spectators set for the game and updates
the connection record with spectator status.
Args:
sid: Socket.IO session ID.
user_id: User's ID (for logging).
game_id: Game to spectate.
Returns:
True if successful, False if connection not found.
Example:
await manager.register_spectator("abc123", user_id, "game-456")
"""
user_id_str = str(user_id)
async with self._get_redis() as redis:
conn_key = self._conn_key(sid)
# Check connection exists
exists = await redis.exists(conn_key)
if not exists:
logger.warning(f"Cannot register spectator: connection not found {sid}")
return False
# Update connection's game_id to indicate spectating
# We use a special format to distinguish from playing
await redis.hset(conn_key, "game_id", f"spectating:{game_id}")
# Add to spectators set
spectators_key = self._spectators_key(game_id)
await redis.sadd(spectators_key, sid)
# Set TTL on spectators set
await redis.expire(spectators_key, self.conn_ttl_seconds)
logger.debug(f"User {user_id_str} ({sid}) spectating game {game_id}")
return True
async def unregister_spectator(self, sid: str, game_id: str) -> bool:
"""Remove a connection from a game's spectator list.
Args:
sid: Socket.IO session ID.
game_id: Game being spectated.
Returns:
True if removed, False if not in spectator list.
Example:
await manager.unregister_spectator("abc123", "game-456")
"""
async with self._get_redis() as redis:
spectators_key = self._spectators_key(game_id)
# Remove from spectators set
removed = await redis.srem(spectators_key, sid)
# Clear game_id on connection if it was spectating this game
conn_key = self._conn_key(sid)
current_game = await redis.hget(conn_key, "game_id")
if current_game == f"spectating:{game_id}":
await redis.hset(conn_key, "game_id", "")
if removed:
logger.debug(f"Connection {sid} stopped spectating game {game_id}")
return bool(removed)
async def get_spectator_count(self, game_id: str) -> int:
"""Get the number of spectators for a game.
Args:
game_id: Game ID.
Returns:
Number of spectators.
Example:
count = await manager.get_spectator_count("game-456")
"""
async with self._get_redis() as redis:
spectators_key = self._spectators_key(game_id)
return await redis.scard(spectators_key)
async def get_game_spectators(self, game_id: str) -> list[str]:
"""Get all spectator sids for a game.
Args:
game_id: Game ID.
Returns:
List of spectator sids.
Example:
sids = await manager.get_game_spectators("game-456")
"""
async with self._get_redis() as redis:
spectators_key = self._spectators_key(game_id)
sids = await redis.smembers(spectators_key)
return list(sids)
async def is_spectating(self, sid: str, game_id: str) -> bool:
"""Check if a connection is spectating a specific game.
Args:
sid: Socket.IO session ID.
game_id: Game ID.
Returns:
True if spectating, False otherwise.
"""
async with self._get_redis() as redis:
spectators_key = self._spectators_key(game_id)
return await redis.sismember(spectators_key, sid)
# Global singleton instance
connection_manager = ConnectionManager()

View File

@ -49,15 +49,16 @@ if TYPE_CHECKING:
from app.core.config import RulesConfig
from app.core.engine import ActionResult, GameCreationResult, GameEngine
from app.core.enums import GameEndReason
from app.core.enums import GameEndReason, TurnPhase
from app.core.models.actions import Action, ResignAction
from app.core.models.card import CardInstance
from app.core.models.game_state import GameState
from app.core.rng import create_rng
from app.core.visibility import VisibleGameState, get_visible_state
from app.core.visibility import VisibleGameState, get_spectator_state, get_visible_state
from app.db.models.game import EndReason, GameType
from app.services.card_service import CardService, get_card_service
from app.services.game_state_manager import GameStateManager, game_state_manager
from app.services.turn_timeout_service import TurnTimeoutService, turn_timeout_service
logger = logging.getLogger(__name__)
@ -189,6 +190,15 @@ class ForcedActionRequiredError(GameServiceError):
)
class CannotSpectateOwnGameError(GameServiceError):
"""Raised when a player tries to spectate a game they are participating in."""
def __init__(self, game_id: str, player_id: str) -> None:
self.game_id = game_id
self.player_id = player_id
super().__init__(f"Cannot spectate your own game: {game_id}")
# =============================================================================
# Result Types
# =============================================================================
@ -227,6 +237,8 @@ class GameActionResult:
turn_changed: Whether the turn changed as a result of this action.
current_player_id: The current player after action execution.
pending_forced_action: If set, the next action must be this forced action.
turn_timeout_seconds: Seconds remaining on turn timer (None if disabled).
turn_deadline: Unix timestamp when current turn expires (None if disabled).
"""
success: bool
@ -240,6 +252,8 @@ class GameActionResult:
turn_changed: bool = False
current_player_id: str | None = None
pending_forced_action: PendingForcedAction | None = None
turn_timeout_seconds: int | None = None
turn_deadline: float | None = None
@dataclass
@ -258,6 +272,8 @@ class GameJoinResult:
game_over: Whether the game has already ended.
pending_forced_action: If set, this action must be taken before any other.
message: Additional information or error message.
turn_timeout_seconds: Seconds remaining on turn timer (None if disabled).
turn_deadline: Unix timestamp when current turn expires (None if disabled).
"""
success: bool
@ -268,6 +284,8 @@ class GameJoinResult:
game_over: bool = False
pending_forced_action: PendingForcedAction | None = None
message: str = ""
turn_timeout_seconds: int | None = None
turn_deadline: float | None = None
@dataclass
@ -325,6 +343,25 @@ class GameEndResult:
message: str = ""
@dataclass
class SpectateResult:
"""Result of spectating a game.
Attributes:
success: Whether spectating succeeded.
game_id: The game being spectated.
visible_state: Spectator-filtered game state (no hands visible).
game_over: Whether the game has already ended.
message: Additional information or error message.
"""
success: bool
game_id: str
visible_state: VisibleGameState | None = None
game_over: bool = False
message: str = ""
# =============================================================================
# GameService
# =============================================================================
@ -344,6 +381,7 @@ class GameService:
Attributes:
_state_manager: GameStateManager for persistence.
_card_service: CardService for card definitions.
_timeout_service: TurnTimeoutService for turn timer management.
_engine_factory: Factory for creating GameEngine for action execution.
_creation_engine_factory: Factory for creating GameEngine for game creation.
"""
@ -352,6 +390,7 @@ class GameService:
self,
state_manager: GameStateManager | None = None,
card_service: CardService | None = None,
timeout_service: TurnTimeoutService | None = None,
engine_factory: EngineFactory | None = None,
creation_engine_factory: CreationEngineFactory | None = None,
) -> None:
@ -364,6 +403,7 @@ class GameService:
Args:
state_manager: GameStateManager instance. Uses global if not provided.
card_service: CardService instance. Uses global if not provided.
timeout_service: TurnTimeoutService instance. Uses global if not provided.
engine_factory: Optional factory for creating GameEngine for action
execution. Takes GameState, returns GameEngine. If not provided,
uses the default _default_engine_factory method.
@ -373,6 +413,7 @@ class GameService:
"""
self._state_manager = state_manager or game_state_manager
self._card_service = card_service or get_card_service()
self._timeout_service = timeout_service or turn_timeout_service
self._engine_factory = engine_factory or self._default_engine_factory
self._creation_engine_factory = (
creation_engine_factory or self._default_creation_engine_factory
@ -575,6 +616,29 @@ class GameService:
if forced is not None and forced.player_id == player_id:
is_turn = True
# Handle turn timer for reconnection
turn_timeout_seconds: int | None = None
turn_deadline: float | None = None
if state.rules.win_conditions.turn_timer_enabled:
# Check if there's an active timer
timeout_info = await self._timeout_service.get_timeout_info(game_id)
if timeout_info is not None:
# Timer exists - extend if this is the current player reconnecting
if is_turn and timeout_info.player_id == player_id:
grace_seconds = state.rules.win_conditions.turn_timer_grace_seconds
extended_info = await self._timeout_service.extend_timer(game_id, grace_seconds)
if extended_info is not None:
timeout_info = extended_info
logger.debug(
f"Extended turn timer on reconnect: game={game_id}, "
f"player={player_id}, grace={grace_seconds}s"
)
turn_timeout_seconds = timeout_info.remaining_seconds
turn_deadline = timeout_info.deadline
logger.info(f"Player {player_id} joined game {game_id}")
return GameJoinResult(
@ -585,6 +649,50 @@ class GameService:
is_your_turn=is_turn,
game_over=False,
pending_forced_action=pending_forced,
turn_timeout_seconds=turn_timeout_seconds,
turn_deadline=turn_deadline,
)
async def spectate_game(
self,
game_id: str,
user_id: str,
) -> SpectateResult:
"""Get spectator view of a game.
Returns a visibility-filtered game state suitable for spectators.
Spectators cannot see any player's hand, deck, or prizes.
Args:
game_id: The game to spectate.
user_id: The user wanting to spectate.
Returns:
SpectateResult with the spectator-visible state.
Raises:
GameNotFoundError: If game doesn't exist.
CannotSpectateOwnGameError: If user is a participant in the game.
"""
state = await self.get_game_state(game_id)
# Players cannot spectate their own game
if user_id in state.players:
raise CannotSpectateOwnGameError(game_id, user_id)
visible = get_spectator_state(state)
# Check if game already ended
game_over = state.winner_id is not None or state.end_reason is not None
logger.info(f"User {user_id} spectating game {game_id}")
return SpectateResult(
success=True,
game_id=game_id,
visible_state=visible,
game_over=game_over,
message="Spectating game" if not game_over else "Game has ended",
)
async def execute_action(
@ -650,6 +758,7 @@ class GameService:
# Track turn state before action for boundary detection
turn_before = state.turn_number
player_before = state.current_player_id
phase_before = state.phase
# Create engine with this game's rules via factory
engine = self._engine_factory(state)
@ -700,6 +809,9 @@ class GameService:
action_result.winner_id = result.win_result.winner_id
action_result.end_reason = result.win_result.end_reason
# Cancel turn timer on game over
await self._timeout_service.cancel_timer(game_id)
# Persist final state to DB
await self._state_manager.persist_to_db(state)
@ -707,6 +819,27 @@ class GameService:
f"Game {game_id} ended: winner={result.win_result.winner_id}, "
f"reason={result.win_result.end_reason}"
)
elif state.rules.win_conditions.turn_timer_enabled:
# Determine if we should start the turn timer:
# 1. Turn changed (player switched turns)
# 2. SETUP phase just ended (first real turn began)
setup_ended = phase_before == TurnPhase.SETUP and state.phase != TurnPhase.SETUP
should_start_timer = turn_changed or setup_ended
if should_start_timer:
timeout_info = await self._timeout_service.start_turn_timer(
game_id=game_id,
player_id=state.current_player_id,
timeout_seconds=state.rules.win_conditions.turn_timer_seconds,
warning_thresholds=state.rules.win_conditions.turn_timer_warning_thresholds,
)
action_result.turn_timeout_seconds = timeout_info.remaining_seconds
action_result.turn_deadline = timeout_info.deadline
logger.debug(
f"Started turn timer: game={game_id}, player={state.current_player_id}, "
f"timeout={timeout_info.timeout_seconds}s, "
f"reason={'setup_ended' if setup_ended else 'turn_changed'}"
)
logger.debug(f"Action executed: game={game_id}, player={player_id}, type={action.type}")
@ -734,6 +867,50 @@ class GameService:
action=ResignAction(),
)
async def handle_timeout(
self,
game_id: str,
timed_out_player_id: str,
) -> GameEndResult:
"""Handle a turn timeout.
Called by the background timeout polling task when a player's
turn timer expires. Declares the timed-out player as the loser.
Future enhancement: Could implement auto-pass for first timeout,
loss only after N consecutive timeouts.
Args:
game_id: The game ID.
timed_out_player_id: The player who timed out.
Returns:
GameEndResult with timeout as the end reason.
Raises:
GameNotFoundError: If game doesn't exist.
"""
state = await self.get_game_state(game_id)
# Determine winner (the opponent)
player_ids = list(state.players.keys())
winner_id: str | None = None
for pid in player_ids:
if pid != timed_out_player_id:
winner_id = pid
break
logger.info(
f"Turn timeout: game={game_id}, timed_out_player={timed_out_player_id}, "
f"winner={winner_id}"
)
return await self.end_game(
game_id=game_id,
winner_id=winner_id,
end_reason=GameEndReason.TIMEOUT,
)
async def end_game(
self,
game_id: str,
@ -766,6 +943,9 @@ class GameService:
"""
state = await self.get_game_state(game_id)
# Cancel any active turn timer
await self._timeout_service.cancel_timer(game_id)
# Set winner and end reason on game state
state.winner_id = winner_id
state.end_reason = end_reason
@ -1015,6 +1195,10 @@ class GameService:
logger.error(f"Failed to persist game state: {e}")
raise GameCreationError(f"Failed to persist game: {e}") from e
# NOTE: Turn timer is NOT started here during SETUP phase.
# Timer starts when SETUP completes (both players select basic pokemon)
# and the first real turn begins. See execute_action() for timer start logic.
# Get player-visible views
player1_view = get_visible_state(game, p1_str)
player2_view = get_visible_state(game, p2_str)

View File

@ -0,0 +1,550 @@
"""Turn timeout management for Mantimon TCG.
This module manages turn time limits using a polling-based approach with
Redis for state storage. It handles:
- Starting and canceling turn timers
- Checking for expired timers
- Sending percentage-based warnings (e.g., 50%, 25% remaining)
- Granting grace periods on reconnection
Key Patterns:
turn_timeout:{game_id} - Hash with timeout data:
- player_id: The player whose turn it is
- deadline: Unix timestamp when the turn expires
- timeout_seconds: Original timeout duration (for % calculation)
- warnings_sent: JSON array of thresholds already sent
Design Decisions:
- Polling approach (not keyspace notifications) for simplicity and
reliability. A background task polls check_expired_timers() periodically.
- Warnings are percentage-based (configurable, default 50% and 25%)
so they scale with different timeout durations.
- Grace period on reconnect extends the deadline without resetting
the warning state.
Example:
from app.services.turn_timeout_service import turn_timeout_service
# Start a timer when a turn begins
await turn_timeout_service.start_turn_timer(
game_id="game-123",
player_id="player-1",
timeout_seconds=180,
warning_thresholds=[50, 25],
)
# Check for warnings to send
warning = await turn_timeout_service.get_pending_warning("game-123")
if warning:
# Send warning to player
await turn_timeout_service.mark_warning_sent("game-123", warning.threshold)
# On reconnect, grant grace period
await turn_timeout_service.extend_timer("game-123", extension_seconds=15)
# Background task polls for expired timers
expired = await turn_timeout_service.check_expired_timers()
for game_id in expired:
# Handle timeout (auto-pass or loss)
"""
import json
import logging
from collections.abc import AsyncIterator, Callable
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from app.db.redis import get_redis
if TYPE_CHECKING:
from redis.asyncio import Redis
logger = logging.getLogger(__name__)
# Type alias for redis factory
RedisFactory = Callable[[], AsyncIterator["Redis"]]
# Redis key patterns
TURN_TIMEOUT_PREFIX = "turn_timeout:"
# Default TTL for timeout keys (cleanup buffer beyond actual timeout)
DEFAULT_KEY_TTL_BUFFER = 300 # 5 minutes beyond deadline
@dataclass
class TurnTimeoutInfo:
"""Information about a turn timeout.
Attributes:
game_id: The game ID.
player_id: The player whose turn is timing out.
deadline: Unix timestamp when the turn expires.
timeout_seconds: Original timeout duration.
remaining_seconds: Seconds remaining until timeout.
warnings_sent: List of warning thresholds already sent.
warning_thresholds: Configured warning thresholds.
"""
game_id: str
player_id: str
deadline: float
timeout_seconds: int
remaining_seconds: int
warnings_sent: list[int]
warning_thresholds: list[int]
@property
def is_expired(self) -> bool:
"""Check if the timeout has expired."""
return self.remaining_seconds <= 0
@property
def percent_remaining(self) -> float:
"""Get the percentage of time remaining."""
if self.timeout_seconds <= 0:
return 0.0
return (self.remaining_seconds / self.timeout_seconds) * 100
@dataclass
class PendingWarning:
"""Information about a warning that should be sent.
Attributes:
game_id: The game ID.
player_id: The player to warn.
threshold: The warning threshold percentage (e.g., 50, 25).
remaining_seconds: Seconds remaining when warning triggered.
"""
game_id: str
player_id: str
threshold: int
remaining_seconds: int
class TurnTimeoutService:
"""Service for managing turn timeouts.
Uses Redis for persistent storage of timeout state. Designed for a
polling model where a background task periodically checks for expired
timers and pending warnings.
Attributes:
_get_redis: Factory for Redis connections.
"""
def __init__(
self,
redis_factory: RedisFactory | None = None,
) -> None:
"""Initialize the TurnTimeoutService.
Args:
redis_factory: Optional factory for Redis connections. If not provided,
uses the default get_redis from app.db.redis. Useful for testing.
"""
self._get_redis = redis_factory if redis_factory is not None else get_redis
def _timeout_key(self, game_id: str) -> str:
"""Generate Redis key for a game's timeout data."""
return f"{TURN_TIMEOUT_PREFIX}{game_id}"
# =========================================================================
# Timer Lifecycle
# =========================================================================
async def start_turn_timer(
self,
game_id: str,
player_id: str,
timeout_seconds: int,
warning_thresholds: list[int] | None = None,
) -> TurnTimeoutInfo:
"""Start a turn timer for a game.
Creates or replaces the timeout data for the game. Resets warnings
since this is a new turn.
Args:
game_id: The game ID.
player_id: The player whose turn is starting.
timeout_seconds: Seconds until the turn times out.
warning_thresholds: Percentage thresholds for warnings (e.g., [50, 25]).
Defaults to [50, 25] if not provided.
Returns:
TurnTimeoutInfo with the new timer state.
Example:
info = await service.start_turn_timer("game-123", "player-1", 180)
print(f"Turn expires at {info.deadline}")
"""
if warning_thresholds is None:
warning_thresholds = [50, 25]
# Sort thresholds descending so we warn at highest % first
warning_thresholds = sorted(warning_thresholds, reverse=True)
deadline = datetime.now(UTC).timestamp() + timeout_seconds
async with self._get_redis() as redis:
key = self._timeout_key(game_id)
await redis.hset(
key,
mapping={
"player_id": player_id,
"deadline": str(deadline),
"timeout_seconds": str(timeout_seconds),
"warnings_sent": json.dumps([]),
"warning_thresholds": json.dumps(warning_thresholds),
},
)
# Set TTL slightly beyond deadline for auto-cleanup
ttl = timeout_seconds + DEFAULT_KEY_TTL_BUFFER
await redis.expire(key, ttl)
logger.debug(
f"Started turn timer: game={game_id}, player={player_id}, "
f"timeout={timeout_seconds}s, thresholds={warning_thresholds}"
)
return TurnTimeoutInfo(
game_id=game_id,
player_id=player_id,
deadline=deadline,
timeout_seconds=timeout_seconds,
remaining_seconds=timeout_seconds,
warnings_sent=[],
warning_thresholds=warning_thresholds,
)
async def cancel_timer(self, game_id: str) -> bool:
"""Cancel a turn timer.
Removes the timeout data for the game. Use when a turn ends
normally or the game ends.
Args:
game_id: The game ID.
Returns:
True if a timer was canceled, False if none existed.
Example:
canceled = await service.cancel_timer("game-123")
"""
async with self._get_redis() as redis:
key = self._timeout_key(game_id)
deleted = await redis.delete(key)
if deleted:
logger.debug(f"Canceled turn timer: game={game_id}")
return deleted > 0
async def extend_timer(
self,
game_id: str,
extension_seconds: int,
) -> TurnTimeoutInfo | None:
"""Extend a turn timer (e.g., on reconnection).
Adds time to the existing deadline without resetting warnings.
The extension is capped so the total time doesn't exceed the
original timeout.
Args:
game_id: The game ID.
extension_seconds: Seconds to add to the deadline.
Returns:
Updated TurnTimeoutInfo, or None if no timer exists.
Example:
# Grant 15 seconds grace on reconnect
info = await service.extend_timer("game-123", 15)
"""
info = await self.get_timeout_info(game_id)
if info is None:
return None
# Calculate new deadline, capped at original timeout
now = datetime.now(UTC).timestamp()
new_deadline = info.deadline + extension_seconds
max_deadline = now + info.timeout_seconds
if new_deadline > max_deadline:
new_deadline = max_deadline
async with self._get_redis() as redis:
key = self._timeout_key(game_id)
await redis.hset(key, "deadline", str(new_deadline))
# Refresh TTL
remaining = int(new_deadline - now) + DEFAULT_KEY_TTL_BUFFER
if remaining > 0:
await redis.expire(key, remaining)
new_remaining = max(0, int(new_deadline - now))
logger.debug(
f"Extended turn timer: game={game_id}, "
f"added={extension_seconds}s, remaining={new_remaining}s"
)
return TurnTimeoutInfo(
game_id=info.game_id,
player_id=info.player_id,
deadline=new_deadline,
timeout_seconds=info.timeout_seconds,
remaining_seconds=new_remaining,
warnings_sent=info.warnings_sent,
warning_thresholds=info.warning_thresholds,
)
# =========================================================================
# Query Methods
# =========================================================================
async def get_timeout_info(self, game_id: str) -> TurnTimeoutInfo | None:
"""Get timeout information for a game.
Args:
game_id: The game ID.
Returns:
TurnTimeoutInfo if a timer exists, None otherwise.
Example:
info = await service.get_timeout_info("game-123")
if info:
print(f"{info.remaining_seconds}s remaining")
"""
async with self._get_redis() as redis:
key = self._timeout_key(game_id)
data = await redis.hgetall(key)
if not data:
return None
# Validate required fields
player_id = data.get("player_id")
deadline_str = data.get("deadline")
timeout_str = data.get("timeout_seconds")
if not player_id or not deadline_str or not timeout_str:
logger.warning(f"Corrupted timeout data for game {game_id}")
return None
try:
deadline = float(deadline_str)
timeout_seconds = int(timeout_str)
warnings_sent = json.loads(data.get("warnings_sent", "[]"))
warning_thresholds = json.loads(data.get("warning_thresholds", "[50, 25]"))
except (ValueError, json.JSONDecodeError) as e:
logger.warning(f"Invalid timeout data for game {game_id}: {e}")
return None
now = datetime.now(UTC).timestamp()
remaining = max(0, int(deadline - now))
return TurnTimeoutInfo(
game_id=game_id,
player_id=player_id,
deadline=deadline,
timeout_seconds=timeout_seconds,
remaining_seconds=remaining,
warnings_sent=warnings_sent,
warning_thresholds=warning_thresholds,
)
async def get_remaining_time(self, game_id: str) -> int | None:
"""Get remaining seconds for a game's turn timer.
Convenience method that returns just the remaining time.
Args:
game_id: The game ID.
Returns:
Seconds remaining, or None if no timer exists.
Example:
remaining = await service.get_remaining_time("game-123")
if remaining is not None and remaining < 30:
# Show low time warning in UI
"""
info = await self.get_timeout_info(game_id)
return info.remaining_seconds if info else None
# =========================================================================
# Warning Management
# =========================================================================
async def get_pending_warning(self, game_id: str) -> PendingWarning | None:
"""Check if a warning should be sent for a game.
Returns the highest-priority unsent warning if the time remaining
has dropped below a warning threshold.
Args:
game_id: The game ID.
Returns:
PendingWarning if one should be sent, None otherwise.
Example:
warning = await service.get_pending_warning("game-123")
if warning:
await send_warning_to_player(warning.player_id, warning.remaining_seconds)
await service.mark_warning_sent("game-123", warning.threshold)
"""
info = await self.get_timeout_info(game_id)
if info is None or info.is_expired:
return None
percent_remaining = info.percent_remaining
# Check each threshold (sorted descending)
for threshold in info.warning_thresholds:
if threshold in info.warnings_sent:
continue # Already sent this warning
if percent_remaining <= threshold:
return PendingWarning(
game_id=game_id,
player_id=info.player_id,
threshold=threshold,
remaining_seconds=info.remaining_seconds,
)
return None
async def mark_warning_sent(self, game_id: str, threshold: int) -> bool:
"""Mark a warning threshold as sent.
Call this after successfully sending a warning to the player
to prevent duplicate warnings.
Args:
game_id: The game ID.
threshold: The warning threshold that was sent.
Returns:
True if marked successfully, False if timer doesn't exist.
Example:
await service.mark_warning_sent("game-123", 50)
"""
info = await self.get_timeout_info(game_id)
if info is None:
return False
if threshold in info.warnings_sent:
return True # Already marked
warnings_sent = info.warnings_sent + [threshold]
async with self._get_redis() as redis:
key = self._timeout_key(game_id)
await redis.hset(key, "warnings_sent", json.dumps(warnings_sent))
logger.debug(f"Marked warning sent: game={game_id}, threshold={threshold}%")
return True
# =========================================================================
# Expiration Checking
# =========================================================================
async def check_expired_timers(self) -> list[TurnTimeoutInfo]:
"""Check for expired turn timers.
Scans all active timers and returns those that have expired.
The caller is responsible for handling the timeouts (auto-pass,
loss declaration, etc.).
Returns:
List of TurnTimeoutInfo for expired timers.
Example:
# In background task, poll every 5 seconds
expired = await service.check_expired_timers()
for info in expired:
await handle_turn_timeout(info.game_id, info.player_id)
await service.cancel_timer(info.game_id)
"""
expired: list[TurnTimeoutInfo] = []
async with self._get_redis() as redis:
# Scan for all timeout keys
async for key in redis.scan_iter(match=f"{TURN_TIMEOUT_PREFIX}*"):
game_id = key[len(TURN_TIMEOUT_PREFIX) :]
info = await self.get_timeout_info(game_id)
if info is not None and info.is_expired:
expired.append(info)
logger.debug(f"Found expired timer: game={game_id}")
return expired
async def get_all_pending_warnings(self) -> list[PendingWarning]:
"""Get all pending warnings across all games.
Scans all active timers and returns warnings that should be sent.
Useful for a background task that handles warnings in batch.
Returns:
List of PendingWarning for all games needing warnings.
Example:
# In background task
warnings = await service.get_all_pending_warnings()
for warning in warnings:
await send_warning(warning)
await service.mark_warning_sent(warning.game_id, warning.threshold)
"""
warnings: list[PendingWarning] = []
async with self._get_redis() as redis:
async for key in redis.scan_iter(match=f"{TURN_TIMEOUT_PREFIX}*"):
game_id = key[len(TURN_TIMEOUT_PREFIX) :]
warning = await self.get_pending_warning(game_id)
if warning is not None:
warnings.append(warning)
return warnings
# =========================================================================
# Utility Methods
# =========================================================================
async def get_active_timer_count(self) -> int:
"""Get the number of active turn timers.
Useful for monitoring and admin dashboards.
Returns:
Number of active timers.
"""
count = 0
async with self._get_redis() as redis:
async for _ in redis.scan_iter(match=f"{TURN_TIMEOUT_PREFIX}*"):
count += 1
return count
async def has_active_timer(self, game_id: str) -> bool:
"""Check if a game has an active turn timer.
Args:
game_id: The game ID.
Returns:
True if a timer exists (even if expired), False otherwise.
"""
async with self._get_redis() as redis:
key = self._timeout_key(game_id)
return await redis.exists(key) > 0
# Global singleton instance
turn_timeout_service = TurnTimeoutService()

View File

@ -7,10 +7,12 @@ GameService to handle:
- Action execution and result broadcasting
- Resignation handling
- Disconnect notifications to opponents
- Automatic reconnection to active games
Architecture:
Socket.IO Events -> GameNamespaceHandler -> GameService
-> ConnectionManager (for routing)
-> GameStateManager (for active game lookup)
-> Socket.IO Emits (responses)
The handler is designed with dependency injection for testability.
@ -27,11 +29,12 @@ Example:
import logging
from typing import TYPE_CHECKING, Any
from uuid import UUID
from pydantic import ValidationError
from app.core.models.actions import parse_action
from app.core.visibility import get_visible_state
from app.core.visibility import get_spectator_state, get_visible_state
from app.schemas.ws_messages import (
ConnectionStatus,
GameOverMessage,
@ -41,6 +44,7 @@ from app.schemas.ws_messages import (
)
from app.services.connection_manager import ConnectionManager, connection_manager
from app.services.game_service import (
CannotSpectateOwnGameError,
ForcedActionRequiredError,
GameAlreadyEndedError,
GameNotFoundError,
@ -50,6 +54,7 @@ from app.services.game_service import (
PlayerNotInGameError,
game_service,
)
from app.services.game_state_manager import GameStateManager, game_state_manager
if TYPE_CHECKING:
import socketio
@ -66,21 +71,25 @@ class GameNamespaceHandler:
Attributes:
_game_service: GameService for game operations.
_connection_manager: ConnectionManager for connection tracking.
_state_manager: GameStateManager for active game lookup.
"""
def __init__(
self,
game_svc: GameService | None = None,
conn_manager: ConnectionManager | None = None,
state_manager: GameStateManager | None = None,
) -> None:
"""Initialize the GameNamespaceHandler.
Args:
game_svc: GameService instance. Uses global if not provided.
conn_manager: ConnectionManager instance. Uses global if not provided.
state_manager: GameStateManager instance. Uses global if not provided.
"""
self._game_service = game_svc or game_service
self._connection_manager = conn_manager or connection_manager
self._state_manager = state_manager or game_state_manager
# =========================================================================
# Event Handlers
@ -163,6 +172,11 @@ class GameNamespaceHandler:
"params": result.pending_forced_action.params,
}
# Include turn timer info if enabled
if result.turn_timeout_seconds is not None:
response["turn_timeout_seconds"] = result.turn_timeout_seconds
response["turn_deadline"] = result.turn_deadline
logger.info(f"Player {user_id} joined game {game_id}")
return response
@ -255,6 +269,11 @@ class GameNamespaceHandler:
"params": result.pending_forced_action.params,
}
# Include turn timer info if enabled and turn changed
if result.turn_timeout_seconds is not None:
response["turn_timeout_seconds"] = result.turn_timeout_seconds
response["turn_deadline"] = result.turn_deadline
# Handle game over
if result.game_over:
response["game_over"] = True
@ -411,6 +430,8 @@ class GameNamespaceHandler:
"""Handle disconnect event.
Notifies opponents in any active game that the player disconnected.
Note: In production, consider debouncing disconnect notifications
to handle rapid disconnect/reconnect cycles (e.g., network hiccups).
Args:
sio: Socket.IO server instance.
@ -424,6 +445,15 @@ class GameNamespaceHandler:
game_id = conn_info.game_id
# Check if spectating (game_id format: "spectating:{actual_game_id}")
if game_id.startswith("spectating:"):
actual_game_id = game_id[len("spectating:") :]
# Unregister spectator and broadcast updated count
await self._connection_manager.unregister_spectator(sid, actual_game_id)
await self._broadcast_spectator_count(sio, actual_game_id)
logger.info(f"Spectator {user_id} left game {actual_game_id}")
return
# Notify opponent of disconnect
await self._notify_opponent_status(
sio, sid, game_id, user_id, ConnectionStatus.DISCONNECTED
@ -431,6 +461,245 @@ class GameNamespaceHandler:
logger.info(f"Player {user_id} disconnected from game {game_id}")
async def handle_spectate(
self,
sio: "socketio.AsyncServer",
sid: str,
user_id: str,
data: dict[str, Any],
) -> dict[str, Any]:
"""Handle game:spectate event.
Allows a user to spectate a game they are not participating in.
Spectators receive a filtered view with no hands visible.
Args:
sio: Socket.IO server instance.
sid: Socket session ID.
user_id: Authenticated user's ID.
data: Message data with game_id.
Returns:
Response dict with success status and spectator state or error.
"""
game_id = data.get("game_id")
message_id = data.get("message_id", "")
if not game_id:
logger.warning(f"game:spectate missing game_id from {sid}")
return self._error_response(
WSErrorCode.INVALID_MESSAGE,
"game_id is required",
message_id,
)
try:
# Get spectator view via GameService
result = await self._game_service.spectate_game(
game_id=game_id,
user_id=user_id,
)
if not result.success:
return self._error_response(
WSErrorCode.INTERNAL_ERROR,
result.message,
message_id,
)
# Register this connection as a spectator
await self._connection_manager.register_spectator(sid, user_id, game_id)
# Join the spectators room for this game
await sio.enter_room(sid, f"spectators:{game_id}", namespace="/game")
# Also join the game room to receive general updates
await sio.enter_room(sid, f"game:{game_id}", namespace="/game")
# Broadcast updated spectator count to players
await self._broadcast_spectator_count(sio, game_id)
# Build response with spectator state
response: dict[str, Any] = {
"success": True,
"game_id": game_id,
"game_over": result.game_over,
"spectator_count": await self._connection_manager.get_spectator_count(game_id),
}
if result.visible_state:
response["state"] = result.visible_state.model_dump(mode="json")
logger.info(f"User {user_id} started spectating game {game_id}")
return response
except GameNotFoundError:
return self._error_response(
WSErrorCode.GAME_NOT_FOUND,
f"Game {game_id} not found",
message_id,
)
except CannotSpectateOwnGameError:
return self._error_response(
WSErrorCode.ACTION_NOT_ALLOWED,
"You cannot spectate a game you are playing in",
message_id,
)
except Exception as e:
logger.exception(f"Error spectating game {game_id}: {e}")
return self._error_response(
WSErrorCode.INTERNAL_ERROR,
"Failed to spectate game",
message_id,
)
async def handle_leave_spectate(
self,
sio: "socketio.AsyncServer",
sid: str,
user_id: str,
data: dict[str, Any],
) -> dict[str, Any]:
"""Handle game:leave_spectate event.
Allows a spectator to stop watching a game.
Args:
sio: Socket.IO server instance.
sid: Socket session ID.
user_id: Authenticated user's ID.
data: Message data with game_id.
Returns:
Response dict with success status.
"""
game_id = data.get("game_id")
message_id = data.get("message_id", "")
if not game_id:
return self._error_response(
WSErrorCode.INVALID_MESSAGE,
"game_id is required",
message_id,
)
# Unregister spectator
await self._connection_manager.unregister_spectator(sid, game_id)
# Leave the rooms
await sio.leave_room(sid, f"spectators:{game_id}", namespace="/game")
await sio.leave_room(sid, f"game:{game_id}", namespace="/game")
# Broadcast updated spectator count to players
await self._broadcast_spectator_count(sio, game_id)
logger.info(f"User {user_id} stopped spectating game {game_id}")
return {
"success": True,
"game_id": game_id,
}
async def handle_reconnect(
self,
sio: "socketio.AsyncServer",
sid: str,
user_id: str,
) -> dict[str, Any] | None:
"""Handle automatic reconnection to active games on connect.
Called after successful authentication to check if the user has
an active game and automatically rejoin them.
This method:
1. Queries for active games via GameStateManager
2. If found, auto-joins the game room
3. Sends full game state to reconnecting player
4. Extends turn timer if it's their turn
5. Notifies opponent of reconnection
Args:
sio: Socket.IO server instance.
sid: Socket session ID.
user_id: Authenticated user's ID (as string, may be UUID).
Returns:
Dict with reconnection info if rejoined, None if no active game.
The dict contains:
- game_id: The game that was rejoined
- is_your_turn: Whether it's the player's turn
- state: The visible game state
- pending_forced_action: Any pending forced action
- turn_timeout_seconds: Remaining turn time if applicable
- turn_deadline: Unix timestamp when turn expires
"""
# Try to get active games from database
try:
player_uuid = UUID(user_id)
except ValueError:
# User ID is not a valid UUID (e.g., NPC) - no active games
return None
active_games = await self._state_manager.get_player_active_games(player_uuid)
if not active_games:
logger.debug(f"No active games for user {user_id}")
return None
# If multiple active games, use the most recent one
# (sorted by last_action_at descending)
active_games.sort(key=lambda g: g.last_action_at or g.started_at, reverse=True)
active_game = active_games[0]
game_id = str(active_game.id)
logger.info(f"Auto-rejoining user {user_id} to game {game_id}")
# Join the game via GameService (handles timer extension)
result = await self._game_service.join_game(
game_id=game_id,
player_id=user_id,
last_event_id=None, # Full state refresh on reconnect
)
if not result.success:
logger.warning(f"Failed to auto-rejoin game {game_id}: {result.message}")
return None
# Register connection with game
await self._connection_manager.join_game(sid, game_id)
# Join the Socket.IO room
await sio.enter_room(sid, f"game:{game_id}", namespace="/game")
# Notify opponent of reconnection
await self._notify_opponent_status(sio, sid, game_id, user_id, ConnectionStatus.CONNECTED)
# Build reconnection response
response: dict[str, Any] = {
"game_id": game_id,
"is_your_turn": result.is_your_turn,
"game_over": result.game_over,
}
if result.visible_state:
response["state"] = result.visible_state.model_dump(mode="json")
if result.pending_forced_action:
response["pending_forced_action"] = {
"player_id": result.pending_forced_action.player_id,
"action_type": result.pending_forced_action.action_type,
"reason": result.pending_forced_action.reason,
"params": result.pending_forced_action.params,
}
if result.turn_timeout_seconds is not None:
response["turn_timeout_seconds"] = result.turn_timeout_seconds
response["turn_deadline"] = result.turn_deadline
logger.info(f"Player {user_id} auto-rejoined game {game_id}")
return response
# =========================================================================
# Broadcast Helpers
# =========================================================================
@ -440,9 +709,10 @@ class GameNamespaceHandler:
sio: "socketio.AsyncServer",
game_id: str,
) -> None:
"""Broadcast filtered game state to all participants.
"""Broadcast filtered game state to all participants and spectators.
Each player receives their own visibility-filtered view of the game.
Spectators receive a view with no hands visible.
Args:
sio: Socket.IO server instance.
@ -452,6 +722,9 @@ class GameNamespaceHandler:
# Get full game state
state = await self._game_service.get_game_state(game_id)
# Get spectator count for inclusion in player messages
spectator_count = await self._connection_manager.get_spectator_count(game_id)
# Get all connected players for this game
user_sids = await self._connection_manager.get_game_user_sids(game_id)
@ -462,6 +735,7 @@ class GameNamespaceHandler:
message = GameStateMessage(
game_id=game_id,
state=visible_state,
spectator_count=spectator_count,
)
await sio.emit(
"game:state",
@ -474,6 +748,23 @@ class GameNamespaceHandler:
logger.warning(f"Player {player_id} not in game {game_id}")
continue
# Send spectator state to all spectators
spectator_sids = await self._connection_manager.get_game_spectators(game_id)
if spectator_sids:
spectator_state = get_spectator_state(state)
spectator_message = GameStateMessage(
game_id=game_id,
state=spectator_state,
spectator_count=spectator_count,
)
for spectator_sid in spectator_sids:
await sio.emit(
"game:state",
spectator_message.model_dump(mode="json"),
to=spectator_sid,
namespace="/game",
)
except GameNotFoundError:
logger.warning(f"Cannot broadcast state: game {game_id} not found")
except Exception as e:
@ -486,7 +777,7 @@ class GameNamespaceHandler:
winner_id: str | None,
end_reason: Any,
) -> None:
"""Broadcast game over notification to all participants.
"""Broadcast game over notification to all participants and spectators.
Args:
sio: Socket.IO server instance.
@ -517,6 +808,24 @@ class GameNamespaceHandler:
except ValueError:
continue
# Send to spectators
spectator_sids = await self._connection_manager.get_game_spectators(game_id)
if spectator_sids:
spectator_state = get_spectator_state(state)
spectator_message = GameOverMessage(
game_id=game_id,
winner_id=winner_id,
end_reason=end_reason,
final_state=spectator_state,
)
for spectator_sid in spectator_sids:
await sio.emit(
"game:game_over",
spectator_message.model_dump(mode="json"),
to=spectator_sid,
namespace="/game",
)
except GameNotFoundError:
# Game already archived - just emit to room without state
message = GameOverMessage(
@ -572,6 +881,36 @@ class GameNamespaceHandler:
except Exception as e:
logger.exception(f"Error notifying opponent status: {e}")
async def _broadcast_spectator_count(
self,
sio: "socketio.AsyncServer",
game_id: str,
) -> None:
"""Broadcast spectator count to all players in a game.
Called when spectators join or leave to keep players informed.
Args:
sio: Socket.IO server instance.
game_id: The game ID.
"""
try:
spectator_count = await self._connection_manager.get_spectator_count(game_id)
user_sids = await self._connection_manager.get_game_user_sids(game_id)
for player_sid in user_sids.values():
await sio.emit(
"game:spectator_count",
{"game_id": game_id, "spectator_count": spectator_count},
to=player_sid,
namespace="/game",
)
logger.debug(f"Broadcast spectator count for game {game_id}: {spectator_count}")
except Exception as e:
logger.exception(f"Error broadcasting spectator count for {game_id}: {e}")
# =========================================================================
# Helper Methods
# =========================================================================

View File

@ -69,6 +69,7 @@ async def connect(
Authenticates the connection using JWT from auth data.
Rejects connections without valid authentication.
After successful auth, checks for active games and auto-rejoins if found.
Args:
sid: Socket session ID assigned by Socket.IO.
@ -103,6 +104,21 @@ async def connect(
await auth_handler.setup_authenticated_session(sio, sid, auth_result.user_id, namespace="/game")
logger.info(f"Client authenticated to /game: sid={sid}, user_id={auth_result.user_id}")
# Check for active games and auto-rejoin
user_id_str = str(auth_result.user_id)
reconnect_info = await game_namespace_handler.handle_reconnect(sio, sid, user_id_str)
if reconnect_info:
# Emit reconnection event to inform client they're back in a game
await sio.emit(
"game:reconnected",
reconnect_info,
to=sid,
namespace="/game",
)
logger.info(f"Emitted game:reconnected for {sid}: game_id={reconnect_info.get('game_id')}")
return True
@ -247,6 +263,60 @@ async def on_game_heartbeat(sid: str, data: dict[str, object] | None = None) ->
}
@sio.on("game:spectate", namespace="/game")
async def on_game_spectate(sid: str, data: dict[str, object]) -> dict[str, object]:
"""Handle request to spectate a game.
Authenticates the request and delegates to GameNamespaceHandler.
On success, the user receives a spectator-filtered game state
(no hands visible).
Args:
sid: Socket session ID.
data: Message containing game_id to spectate.
Returns:
Response with spectator game state or error.
"""
logger.debug(f"game:spectate from {sid}: {data}")
# Require authentication
user_id = await require_auth(sio, sid)
if not user_id:
return {
"success": False,
"error": {"code": "unauthenticated", "message": "Not authenticated"},
}
return await game_namespace_handler.handle_spectate(sio, sid, user_id, dict(data))
@sio.on("game:leave_spectate", namespace="/game")
async def on_game_leave_spectate(sid: str, data: dict[str, object]) -> dict[str, object]:
"""Handle request to stop spectating a game.
Authenticates the request and removes the user from spectator list.
Args:
sid: Socket session ID.
data: Message containing game_id to stop spectating.
Returns:
Confirmation of spectate leave.
"""
logger.debug(f"game:leave_spectate from {sid}: {data}")
# Require authentication
user_id = await require_auth(sio, sid)
if not user_id:
return {
"success": False,
"error": {"code": "unauthenticated", "message": "Not authenticated"},
}
return await game_namespace_handler.handle_leave_spectate(sio, sid, user_id, dict(data))
# =============================================================================
# ASGI App Creation
# =============================================================================

View File

@ -622,6 +622,104 @@ class TestQueryMethods:
assert result == opponent_sid
@pytest.mark.asyncio
async def test_get_user_active_game_returns_game_id(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that get_user_active_game returns the game_id from connection.
This method looks up the user's connection and returns their
current game_id if they are in a game.
"""
user_id = "user-123"
game_id = "game-456"
now = datetime.now(UTC)
# Mock: user has a connection with a game
mock_redis.get.return_value = "test-sid"
mock_redis.hgetall.return_value = {
"user_id": user_id,
"game_id": game_id,
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
}
result = await manager.get_user_active_game(user_id)
assert result == game_id
@pytest.mark.asyncio
async def test_get_user_active_game_returns_none_when_not_in_game(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that get_user_active_game returns None when user has no game.
If the user is connected but not in a game, their game_id will
be None (or empty string in Redis), so we return None.
"""
user_id = "user-123"
now = datetime.now(UTC)
mock_redis.get.return_value = "test-sid"
mock_redis.hgetall.return_value = {
"user_id": user_id,
"game_id": "", # Empty string means not in a game
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
}
result = await manager.get_user_active_game(user_id)
# Empty string game_id becomes None in ConnectionInfo
assert result is None
@pytest.mark.asyncio
async def test_get_user_active_game_returns_none_when_not_connected(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that get_user_active_game returns None when user is offline.
If the user has no active connection, there's no game to return.
"""
mock_redis.get.return_value = None # No connection for user
result = await manager.get_user_active_game("user-123")
assert result is None
@pytest.mark.asyncio
async def test_get_user_active_game_accepts_uuid(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that get_user_active_game accepts UUID objects.
The method should work with both string and UUID user IDs
for convenience.
"""
user_uuid = uuid4()
game_id = "game-789"
now = datetime.now(UTC)
mock_redis.get.return_value = "test-sid"
mock_redis.hgetall.return_value = {
"user_id": str(user_uuid),
"game_id": game_id,
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
}
result = await manager.get_user_active_game(user_uuid)
assert result == game_id
class TestKeyGeneration:
"""Tests for Redis key generation methods."""
@ -653,3 +751,236 @@ class TestKeyGeneration:
manager = ConnectionManager()
key = manager._game_conns_key("game-789")
assert key == "game_conns:game-789"
def test_spectators_key_format(self) -> None:
"""Test that spectators keys have correct format.
Keys should follow the pattern spectators:{game_id}.
"""
manager = ConnectionManager()
key = manager._spectators_key("game-789")
assert key == "spectators:game-789"
class TestSpectatorManagement:
"""Tests for spectator-related methods.
Spectators are tracked separately from game participants using
a dedicated Redis set per game.
"""
@pytest.mark.asyncio
async def test_register_spectator_success(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test registering a spectator adds them to the spectators set.
When a user starts spectating a game, they should be added to
the spectators:{game_id} set and their connection's game_id should
be updated to indicate spectating.
"""
sid = "spectator-sid"
user_id = "user-123"
game_id = "game-456"
mock_redis.exists.return_value = True
result = await manager.register_spectator(sid, user_id, game_id)
assert result is True
# Should update connection's game_id to indicate spectating
mock_redis.hset.assert_called_with(
f"{CONN_PREFIX}{sid}",
"game_id",
f"spectating:{game_id}",
)
# Should add to spectators set
mock_redis.sadd.assert_called_with(f"spectators:{game_id}", sid)
# Should set TTL on spectators set
mock_redis.expire.assert_called()
@pytest.mark.asyncio
async def test_register_spectator_connection_not_found(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test registering spectator fails if connection doesn't exist.
A user must have an active connection before they can spectate.
"""
mock_redis.exists.return_value = False
result = await manager.register_spectator("unknown-sid", "user-123", "game-456")
assert result is False
mock_redis.sadd.assert_not_called()
@pytest.mark.asyncio
async def test_unregister_spectator_success(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test unregistering a spectator removes them from the set.
When a spectator leaves, they should be removed from the
spectators set and their connection's game_id should be cleared.
"""
sid = "spectator-sid"
game_id = "game-456"
mock_redis.srem.return_value = 1 # 1 member removed
mock_redis.hget.return_value = f"spectating:{game_id}"
result = await manager.unregister_spectator(sid, game_id)
assert result is True
mock_redis.srem.assert_called_with(f"spectators:{game_id}", sid)
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "")
@pytest.mark.asyncio
async def test_unregister_spectator_not_in_set(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test unregistering non-spectator returns False.
If the sid is not in the spectators set, unregister should
return False to indicate nothing was removed.
"""
mock_redis.srem.return_value = 0 # No members removed
mock_redis.hget.return_value = ""
result = await manager.unregister_spectator("unknown-sid", "game-456")
assert result is False
@pytest.mark.asyncio
async def test_get_spectator_count(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test get_spectator_count returns the set cardinality.
Should use Redis SCARD to get the number of spectators efficiently.
"""
mock_redis.scard.return_value = 5
result = await manager.get_spectator_count("game-456")
assert result == 5
mock_redis.scard.assert_called_with("spectators:game-456")
@pytest.mark.asyncio
async def test_get_game_spectators(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test get_game_spectators returns all spectator sids.
Should return all sids from the spectators set for the game.
"""
expected_sids = {"sid-1", "sid-2", "sid-3"}
mock_redis.smembers.return_value = expected_sids
result = await manager.get_game_spectators("game-456")
assert set(result) == expected_sids
mock_redis.smembers.assert_called_with("spectators:game-456")
@pytest.mark.asyncio
async def test_is_spectating_returns_true(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test is_spectating returns True when sid is spectating.
Should use Redis SISMEMBER for efficient membership check.
"""
mock_redis.sismember.return_value = True
result = await manager.is_spectating("spectator-sid", "game-456")
assert result is True
mock_redis.sismember.assert_called_with("spectators:game-456", "spectator-sid")
@pytest.mark.asyncio
async def test_is_spectating_returns_false(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test is_spectating returns False when not spectating.
Should return False for sids not in the spectators set.
"""
mock_redis.sismember.return_value = False
result = await manager.is_spectating("player-sid", "game-456")
assert result is False
class TestCleanupWithSpectators:
"""Tests for connection cleanup including spectator state."""
@pytest.mark.asyncio
async def test_cleanup_removes_spectator_from_set(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that cleanup removes spectator from spectators set.
When a spectating connection is cleaned up (disconnect), the
sid should be removed from the spectators set.
"""
sid = "spectator-sid"
game_id = "game-456"
mock_redis.hgetall.return_value = {
"user_id": "user-123",
"game_id": f"spectating:{game_id}",
"connected_at": "2024-01-01T00:00:00+00:00",
"last_seen": "2024-01-01T00:00:00+00:00",
}
mock_redis.get.return_value = sid
await manager._cleanup_connection(sid)
# Should remove from spectators set
mock_redis.srem.assert_called_with(f"spectators:{game_id}", sid)
@pytest.mark.asyncio
async def test_cleanup_player_removes_from_game_conns(
self,
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that cleanup removes player from game_conns set.
When a playing (not spectating) connection is cleaned up,
the sid should be removed from game_conns, not spectators.
"""
sid = "player-sid"
game_id = "game-456"
mock_redis.hgetall.return_value = {
"user_id": "user-123",
"game_id": game_id, # Regular game_id, not spectating:
"connected_at": "2024-01-01T00:00:00+00:00",
"last_seen": "2024-01-01T00:00:00+00:00",
}
mock_redis.get.return_value = sid
await manager._cleanup_connection(sid)
# Should remove from game_conns, not spectators
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)

View File

@ -22,6 +22,7 @@ from app.core.models.actions import AttackAction, PassAction, ResignAction, Sele
from app.core.models.game_state import ForcedAction, GameState, PlayerState
from app.core.win_conditions import WinResult
from app.services.game_service import (
CannotSpectateOwnGameError,
ForcedActionRequiredError,
GameAlreadyEndedError,
GameCreationError,
@ -84,11 +85,39 @@ def mock_engine() -> MagicMock:
return engine
@pytest.fixture
def mock_timeout_service() -> AsyncMock:
"""Create a mock TurnTimeoutService.
The timeout service manages turn timers using Redis.
For tests, we mock all Redis interactions.
"""
from app.services.turn_timeout_service import TurnTimeoutInfo
service = AsyncMock()
service.start_turn_timer = AsyncMock(
return_value=TurnTimeoutInfo(
game_id="game-123",
player_id="player-1",
deadline=0.0,
timeout_seconds=180,
remaining_seconds=180,
warnings_sent=[],
warning_thresholds=[50, 25],
)
)
service.cancel_timer = AsyncMock(return_value=True)
service.extend_timer = AsyncMock(return_value=None)
service.get_timeout_info = AsyncMock(return_value=None)
return service
@pytest.fixture
def game_service(
mock_state_manager: AsyncMock,
mock_card_service: MagicMock,
mock_engine: MagicMock,
mock_timeout_service: AsyncMock,
) -> GameService:
"""Create a GameService with injected mock dependencies.
@ -98,6 +127,7 @@ def game_service(
return GameService(
state_manager=mock_state_manager,
card_service=mock_card_service,
timeout_service=mock_timeout_service,
engine_factory=lambda game: mock_engine,
)
@ -1556,3 +1586,849 @@ class TestExceptionMessages:
assert "pass" in str(error)
assert error.required_action_type == "select_active"
assert error.attempted_action_type == "pass"
class TestTurnTimerIntegration:
"""Tests for turn timer integration in GameService.
The turn timer should:
- NOT start during SETUP phase (when selecting basic pokemon)
- Start when SETUP phase ends (first real turn begins)
- Start when turn changes during normal play
- Be canceled when game ends
"""
@pytest.fixture
def game_state_in_setup(self) -> GameState:
"""Create a game state in SETUP phase.
During SETUP, players are selecting their basic pokemon.
Timer should NOT be running yet.
"""
player1 = PlayerState(player_id="player-1")
player2 = PlayerState(player_id="player-2")
return GameState(
game_id="game-123",
players={"player-1": player1, "player-2": player2},
current_player_id="player-1",
turn_number=0, # Setup phase
phase=TurnPhase.SETUP,
)
@pytest.mark.asyncio
async def test_timer_starts_when_setup_ends(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
mock_timeout_service: AsyncMock,
game_state_in_setup: GameState,
) -> None:
"""Test that turn timer starts when SETUP phase ends.
When an action causes the phase to transition from SETUP to
a real game phase (DRAW/MAIN), the turn timer should start.
This ensures players have unlimited time for initial pokemon
selection but are timed once actual gameplay begins.
"""
# Enable turn timer in rules
game_state_in_setup.rules.win_conditions.turn_timer_enabled = True
game_state_in_setup.rules.win_conditions.turn_timer_seconds = 180
mock_state_manager.load_state.return_value = game_state_in_setup
def mock_execute(state, player_id, action):
# Simulate SETUP ending - phase transitions to MAIN
state.phase = TurnPhase.MAIN
state.turn_number = 1
return ActionResult(success=True, message="Setup complete, game started")
mock_engine.execute_action = AsyncMock(side_effect=mock_execute)
result = await game_service.execute_action(
"game-123", "player-1", SelectActiveAction(pokemon_id="basic-1")
)
assert result.success is True
# Timer should have been started
mock_timeout_service.start_turn_timer.assert_called_once_with(
game_id="game-123",
player_id="player-1",
timeout_seconds=180,
warning_thresholds=[50, 25],
)
assert result.turn_timeout_seconds == 180
@pytest.mark.asyncio
async def test_timer_not_started_during_setup(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
mock_timeout_service: AsyncMock,
game_state_in_setup: GameState,
) -> None:
"""Test that turn timer is NOT started during SETUP phase actions.
When actions are executed during SETUP (before both players
have selected basic pokemon), the timer should not start.
"""
# Enable turn timer in rules
game_state_in_setup.rules.win_conditions.turn_timer_enabled = True
mock_state_manager.load_state.return_value = game_state_in_setup
def mock_execute(state, player_id, action):
# Action during SETUP - phase stays in SETUP (e.g., first player selected)
# Phase does NOT change
return ActionResult(success=True, message="First player selected basic")
mock_engine.execute_action = AsyncMock(side_effect=mock_execute)
await game_service.execute_action(
"game-123", "player-1", SelectActiveAction(pokemon_id="basic-1")
)
# Timer should NOT have been started (still in SETUP)
mock_timeout_service.start_turn_timer.assert_not_called()
@pytest.mark.asyncio
async def test_timer_starts_on_turn_change(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
mock_timeout_service: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that turn timer starts when turn changes during normal play.
When a player ends their turn (e.g., via pass action), the
timer should start for the next player.
"""
# Enable turn timer in rules
sample_game_state.rules.win_conditions.turn_timer_enabled = True
sample_game_state.rules.win_conditions.turn_timer_seconds = 180
mock_state_manager.load_state.return_value = sample_game_state
def mock_execute(state, player_id, action):
# Pass action ends turn - next player's turn
state.current_player_id = "player-2"
state.turn_number = 2
return ActionResult(success=True, message="Turn ended")
mock_engine.execute_action = AsyncMock(side_effect=mock_execute)
result = await game_service.execute_action("game-123", "player-1", PassAction())
assert result.success is True
assert result.turn_changed is True
# Timer should have been started for the new current player
mock_timeout_service.start_turn_timer.assert_called_once_with(
game_id="game-123",
player_id="player-2", # New current player
timeout_seconds=180,
warning_thresholds=[50, 25],
)
@pytest.mark.asyncio
async def test_timer_not_started_when_disabled(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
mock_timeout_service: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that timer is not started when turn timer is disabled.
If turn_timer_enabled is False in rules, no timer operations
should be performed even on turn changes.
"""
# Disable turn timer (default)
sample_game_state.rules.win_conditions.turn_timer_enabled = False
mock_state_manager.load_state.return_value = sample_game_state
def mock_execute(state, player_id, action):
state.current_player_id = "player-2"
state.turn_number = 2
return ActionResult(success=True, message="Turn ended")
mock_engine.execute_action = AsyncMock(side_effect=mock_execute)
result = await game_service.execute_action("game-123", "player-1", PassAction())
assert result.success is True
# Timer should NOT have been started (disabled in rules)
mock_timeout_service.start_turn_timer.assert_not_called()
assert result.turn_timeout_seconds is None
assert result.turn_deadline is None
@pytest.mark.asyncio
async def test_timer_canceled_on_game_over(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
mock_timeout_service: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that turn timer is canceled when game ends.
When an action results in game over (win condition met),
the timer should be canceled to prevent spurious timeout events.
"""
sample_game_state.rules.win_conditions.turn_timer_enabled = True
mock_state_manager.load_state.return_value = sample_game_state
mock_engine.execute_action.return_value = ActionResult(
success=True,
message="Game over",
win_result=WinResult(
winner_id="player-1",
loser_id="player-2",
end_reason=GameEndReason.PRIZES_TAKEN,
reason="All prizes taken",
),
)
result = await game_service.execute_action(
"game-123", "player-1", AttackAction(attack_index=0)
)
assert result.success is True
assert result.game_over is True
# Timer should have been canceled
mock_timeout_service.cancel_timer.assert_called_once_with("game-123")
class TestSpectateGame:
"""Tests for the spectate_game method.
Spectator mode allows users to watch games they are not participating in.
Spectators receive a filtered view with no hands visible.
"""
@pytest.mark.asyncio
async def test_spectate_game_success(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test successful game spectating returns spectator-filtered state.
When a non-participant spectates a game, they should receive a
SpectateResult with the game state filtered to hide all hands.
"""
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.spectate_game("game-123", "spectator-user")
assert result.success is True
assert result.game_id == "game-123"
assert result.visible_state is not None
assert result.game_over is False
# Spectator view should have special viewer_id
assert result.visible_state.viewer_id == "__spectator__"
# Spectator should not see any hands (is_my_turn always False)
assert result.visible_state.is_my_turn is False
@pytest.mark.asyncio
async def test_spectate_game_not_found(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
) -> None:
"""Test spectate_game raises GameNotFoundError when game doesn't exist.
Spectating a non-existent game should raise an appropriate error.
"""
mock_state_manager.load_state.return_value = None
with pytest.raises(GameNotFoundError) as exc_info:
await game_service.spectate_game("nonexistent", "spectator-user")
assert exc_info.value.game_id == "nonexistent"
@pytest.mark.asyncio
async def test_spectate_own_game_raises_error(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that players cannot spectate their own game.
A player who is participating in the game should not be able to
spectate it - they should use join_game instead.
"""
mock_state_manager.load_state.return_value = sample_game_state
with pytest.raises(CannotSpectateOwnGameError) as exc_info:
await game_service.spectate_game("game-123", "player-1")
assert exc_info.value.game_id == "game-123"
assert exc_info.value.player_id == "player-1"
@pytest.mark.asyncio
async def test_spectate_ended_game(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test spectating an ended game succeeds but indicates game_over.
Users should be able to spectate completed games to view the
final state, but game_over should be True.
"""
sample_game_state.winner_id = "player-1"
sample_game_state.end_reason = GameEndReason.PRIZES_TAKEN
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.spectate_game("game-123", "spectator-user")
assert result.success is True
assert result.game_over is True
assert "ended" in result.message.lower()
@pytest.mark.asyncio
async def test_spectate_game_hides_both_hands(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that spectator view hides both players' hands.
Unlike player views where you can see your own hand, spectators
cannot see either player's hand.
"""
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.spectate_game("game-123", "spectator-user")
assert result.visible_state is not None
# Both players' hands should show count only, no cards
for _player_id, player_state in result.visible_state.players.items():
assert player_state.hand.cards == []
# is_current_player should be False for all players from spectator view
assert player_state.is_current_player is False
class TestCannotSpectateOwnGameError:
"""Tests for CannotSpectateOwnGameError exception."""
def test_error_message_contains_game_and_player(self) -> None:
"""Test CannotSpectateOwnGameError has descriptive message.
The error message should clearly indicate which game and player
triggered the error.
"""
error = CannotSpectateOwnGameError("game-123", "player-1")
assert "game-123" in str(error)
assert error.game_id == "game-123"
assert error.player_id == "player-1"
assert "spectate" in str(error).lower()
class TestHandleTimeout:
"""Tests for the handle_timeout method.
handle_timeout is called by the background timeout polling task when
a player's turn timer expires. It should end the game with the
timed-out player as the loser.
"""
@pytest.mark.asyncio
async def test_handle_timeout_declares_opponent_winner(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that handle_timeout declares the opponent as winner.
When a player times out, their opponent should win by timeout.
"""
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.handle_timeout("game-123", "player-1")
assert result.success is True
assert result.game_id == "game-123"
assert result.winner_id == "player-2" # Opponent wins
assert result.loser_id == "player-1" # Timed out player loses
assert result.end_reason == GameEndReason.TIMEOUT
@pytest.mark.asyncio
async def test_handle_timeout_calls_end_game(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that handle_timeout archives the game to history.
The game should be properly archived with timeout as the end reason.
"""
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.handle_timeout("game-123", "player-2")
# Should archive to history
mock_state_manager.archive_to_history.assert_called_once()
call_kwargs = mock_state_manager.archive_to_history.call_args.kwargs
assert call_kwargs["game_id"] == "game-123"
assert call_kwargs["end_reason"].value == "timeout"
# Winner should be the non-timed-out player
assert result.winner_id == "player-1"
@pytest.mark.asyncio
async def test_handle_timeout_game_not_found(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
) -> None:
"""Test handle_timeout raises error when game doesn't exist.
If the game was already cleaned up or never existed, we should
get a GameNotFoundError.
"""
mock_state_manager.load_state.return_value = None
with pytest.raises(GameNotFoundError) as exc_info:
await game_service.handle_timeout("nonexistent", "player-1")
assert exc_info.value.game_id == "nonexistent"
@pytest.mark.asyncio
async def test_handle_timeout_cancels_timer(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_timeout_service: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that handle_timeout cancels the turn timer.
The timer should be canceled to prevent any further timeout events.
"""
mock_state_manager.load_state.return_value = sample_game_state
await game_service.handle_timeout("game-123", "player-1")
# Timer should have been canceled via end_game
mock_timeout_service.cancel_timer.assert_called_with("game-123")
class TestJoinGameTimerExtension:
"""Tests for timer extension during reconnection in join_game.
When a player reconnects mid-turn with an active timer, the timer
should be extended by the grace period to give them time to act.
"""
@pytest.mark.asyncio
async def test_join_game_extends_timer_on_reconnect(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_timeout_service: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that joining a game extends turn timer on reconnect.
When a player reconnects and it's their turn, the timer should
be extended by the grace period (default 15 seconds).
"""
from app.services.turn_timeout_service import TurnTimeoutInfo
# Enable turn timer
sample_game_state.rules.win_conditions.turn_timer_enabled = True
sample_game_state.rules.win_conditions.turn_timer_grace_seconds = 15
mock_state_manager.load_state.return_value = sample_game_state
# Mock existing timer for current player
existing_timer = TurnTimeoutInfo(
game_id="game-123",
player_id="player-1",
deadline=1000.0,
timeout_seconds=180,
remaining_seconds=60,
warnings_sent=[],
warning_thresholds=[50, 25],
)
mock_timeout_service.get_timeout_info.return_value = existing_timer
# Mock extended timer
extended_timer = TurnTimeoutInfo(
game_id="game-123",
player_id="player-1",
deadline=1015.0, # Extended by grace period
timeout_seconds=180,
remaining_seconds=75, # 60 + 15
warnings_sent=[],
warning_thresholds=[50, 25],
)
mock_timeout_service.extend_timer.return_value = extended_timer
# Player-1 reconnects (it's their turn)
result = await game_service.join_game("game-123", "player-1")
assert result.success is True
assert result.is_your_turn is True
# Timer should have been extended
mock_timeout_service.extend_timer.assert_called_once_with("game-123", 15)
# Result should show extended timer info
assert result.turn_timeout_seconds == 75
assert result.turn_deadline == 1015.0
@pytest.mark.asyncio
async def test_join_game_no_extension_when_not_your_turn(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_timeout_service: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that timer is not extended when it's not your turn.
Only the current player's reconnection should extend the timer.
The opponent reconnecting should not affect the timer.
"""
from app.services.turn_timeout_service import TurnTimeoutInfo
# Enable turn timer, player-1's turn
sample_game_state.rules.win_conditions.turn_timer_enabled = True
sample_game_state.current_player_id = "player-1"
mock_state_manager.load_state.return_value = sample_game_state
# Mock timer for player-1 (current player)
existing_timer = TurnTimeoutInfo(
game_id="game-123",
player_id="player-1",
deadline=1000.0,
timeout_seconds=180,
remaining_seconds=60,
warnings_sent=[],
warning_thresholds=[50, 25],
)
mock_timeout_service.get_timeout_info.return_value = existing_timer
# Player-2 reconnects (NOT their turn)
result = await game_service.join_game("game-123", "player-2")
assert result.success is True
assert result.is_your_turn is False
# Timer should NOT have been extended (not player-2's turn)
mock_timeout_service.extend_timer.assert_not_called()
# Result should still show timer info
assert result.turn_timeout_seconds == 60
assert result.turn_deadline == 1000.0
@pytest.mark.asyncio
async def test_join_game_no_timer_info_when_disabled(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_timeout_service: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that timer info is None when timer is disabled.
If turn_timer_enabled is False, no timer operations should occur.
"""
# Timer disabled (default)
sample_game_state.rules.win_conditions.turn_timer_enabled = False
mock_state_manager.load_state.return_value = sample_game_state
result = await game_service.join_game("game-123", "player-1")
assert result.success is True
# No timer operations
mock_timeout_service.get_timeout_info.assert_not_called()
mock_timeout_service.extend_timer.assert_not_called()
# Timer fields should be None
assert result.turn_timeout_seconds is None
assert result.turn_deadline is None
class TestAdditionalActionTypes:
"""Tests for additional action types in execute_action.
These tests verify that various action types are properly passed through
to the engine and handled correctly by the service layer.
"""
@pytest.mark.asyncio
async def test_execute_play_pokemon_action(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test executing a PlayPokemonAction.
Playing a Pokemon from hand should be validated and executed
through the engine like any other action.
"""
from app.core.models.actions import PlayPokemonAction
mock_state_manager.load_state.return_value = sample_game_state
mock_engine.execute_action.return_value = ActionResult(
success=True,
message="Pikachu placed on bench",
state_changes=[{"type": "play_pokemon", "zone": "bench"}],
)
action = PlayPokemonAction(card_instance_id="pikachu-001")
result = await game_service.execute_action("game-123", "player-1", action)
assert result.success is True
assert result.action_type == "play_pokemon"
mock_engine.execute_action.assert_called_once()
@pytest.mark.asyncio
async def test_execute_attach_energy_action(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test executing an AttachEnergyAction.
Attaching energy should update the Pokemon's energy and be
reflected in the state changes.
"""
from app.core.models.actions import AttachEnergyAction
mock_state_manager.load_state.return_value = sample_game_state
mock_engine.execute_action.return_value = ActionResult(
success=True,
message="Lightning energy attached to Pikachu",
state_changes=[{"type": "attach_energy", "energy_type": "lightning"}],
)
action = AttachEnergyAction(energy_card_id="energy-001", target_pokemon_id="pikachu-001")
result = await game_service.execute_action("game-123", "player-1", action)
assert result.success is True
assert result.action_type == "attach_energy"
@pytest.mark.asyncio
async def test_execute_retreat_action(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test executing a RetreatAction.
Retreating should switch the active Pokemon and discard energy
equal to the retreat cost.
"""
from app.core.models.actions import RetreatAction
mock_state_manager.load_state.return_value = sample_game_state
mock_engine.execute_action.return_value = ActionResult(
success=True,
message="Retreated to Raichu",
state_changes=[
{"type": "retreat", "old_active": "pikachu-001", "new_active": "raichu-001"}
],
)
action = RetreatAction(new_active_id="raichu-001", energy_to_discard=["energy-001"])
result = await game_service.execute_action("game-123", "player-1", action)
assert result.success is True
assert result.action_type == "retreat"
@pytest.mark.asyncio
async def test_execute_evolve_pokemon_action(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test executing an EvolvePokemonAction.
Evolution should place the evolution card on top of the target
Pokemon and update its stats.
"""
from app.core.models.actions import EvolvePokemonAction
mock_state_manager.load_state.return_value = sample_game_state
mock_engine.execute_action.return_value = ActionResult(
success=True,
message="Pikachu evolved into Raichu",
state_changes=[{"type": "evolve", "from": "pikachu-001", "to": "raichu-001"}],
)
action = EvolvePokemonAction(
evolution_card_id="raichu-001", target_pokemon_id="pikachu-001"
)
result = await game_service.execute_action("game-123", "player-1", action)
assert result.success is True
assert result.action_type == "evolve"
@pytest.mark.asyncio
async def test_execute_play_trainer_action(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test executing a PlayTrainerAction.
Trainer cards should be played and their effects resolved
through the effect handler system.
"""
from app.core.models.actions import PlayTrainerAction
mock_state_manager.load_state.return_value = sample_game_state
mock_engine.execute_action.return_value = ActionResult(
success=True,
message="Professor Oak played - drew 7 cards",
state_changes=[{"type": "play_trainer", "cards_drawn": 7}],
)
action = PlayTrainerAction(card_instance_id="prof-oak-001")
result = await game_service.execute_action("game-123", "player-1", action)
assert result.success is True
assert result.action_type == "play_trainer"
@pytest.mark.asyncio
async def test_execute_select_prize_action(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test executing a SelectPrizeAction.
After a knockout, the player should be able to select a prize
card to add to their hand.
"""
from app.core.models.actions import SelectPrizeAction
# Set up forced action for prize selection
sample_game_state.forced_actions = [
ForcedAction(
player_id="player-1",
action_type="select_prize",
reason="Select a prize card",
)
]
mock_state_manager.load_state.return_value = sample_game_state
mock_engine.execute_action.return_value = ActionResult(
success=True,
message="Prize card taken",
state_changes=[{"type": "select_prize", "prize_index": 2}],
)
action = SelectPrizeAction(prize_index=2)
result = await game_service.execute_action("game-123", "player-1", action)
assert result.success is True
assert result.action_type == "select_prize"
@pytest.mark.asyncio
async def test_execute_use_ability_action(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test executing a UseAbilityAction.
Pokemon abilities should be activated and their effects resolved.
"""
from app.core.models.actions import UseAbilityAction
mock_state_manager.load_state.return_value = sample_game_state
mock_engine.execute_action.return_value = ActionResult(
success=True,
message="Energy Trans activated",
state_changes=[{"type": "use_ability", "ability": "Energy Trans"}],
)
action = UseAbilityAction(pokemon_id="venusaur-001", ability_index=0)
result = await game_service.execute_action("game-123", "player-1", action)
assert result.success is True
assert result.action_type == "use_ability"
class TestEndReasonMapping:
"""Tests for the _map_end_reason helper function.
This function maps core GameEndReason to database EndReason,
ensuring proper enum synchronization between modules.
"""
def test_map_all_end_reasons(self) -> None:
"""Test that all GameEndReason values can be mapped.
Every core end reason should have a corresponding database
end reason to prevent runtime errors during game archival.
"""
from app.db.models.game import EndReason
from app.services.game_service import _map_end_reason
# All core end reasons should be mappable
for core_reason in GameEndReason:
db_reason = _map_end_reason(core_reason)
assert isinstance(db_reason, EndReason)
def test_map_prizes_taken(self) -> None:
"""Test mapping PRIZES_TAKEN end reason."""
from app.db.models.game import EndReason
from app.services.game_service import _map_end_reason
result = _map_end_reason(GameEndReason.PRIZES_TAKEN)
assert result == EndReason.PRIZES_TAKEN
def test_map_resignation(self) -> None:
"""Test mapping RESIGNATION end reason."""
from app.db.models.game import EndReason
from app.services.game_service import _map_end_reason
result = _map_end_reason(GameEndReason.RESIGNATION)
assert result == EndReason.RESIGNATION
def test_map_timeout(self) -> None:
"""Test mapping TIMEOUT end reason."""
from app.db.models.game import EndReason
from app.services.game_service import _map_end_reason
result = _map_end_reason(GameEndReason.TIMEOUT)
assert result == EndReason.TIMEOUT
def test_map_deck_empty_to_cannot_draw(self) -> None:
"""Test mapping DECK_EMPTY to CANNOT_DRAW.
The core uses DECK_EMPTY for clarity, but the DB schema
uses CANNOT_DRAW as the canonical name.
"""
from app.db.models.game import EndReason
from app.services.game_service import _map_end_reason
result = _map_end_reason(GameEndReason.DECK_EMPTY)
assert result == EndReason.CANNOT_DRAW

View File

@ -0,0 +1,930 @@
"""Unit tests for TurnTimeoutService.
This module tests the turn timeout management functionality including
timer lifecycle (start, cancel, extend), warning detection, and
expiration checking.
All tests use mocked Redis to avoid external dependencies.
"""
import json
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from unittest.mock import AsyncMock
import pytest
from app.services.turn_timeout_service import (
DEFAULT_KEY_TTL_BUFFER,
TURN_TIMEOUT_PREFIX,
PendingWarning,
TurnTimeoutInfo,
TurnTimeoutService,
)
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def mock_redis() -> AsyncMock:
"""Create a mock Redis client with common methods."""
redis = AsyncMock()
redis.hset = AsyncMock()
redis.hgetall = AsyncMock(return_value={})
redis.delete = AsyncMock(return_value=1)
redis.expire = AsyncMock()
redis.exists = AsyncMock(return_value=0)
# Mock scan_iter to return an empty async iterator by default
async def empty_scan_iter(match=None):
return
yield # Make this an async generator
redis.scan_iter = empty_scan_iter
return redis
@pytest.fixture
def turn_timeout_service(mock_redis: AsyncMock) -> TurnTimeoutService:
"""Create a TurnTimeoutService with injected mock Redis."""
@asynccontextmanager
async def mock_redis_factory():
yield mock_redis
return TurnTimeoutService(redis_factory=mock_redis_factory)
# =============================================================================
# Timer Lifecycle Tests
# =============================================================================
class TestStartTurnTimer:
"""Tests for TurnTimeoutService.start_turn_timer."""
@pytest.mark.asyncio
async def test_start_turn_timer_creates_redis_entry(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that starting a timer creates the correct Redis hash entry.
Verifies that all required fields are stored: player_id, deadline,
timeout_seconds, warnings_sent, and warning_thresholds.
"""
await turn_timeout_service.start_turn_timer(
game_id="game-123",
player_id="player-1",
timeout_seconds=180,
warning_thresholds=[50, 25],
)
# Verify Redis hset was called with correct data
mock_redis.hset.assert_called_once()
call_args = mock_redis.hset.call_args
assert call_args[0][0] == f"{TURN_TIMEOUT_PREFIX}game-123"
mapping = call_args[1]["mapping"]
assert mapping["player_id"] == "player-1"
assert mapping["timeout_seconds"] == "180"
assert mapping["warnings_sent"] == "[]"
assert json.loads(mapping["warning_thresholds"]) == [50, 25]
# Verify deadline is approximately correct (within 2 seconds)
deadline = float(mapping["deadline"])
expected = datetime.now(UTC).timestamp() + 180
assert abs(deadline - expected) < 2
@pytest.mark.asyncio
async def test_start_turn_timer_returns_info(
self, turn_timeout_service: TurnTimeoutService
) -> None:
"""
Test that starting a timer returns correct TurnTimeoutInfo.
The returned info should have all fields populated correctly
with remaining_seconds equal to timeout_seconds (just started).
"""
result = await turn_timeout_service.start_turn_timer(
game_id="game-123",
player_id="player-1",
timeout_seconds=180,
)
assert isinstance(result, TurnTimeoutInfo)
assert result.game_id == "game-123"
assert result.player_id == "player-1"
assert result.timeout_seconds == 180
assert result.remaining_seconds == 180
assert result.warnings_sent == []
assert not result.is_expired
@pytest.mark.asyncio
async def test_start_turn_timer_default_thresholds(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that default warning thresholds [50, 25] are used when not specified.
This ensures games always have reasonable warning points even if
the caller doesn't specify custom thresholds.
"""
await turn_timeout_service.start_turn_timer(
game_id="game-123",
player_id="player-1",
timeout_seconds=180,
)
mapping = mock_redis.hset.call_args[1]["mapping"]
thresholds = json.loads(mapping["warning_thresholds"])
assert thresholds == [50, 25]
@pytest.mark.asyncio
async def test_start_turn_timer_sorts_thresholds_descending(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that warning thresholds are sorted in descending order.
Warnings should be processed from highest to lowest percentage
so players get warned at 50% before 25%, regardless of input order.
"""
await turn_timeout_service.start_turn_timer(
game_id="game-123",
player_id="player-1",
timeout_seconds=180,
warning_thresholds=[25, 75, 50], # Unsorted input
)
mapping = mock_redis.hset.call_args[1]["mapping"]
thresholds = json.loads(mapping["warning_thresholds"])
assert thresholds == [75, 50, 25] # Should be sorted descending
@pytest.mark.asyncio
async def test_start_turn_timer_sets_ttl(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that the Redis key TTL is set beyond the timeout duration.
The TTL includes a buffer to ensure the key persists long enough
for the expiration to be detected and handled.
"""
await turn_timeout_service.start_turn_timer(
game_id="game-123",
player_id="player-1",
timeout_seconds=180,
)
mock_redis.expire.assert_called_once()
call_args = mock_redis.expire.call_args
assert call_args[0][0] == f"{TURN_TIMEOUT_PREFIX}game-123"
assert call_args[0][1] == 180 + DEFAULT_KEY_TTL_BUFFER
class TestCancelTimer:
"""Tests for TurnTimeoutService.cancel_timer."""
@pytest.mark.asyncio
async def test_cancel_timer_deletes_redis_key(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that canceling a timer deletes the Redis key.
When a turn ends normally, the timer should be cleaned up
to prevent false expiration detection.
"""
result = await turn_timeout_service.cancel_timer("game-123")
mock_redis.delete.assert_called_once_with(f"{TURN_TIMEOUT_PREFIX}game-123")
assert result is True
@pytest.mark.asyncio
async def test_cancel_timer_returns_false_when_not_found(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that canceling a non-existent timer returns False.
This allows callers to know if there was actually a timer
to cancel, useful for debugging and logging.
"""
mock_redis.delete = AsyncMock(return_value=0)
result = await turn_timeout_service.cancel_timer("nonexistent-game")
assert result is False
class TestExtendTimer:
"""Tests for TurnTimeoutService.extend_timer."""
@pytest.mark.asyncio
async def test_extend_timer_increases_deadline(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that extending a timer increases the deadline.
When a player reconnects, they should get additional time
to complete their turn.
"""
now = datetime.now(UTC).timestamp()
original_deadline = now + 60 # 60 seconds remaining
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(original_deadline),
"timeout_seconds": "180",
"warnings_sent": "[]",
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.extend_timer("game-123", extension_seconds=15)
assert result is not None
# Deadline should be extended by 15 seconds
assert result.deadline > original_deadline
assert result.remaining_seconds >= 60 # At least original remaining
@pytest.mark.asyncio
async def test_extend_timer_capped_at_original_timeout(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that timer extension is capped at the original timeout.
Players shouldn't be able to get more time than the original
timeout by repeatedly reconnecting.
"""
now = datetime.now(UTC).timestamp()
original_deadline = now + 170 # 170 seconds remaining (almost full)
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(original_deadline),
"timeout_seconds": "180",
"warnings_sent": "[]",
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.extend_timer("game-123", extension_seconds=30)
assert result is not None
# Should be capped at original timeout (180s from now)
max_deadline = now + 180
assert result.deadline <= max_deadline + 1 # Allow 1s tolerance
@pytest.mark.asyncio
async def test_extend_timer_returns_none_when_not_found(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that extending a non-existent timer returns None.
This handles the case where the game ended or the timer
was already canceled.
"""
mock_redis.hgetall = AsyncMock(return_value={})
result = await turn_timeout_service.extend_timer("nonexistent-game", extension_seconds=15)
assert result is None
@pytest.mark.asyncio
async def test_extend_timer_preserves_warnings_sent(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that extending a timer preserves the warnings_sent state.
If a 50% warning was already sent, extending shouldn't cause
it to be sent again.
"""
now = datetime.now(UTC).timestamp()
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(now + 30), # 30 seconds remaining
"timeout_seconds": "180",
"warnings_sent": "[50]", # 50% warning already sent
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.extend_timer("game-123", extension_seconds=15)
assert result is not None
assert 50 in result.warnings_sent
# =============================================================================
# Query Tests
# =============================================================================
class TestGetTimeoutInfo:
"""Tests for TurnTimeoutService.get_timeout_info."""
@pytest.mark.asyncio
async def test_get_timeout_info_returns_correct_data(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that get_timeout_info returns all stored data correctly.
This is the primary method for getting timer state and must
return accurate information for all fields.
"""
now = datetime.now(UTC).timestamp()
deadline = now + 120 # 120 seconds remaining
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(deadline),
"timeout_seconds": "180",
"warnings_sent": "[50]",
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.get_timeout_info("game-123")
assert result is not None
assert result.game_id == "game-123"
assert result.player_id == "player-1"
assert result.timeout_seconds == 180
assert result.warnings_sent == [50]
assert result.warning_thresholds == [50, 25]
# Allow small variance due to test execution time
assert 118 <= result.remaining_seconds <= 122
@pytest.mark.asyncio
async def test_get_timeout_info_returns_none_when_not_found(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that get_timeout_info returns None for non-existent timers.
This allows callers to check if a game has an active timer.
"""
mock_redis.hgetall = AsyncMock(return_value={})
result = await turn_timeout_service.get_timeout_info("nonexistent-game")
assert result is None
@pytest.mark.asyncio
async def test_get_timeout_info_handles_corrupted_data(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that corrupted Redis data is handled gracefully.
If a key exists but has missing or invalid fields, we should
return None rather than crashing.
"""
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
# Missing deadline and timeout_seconds
}
)
result = await turn_timeout_service.get_timeout_info("game-123")
assert result is None
@pytest.mark.asyncio
async def test_get_timeout_info_expired_timer(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that expired timers are correctly identified.
The is_expired property should be True when the deadline
has passed.
"""
now = datetime.now(UTC).timestamp()
past_deadline = now - 10 # 10 seconds ago
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(past_deadline),
"timeout_seconds": "180",
"warnings_sent": "[]",
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.get_timeout_info("game-123")
assert result is not None
assert result.is_expired
assert result.remaining_seconds == 0
class TestGetRemainingTime:
"""Tests for TurnTimeoutService.get_remaining_time."""
@pytest.mark.asyncio
async def test_get_remaining_time_returns_seconds(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that get_remaining_time returns the correct seconds.
This convenience method should return just the remaining
time without requiring full TurnTimeoutInfo parsing.
"""
now = datetime.now(UTC).timestamp()
deadline = now + 90
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(deadline),
"timeout_seconds": "180",
"warnings_sent": "[]",
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.get_remaining_time("game-123")
assert result is not None
assert 88 <= result <= 92 # Allow small variance
@pytest.mark.asyncio
async def test_get_remaining_time_returns_none_when_not_found(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that get_remaining_time returns None for non-existent timers.
"""
mock_redis.hgetall = AsyncMock(return_value={})
result = await turn_timeout_service.get_remaining_time("nonexistent-game")
assert result is None
# =============================================================================
# Warning Tests
# =============================================================================
class TestGetPendingWarning:
"""Tests for TurnTimeoutService.get_pending_warning."""
@pytest.mark.asyncio
async def test_get_pending_warning_at_50_percent(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that a warning is returned when time drops below 50%.
With 180s timeout, 50% = 90s. If remaining is 85s (47%),
we should get a warning for the 50% threshold.
"""
now = datetime.now(UTC).timestamp()
deadline = now + 85 # 85s remaining out of 180s = 47%
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(deadline),
"timeout_seconds": "180",
"warnings_sent": "[]",
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.get_pending_warning("game-123")
assert result is not None
assert isinstance(result, PendingWarning)
assert result.threshold == 50
assert result.player_id == "player-1"
@pytest.mark.asyncio
async def test_get_pending_warning_skips_already_sent(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that warnings already sent are not returned again.
If 50% warning was sent, and time is now at 20% (below 25%),
we should get the 25% warning, not the 50% warning again.
"""
now = datetime.now(UTC).timestamp()
deadline = now + 36 # 36s remaining out of 180s = 20%
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(deadline),
"timeout_seconds": "180",
"warnings_sent": "[50]", # 50% already sent
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.get_pending_warning("game-123")
assert result is not None
assert result.threshold == 25 # Should be 25%, not 50%
@pytest.mark.asyncio
async def test_get_pending_warning_none_when_above_threshold(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that no warning is returned when time is above all thresholds.
If 60% of time remains and thresholds are [50, 25], no warning
should be pending.
"""
now = datetime.now(UTC).timestamp()
deadline = now + 108 # 108s remaining out of 180s = 60%
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(deadline),
"timeout_seconds": "180",
"warnings_sent": "[]",
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.get_pending_warning("game-123")
assert result is None
@pytest.mark.asyncio
async def test_get_pending_warning_none_when_all_sent(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that no warning is returned when all warnings have been sent.
Even if time is very low, if all warnings were already sent,
there's nothing pending.
"""
now = datetime.now(UTC).timestamp()
deadline = now + 18 # 18s remaining out of 180s = 10%
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(deadline),
"timeout_seconds": "180",
"warnings_sent": "[50, 25]", # All sent
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.get_pending_warning("game-123")
assert result is None
@pytest.mark.asyncio
async def test_get_pending_warning_none_when_expired(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that no warning is returned for expired timers.
Once a timer expires, warnings are no longer relevant - the
timeout handler takes over.
"""
now = datetime.now(UTC).timestamp()
past_deadline = now - 5 # Expired 5 seconds ago
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(past_deadline),
"timeout_seconds": "180",
"warnings_sent": "[]",
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.get_pending_warning("game-123")
assert result is None
class TestMarkWarningSent:
"""Tests for TurnTimeoutService.mark_warning_sent."""
@pytest.mark.asyncio
async def test_mark_warning_sent_updates_redis(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that marking a warning updates the warnings_sent list.
After sending a warning, it should be added to the list
to prevent duplicate sends.
"""
now = datetime.now(UTC).timestamp()
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(now + 85),
"timeout_seconds": "180",
"warnings_sent": "[]",
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.mark_warning_sent("game-123", 50)
assert result is True
# Check that hset was called with updated warnings_sent
call_args = mock_redis.hset.call_args
assert call_args[0][0] == f"{TURN_TIMEOUT_PREFIX}game-123"
assert call_args[0][1] == "warnings_sent"
assert 50 in json.loads(call_args[0][2])
@pytest.mark.asyncio
async def test_mark_warning_sent_returns_false_when_not_found(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that marking a warning returns False if timer doesn't exist.
"""
mock_redis.hgetall = AsyncMock(return_value={})
result = await turn_timeout_service.mark_warning_sent("nonexistent-game", 50)
assert result is False
@pytest.mark.asyncio
async def test_mark_warning_sent_idempotent(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that marking the same warning twice is idempotent.
Re-marking shouldn't cause errors or duplicate entries.
"""
now = datetime.now(UTC).timestamp()
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(now + 85),
"timeout_seconds": "180",
"warnings_sent": "[50]", # Already marked
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.mark_warning_sent("game-123", 50)
assert result is True
# hset should not be called since it's already marked
mock_redis.hset.assert_not_called()
# =============================================================================
# Expiration Tests
# =============================================================================
class TestCheckExpiredTimers:
"""Tests for TurnTimeoutService.check_expired_timers."""
@pytest.mark.asyncio
async def test_check_expired_timers_returns_expired(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that check_expired_timers returns expired timers.
This is the main polling method that background tasks use
to detect timeouts.
"""
now = datetime.now(UTC).timestamp()
# Mock scan_iter to return one game
async def mock_scan_iter(match=None):
yield f"{TURN_TIMEOUT_PREFIX}game-123"
mock_redis.scan_iter = mock_scan_iter
# Mock hgetall to return expired timer
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(now - 10), # Expired 10 seconds ago
"timeout_seconds": "180",
"warnings_sent": "[]",
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.check_expired_timers()
assert len(result) == 1
assert result[0].game_id == "game-123"
assert result[0].is_expired
@pytest.mark.asyncio
async def test_check_expired_timers_excludes_active(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that check_expired_timers excludes active (non-expired) timers.
Only expired timers should be returned; active games should
continue uninterrupted.
"""
now = datetime.now(UTC).timestamp()
async def mock_scan_iter(match=None):
yield f"{TURN_TIMEOUT_PREFIX}game-123"
mock_redis.scan_iter = mock_scan_iter
# Mock hgetall to return active timer
mock_redis.hgetall = AsyncMock(
return_value={
"player_id": "player-1",
"deadline": str(now + 60), # 60 seconds remaining
"timeout_seconds": "180",
"warnings_sent": "[]",
"warning_thresholds": "[50, 25]",
}
)
result = await turn_timeout_service.check_expired_timers()
assert len(result) == 0
@pytest.mark.asyncio
async def test_check_expired_timers_empty_when_no_timers(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that check_expired_timers returns empty list when no timers exist.
"""
async def mock_scan_iter(match=None):
return
yield # Empty async generator
mock_redis.scan_iter = mock_scan_iter
result = await turn_timeout_service.check_expired_timers()
assert result == []
# =============================================================================
# TurnTimeoutInfo Tests
# =============================================================================
class TestTurnTimeoutInfo:
"""Tests for TurnTimeoutInfo dataclass properties."""
def test_is_expired_true_when_remaining_zero(self) -> None:
"""
Test that is_expired returns True when remaining_seconds is 0.
"""
info = TurnTimeoutInfo(
game_id="game-123",
player_id="player-1",
deadline=0,
timeout_seconds=180,
remaining_seconds=0,
warnings_sent=[],
warning_thresholds=[50, 25],
)
assert info.is_expired is True
def test_is_expired_false_when_time_remaining(self) -> None:
"""
Test that is_expired returns False when time remains.
"""
info = TurnTimeoutInfo(
game_id="game-123",
player_id="player-1",
deadline=0,
timeout_seconds=180,
remaining_seconds=60,
warnings_sent=[],
warning_thresholds=[50, 25],
)
assert info.is_expired is False
def test_percent_remaining_calculation(self) -> None:
"""
Test that percent_remaining calculates correctly.
90 seconds remaining out of 180 should be 50%.
"""
info = TurnTimeoutInfo(
game_id="game-123",
player_id="player-1",
deadline=0,
timeout_seconds=180,
remaining_seconds=90,
warnings_sent=[],
warning_thresholds=[50, 25],
)
assert info.percent_remaining == 50.0
def test_percent_remaining_zero_timeout(self) -> None:
"""
Test that percent_remaining handles zero timeout gracefully.
Prevents division by zero errors.
"""
info = TurnTimeoutInfo(
game_id="game-123",
player_id="player-1",
deadline=0,
timeout_seconds=0,
remaining_seconds=0,
warnings_sent=[],
warning_thresholds=[50, 25],
)
assert info.percent_remaining == 0.0
# =============================================================================
# Utility Method Tests
# =============================================================================
class TestUtilityMethods:
"""Tests for utility methods."""
@pytest.mark.asyncio
async def test_has_active_timer_true(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that has_active_timer returns True when timer exists.
"""
mock_redis.exists = AsyncMock(return_value=1)
result = await turn_timeout_service.has_active_timer("game-123")
assert result is True
@pytest.mark.asyncio
async def test_has_active_timer_false(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that has_active_timer returns False when no timer exists.
"""
mock_redis.exists = AsyncMock(return_value=0)
result = await turn_timeout_service.has_active_timer("nonexistent-game")
assert result is False
@pytest.mark.asyncio
async def test_get_active_timer_count(
self, turn_timeout_service: TurnTimeoutService, mock_redis: AsyncMock
) -> None:
"""
Test that get_active_timer_count returns correct count.
Useful for monitoring active games with turn timers.
"""
async def mock_scan_iter(match=None):
yield f"{TURN_TIMEOUT_PREFIX}game-1"
yield f"{TURN_TIMEOUT_PREFIX}game-2"
yield f"{TURN_TIMEOUT_PREFIX}game-3"
mock_redis.scan_iter = mock_scan_iter
result = await turn_timeout_service.get_active_timer_count()
assert result == 3

View File

@ -57,9 +57,25 @@ def mock_connection_manager() -> AsyncMock:
cm.get_connection = AsyncMock(return_value=None)
cm.get_game_user_sids = AsyncMock(return_value={})
cm.get_opponent_sid = AsyncMock(return_value=None)
cm.get_user_active_game = AsyncMock(return_value=None)
return cm
@pytest.fixture
def mock_state_manager() -> AsyncMock:
"""Create a mock GameStateManager.
The GameStateManager handles persistence to Redis/Postgres
and provides methods to look up active games.
"""
sm = AsyncMock()
sm.get_player_active_games = AsyncMock(return_value=[])
sm.load_state = AsyncMock(return_value=None)
sm.save_to_cache = AsyncMock()
sm.persist_to_db = AsyncMock()
return sm
@pytest.fixture
def mock_sio() -> AsyncMock:
"""Create a mock Socket.IO AsyncServer.
@ -77,11 +93,13 @@ def mock_sio() -> AsyncMock:
def handler(
mock_game_service: AsyncMock,
mock_connection_manager: AsyncMock,
mock_state_manager: AsyncMock,
) -> GameNamespaceHandler:
"""Create a GameNamespaceHandler with injected mock dependencies."""
return GameNamespaceHandler(
game_svc=mock_game_service,
conn_manager=mock_connection_manager,
state_manager=mock_state_manager,
)
@ -931,3 +949,347 @@ class TestErrorResponse:
assert result["error"]["code"] == WSErrorCode.GAME_NOT_FOUND.value
assert result["error"]["message"] == "Game not found"
assert result["request_message_id"] == "msg-123"
class TestHandleReconnect:
"""Tests for the handle_reconnect event handler.
The reconnection handler is called after successful authentication
to check if the user has an active game and auto-rejoin them.
"""
@pytest.fixture
def mock_active_game(self) -> AsyncMock:
"""Create a mock ActiveGame record.
ActiveGame represents a game in progress stored in Postgres.
"""
from datetime import UTC, datetime
from uuid import UUID
game = AsyncMock()
game.id = UUID("12345678-1234-5678-1234-567812345678")
game.started_at = datetime.now(UTC)
game.last_action_at = datetime.now(UTC)
return game
@pytest.mark.asyncio
async def test_reconnect_no_active_games(
self,
handler: GameNamespaceHandler,
mock_state_manager: AsyncMock,
mock_sio: AsyncMock,
) -> None:
"""Test that reconnect returns None when user has no active games.
If the user isn't in any games, there's nothing to reconnect to.
"""
mock_state_manager.get_player_active_games.return_value = []
result = await handler.handle_reconnect(
mock_sio,
"sid-123",
"12345678-1234-5678-1234-567812345678",
)
assert result is None
mock_state_manager.get_player_active_games.assert_called_once()
@pytest.mark.asyncio
async def test_reconnect_invalid_uuid_returns_none(
self,
handler: GameNamespaceHandler,
mock_state_manager: AsyncMock,
mock_sio: AsyncMock,
) -> None:
"""Test that invalid user_id (non-UUID) returns None.
NPC IDs are not valid UUIDs and shouldn't have active games.
"""
result = await handler.handle_reconnect(
mock_sio,
"sid-123",
"npc-grass-trainer-1", # Not a valid UUID
)
assert result is None
# Should not have queried for active games
mock_state_manager.get_player_active_games.assert_not_called()
@pytest.mark.asyncio
async def test_reconnect_success(
self,
handler: GameNamespaceHandler,
mock_game_service: AsyncMock,
mock_connection_manager: AsyncMock,
mock_state_manager: AsyncMock,
mock_sio: AsyncMock,
mock_active_game: AsyncMock,
sample_visible_state: VisibleGameState,
) -> None:
"""Test successful reconnection to an active game.
When a user connects with an active game, they should be
automatically rejoined to that game.
"""
mock_state_manager.get_player_active_games.return_value = [mock_active_game]
mock_game_service.join_game.return_value = GameJoinResult(
success=True,
game_id=str(mock_active_game.id),
player_id="12345678-1234-5678-1234-567812345678",
visible_state=sample_visible_state,
is_your_turn=True,
game_over=False,
)
result = await handler.handle_reconnect(
mock_sio,
"sid-123",
"12345678-1234-5678-1234-567812345678",
)
assert result is not None
assert result["game_id"] == str(mock_active_game.id)
assert result["is_your_turn"] is True
assert "state" in result
# Should have joined connection manager
mock_connection_manager.join_game.assert_called_once_with(
"sid-123", str(mock_active_game.id)
)
# Should have entered the room
mock_sio.enter_room.assert_called_once()
@pytest.mark.asyncio
async def test_reconnect_notifies_opponent(
self,
handler: GameNamespaceHandler,
mock_game_service: AsyncMock,
mock_connection_manager: AsyncMock,
mock_state_manager: AsyncMock,
mock_sio: AsyncMock,
mock_active_game: AsyncMock,
sample_visible_state: VisibleGameState,
) -> None:
"""Test that reconnection notifies the opponent.
When a player reconnects, their opponent should be notified
that they're back online.
"""
mock_state_manager.get_player_active_games.return_value = [mock_active_game]
mock_game_service.join_game.return_value = GameJoinResult(
success=True,
game_id=str(mock_active_game.id),
player_id="12345678-1234-5678-1234-567812345678",
visible_state=sample_visible_state,
is_your_turn=True,
)
mock_connection_manager.get_opponent_sid.return_value = "opponent-sid"
await handler.handle_reconnect(
mock_sio,
"sid-123",
"12345678-1234-5678-1234-567812345678",
)
# Should have tried to notify opponent
mock_connection_manager.get_opponent_sid.assert_called()
# Should have emitted opponent status
emit_calls = [
call for call in mock_sio.emit.call_args_list if call[0][0] == "game:opponent_status"
]
assert len(emit_calls) == 1
@pytest.mark.asyncio
async def test_reconnect_includes_pending_forced_action(
self,
handler: GameNamespaceHandler,
mock_game_service: AsyncMock,
mock_connection_manager: AsyncMock,
mock_state_manager: AsyncMock,
mock_sio: AsyncMock,
mock_active_game: AsyncMock,
sample_visible_state: VisibleGameState,
) -> None:
"""Test that reconnect includes pending forced action.
If the game has a pending forced action when the player
reconnects, it should be included in the response.
"""
mock_state_manager.get_player_active_games.return_value = [mock_active_game]
mock_game_service.join_game.return_value = GameJoinResult(
success=True,
game_id=str(mock_active_game.id),
player_id="12345678-1234-5678-1234-567812345678",
visible_state=sample_visible_state,
is_your_turn=True,
pending_forced_action=PendingForcedAction(
player_id="12345678-1234-5678-1234-567812345678",
action_type="select_active",
reason="Your active Pokemon was knocked out.",
params={"available_bench_ids": ["bench-1"]},
),
)
result = await handler.handle_reconnect(
mock_sio,
"sid-123",
"12345678-1234-5678-1234-567812345678",
)
assert result is not None
assert "pending_forced_action" in result
assert result["pending_forced_action"]["action_type"] == "select_active"
@pytest.mark.asyncio
async def test_reconnect_includes_turn_timer(
self,
handler: GameNamespaceHandler,
mock_game_service: AsyncMock,
mock_connection_manager: AsyncMock,
mock_state_manager: AsyncMock,
mock_sio: AsyncMock,
mock_active_game: AsyncMock,
sample_visible_state: VisibleGameState,
) -> None:
"""Test that reconnect includes turn timer information.
If turn timers are enabled, the remaining time should be
included in the reconnection response.
"""
mock_state_manager.get_player_active_games.return_value = [mock_active_game]
mock_game_service.join_game.return_value = GameJoinResult(
success=True,
game_id=str(mock_active_game.id),
player_id="12345678-1234-5678-1234-567812345678",
visible_state=sample_visible_state,
is_your_turn=True,
turn_timeout_seconds=120,
turn_deadline=1700000000.0,
)
result = await handler.handle_reconnect(
mock_sio,
"sid-123",
"12345678-1234-5678-1234-567812345678",
)
assert result is not None
assert result["turn_timeout_seconds"] == 120
assert result["turn_deadline"] == 1700000000.0
@pytest.mark.asyncio
async def test_reconnect_join_game_fails(
self,
handler: GameNamespaceHandler,
mock_game_service: AsyncMock,
mock_state_manager: AsyncMock,
mock_sio: AsyncMock,
mock_active_game: AsyncMock,
) -> None:
"""Test that reconnect returns None when join_game fails.
If the GameService fails to join the game (e.g., game was archived
between lookup and join), we should return None gracefully.
"""
mock_state_manager.get_player_active_games.return_value = [mock_active_game]
mock_game_service.join_game.return_value = GameJoinResult(
success=False,
game_id=str(mock_active_game.id),
player_id="12345678-1234-5678-1234-567812345678",
message="Game not found",
)
result = await handler.handle_reconnect(
mock_sio,
"sid-123",
"12345678-1234-5678-1234-567812345678",
)
assert result is None
@pytest.mark.asyncio
async def test_reconnect_multiple_games_uses_most_recent(
self,
handler: GameNamespaceHandler,
mock_game_service: AsyncMock,
mock_connection_manager: AsyncMock,
mock_state_manager: AsyncMock,
mock_sio: AsyncMock,
sample_visible_state: VisibleGameState,
) -> None:
"""Test that with multiple active games, the most recent is used.
If a player has multiple active games (edge case), we should
reconnect to the one with the most recent activity.
"""
from datetime import UTC, datetime, timedelta
from uuid import UUID
older_game = AsyncMock()
older_game.id = UUID("11111111-1111-1111-1111-111111111111")
older_game.started_at = datetime.now(UTC) - timedelta(hours=2)
older_game.last_action_at = datetime.now(UTC) - timedelta(hours=1)
newer_game = AsyncMock()
newer_game.id = UUID("22222222-2222-2222-2222-222222222222")
newer_game.started_at = datetime.now(UTC) - timedelta(hours=1)
newer_game.last_action_at = datetime.now(UTC) - timedelta(minutes=5)
# Return older game first to verify sorting works
mock_state_manager.get_player_active_games.return_value = [older_game, newer_game]
mock_game_service.join_game.return_value = GameJoinResult(
success=True,
game_id=str(newer_game.id),
player_id="12345678-1234-5678-1234-567812345678",
visible_state=sample_visible_state,
is_your_turn=True,
)
result = await handler.handle_reconnect(
mock_sio,
"sid-123",
"12345678-1234-5678-1234-567812345678",
)
assert result is not None
# Should have joined the newer game
assert result["game_id"] == str(newer_game.id)
@pytest.mark.asyncio
async def test_reconnect_to_ended_game(
self,
handler: GameNamespaceHandler,
mock_game_service: AsyncMock,
mock_connection_manager: AsyncMock,
mock_state_manager: AsyncMock,
mock_sio: AsyncMock,
mock_active_game: AsyncMock,
sample_visible_state: VisibleGameState,
) -> None:
"""Test reconnection to a game that has ended.
If the game ended while the player was disconnected, they
should still see the final state with game_over=True.
"""
mock_state_manager.get_player_active_games.return_value = [mock_active_game]
mock_game_service.join_game.return_value = GameJoinResult(
success=True,
game_id=str(mock_active_game.id),
player_id="12345678-1234-5678-1234-567812345678",
visible_state=sample_visible_state,
is_your_turn=False,
game_over=True,
message="Game has ended",
)
result = await handler.handle_reconnect(
mock_sio,
"sid-123",
"12345678-1234-5678-1234-567812345678",
)
assert result is not None
assert result["game_over"] is True