"""WebSocket connection management with Redis-backed session tracking. This module manages WebSocket connection state using Redis for persistence and cross-instance coordination. It tracks: - Individual connection sessions (sid -> user_id, game_id, timestamps) - User to connection mapping (user_id -> sid) - Game participants (game_id -> set of sids) Key Patterns: conn:{sid} - Hash with connection details (user_id, game_id, connected_at, last_seen) user_conn:{user_id} - String with active sid game_conns:{game_id} - Set of sids for game participants Lifecycle: 1. on_connect: Register connection, map user to sid 2. on_join_game: Add sid to game's connection set 3. on_heartbeat: Update last_seen timestamp 4. on_disconnect: Remove from all tracking structures Example: manager = ConnectionManager() # On WebSocket connect await manager.register_connection(sid, user_id) # On joining a game await manager.join_game(sid, game_id) # Check if opponent is online if await manager.is_user_online(opponent_id): # Send real-time update # On disconnect await manager.unregister_connection(sid) """ import logging from dataclasses import dataclass from datetime import UTC, datetime from uuid import UUID from app.db.redis import get_redis logger = logging.getLogger(__name__) # Redis key patterns CONN_PREFIX = "conn:" USER_CONN_PREFIX = "user_conn:" GAME_CONNS_PREFIX = "game_conns:" # Connection TTL (auto-expire stale connections) DEFAULT_CONN_TTL_SECONDS = 3600 # 1 hour HEARTBEAT_INTERVAL_SECONDS = 30 @dataclass class ConnectionInfo: """Information about a WebSocket connection. Attributes: sid: Socket.IO session ID. user_id: Authenticated user's UUID. game_id: Current game ID if in a game, None otherwise. connected_at: When the connection was established. last_seen: Last heartbeat or activity timestamp. """ sid: str user_id: str game_id: str | None connected_at: datetime last_seen: datetime def is_stale(self, threshold_seconds: int = HEARTBEAT_INTERVAL_SECONDS * 3) -> bool: """Check if connection is stale (no recent activity). Args: threshold_seconds: Seconds since last_seen to consider stale. Returns: True if connection hasn't been seen recently. """ elapsed = (datetime.now(UTC) - self.last_seen).total_seconds() return elapsed > threshold_seconds class ConnectionManager: """Manages WebSocket connections with Redis-backed session tracking. Provides methods to: - Register/unregister connections - Track which game a connection is participating in - Check user online status - Get all connections for a game - Handle heartbeats and stale connection detection Attributes: conn_ttl_seconds: TTL for connection records in Redis. """ def __init__(self, conn_ttl_seconds: int = DEFAULT_CONN_TTL_SECONDS) -> None: """Initialize the ConnectionManager. Args: conn_ttl_seconds: How long to keep connection records in Redis. """ self.conn_ttl_seconds = conn_ttl_seconds def _conn_key(self, sid: str) -> str: """Generate Redis key for a connection.""" return f"{CONN_PREFIX}{sid}" def _user_conn_key(self, user_id: str) -> str: """Generate Redis key for user-to-connection mapping.""" return f"{USER_CONN_PREFIX}{user_id}" def _game_conns_key(self, game_id: str) -> str: """Generate Redis key for game connection set.""" return f"{GAME_CONNS_PREFIX}{game_id}" # ========================================================================= # Connection Lifecycle # ========================================================================= async def register_connection( self, sid: str, user_id: str | UUID, ) -> None: """Register a new WebSocket connection. Creates connection record in Redis and maps user to this connection. If user had a previous connection, it will be replaced. Args: sid: Socket.IO session ID. user_id: Authenticated user's ID (UUID or string). Example: await manager.register_connection("abc123", user.id) """ user_id_str = str(user_id) now = datetime.now(UTC).isoformat() async with get_redis() as redis: # Check for existing connection and clean it up old_sid = await redis.get(self._user_conn_key(user_id_str)) if old_sid and old_sid != sid: logger.info(f"Replacing old connection {old_sid} for user {user_id_str}") await self._cleanup_connection(old_sid) # Create connection record (hash) conn_key = self._conn_key(sid) await redis.hset( conn_key, mapping={ "user_id": user_id_str, "game_id": "", "connected_at": now, "last_seen": now, }, ) await redis.expire(conn_key, self.conn_ttl_seconds) # Map user to this connection user_conn_key = self._user_conn_key(user_id_str) await redis.set(user_conn_key, sid) await redis.expire(user_conn_key, self.conn_ttl_seconds) logger.debug(f"Registered connection: sid={sid}, user_id={user_id_str}") async def unregister_connection(self, sid: str) -> ConnectionInfo | None: """Unregister a WebSocket connection and clean up all related data. Removes connection from: - Connection record - User-to-connection mapping - Any game connection sets Args: sid: Socket.IO session ID. Returns: ConnectionInfo of the removed connection, or None if not found. Example: info = await manager.unregister_connection("abc123") if info and info.game_id: # Notify opponent of disconnect """ # Get connection info before cleanup info = await self.get_connection(sid) if info is None: logger.debug(f"Connection not found for unregister: {sid}") return None await self._cleanup_connection(sid) logger.debug(f"Unregistered connection: sid={sid}, user_id={info.user_id}") return info async def _cleanup_connection(self, sid: str) -> None: """Internal cleanup of a connection's Redis data. Args: sid: Socket.IO session ID. """ async with get_redis() as redis: conn_key = self._conn_key(sid) # Get connection data for cleanup data = await redis.hgetall(conn_key) if not data: return user_id = data.get("user_id", "") game_id = data.get("game_id", "") # Remove from game connection set if in a game if game_id: game_conns_key = self._game_conns_key(game_id) await redis.srem(game_conns_key, sid) # Remove user-to-connection mapping if it points to this sid if user_id: user_conn_key = self._user_conn_key(user_id) current_sid = await redis.get(user_conn_key) if current_sid == sid: await redis.delete(user_conn_key) # Delete connection record await redis.delete(conn_key) # ========================================================================= # Game Association # ========================================================================= async def join_game(self, sid: str, game_id: str) -> bool: """Associate a connection with a game. Updates the connection's game_id and adds it to the game's connection set. If connection was in a different game, leaves that game first. Args: sid: Socket.IO session ID. game_id: Game to join. Returns: True if successful, False if connection not found. Example: await manager.join_game("abc123", "game-456") """ async with get_redis() as redis: conn_key = self._conn_key(sid) # Check connection exists exists = await redis.exists(conn_key) if not exists: logger.warning(f"Cannot join game: connection not found {sid}") return False # Leave current game if in one current_game = await redis.hget(conn_key, "game_id") if current_game and current_game != game_id: old_game_conns_key = self._game_conns_key(current_game) await redis.srem(old_game_conns_key, sid) logger.debug(f"Connection {sid} left game {current_game}") # Update connection's game_id await redis.hset(conn_key, "game_id", game_id) # Add to game's connection set game_conns_key = self._game_conns_key(game_id) await redis.sadd(game_conns_key, sid) # Set TTL on game conns set (cleanup if game becomes inactive) await redis.expire(game_conns_key, self.conn_ttl_seconds) logger.debug(f"Connection {sid} joined game {game_id}") return True async def leave_game(self, sid: str) -> str | None: """Remove a connection from its current game. Args: sid: Socket.IO session ID. Returns: The game_id that was left, or None if not in a game. Example: game_id = await manager.leave_game("abc123") """ async with get_redis() as redis: conn_key = self._conn_key(sid) # Get current game game_id = await redis.hget(conn_key, "game_id") if not game_id: return None # Remove from game's connection set game_conns_key = self._game_conns_key(game_id) await redis.srem(game_conns_key, sid) # Clear game_id on connection await redis.hset(conn_key, "game_id", "") logger.debug(f"Connection {sid} left game {game_id}") return game_id # ========================================================================= # Heartbeat / Activity # ========================================================================= async def update_heartbeat(self, sid: str) -> bool: """Update the last_seen timestamp for a connection. Should be called on each heartbeat message from the client. Also refreshes TTL on connection records. Args: sid: Socket.IO session ID. Returns: True if successful, False if connection not found. Example: await manager.update_heartbeat("abc123") """ now = datetime.now(UTC).isoformat() async with get_redis() as redis: conn_key = self._conn_key(sid) # Check exists and update atomically exists = await redis.exists(conn_key) if not exists: return False # Update last_seen await redis.hset(conn_key, "last_seen", now) # Refresh TTL await redis.expire(conn_key, self.conn_ttl_seconds) # Also refresh user mapping TTL user_id = await redis.hget(conn_key, "user_id") if user_id: user_conn_key = self._user_conn_key(user_id) await redis.expire(user_conn_key, self.conn_ttl_seconds) return True # ========================================================================= # Query Methods # ========================================================================= async def get_connection(self, sid: str) -> ConnectionInfo | None: """Get connection info by session ID. Args: sid: Socket.IO session ID. Returns: ConnectionInfo if found, None otherwise. Example: info = await manager.get_connection("abc123") if info: print(f"User {info.user_id} connected at {info.connected_at}") """ async with get_redis() as redis: conn_key = self._conn_key(sid) data = await redis.hgetall(conn_key) if not data: return None return ConnectionInfo( sid=sid, user_id=data.get("user_id", ""), game_id=data.get("game_id") or None, connected_at=datetime.fromisoformat(data["connected_at"]), last_seen=datetime.fromisoformat(data["last_seen"]), ) async def get_user_connection(self, user_id: str | UUID) -> ConnectionInfo | None: """Get connection info for a user. Args: user_id: User's UUID or string ID. Returns: ConnectionInfo if user is connected, None otherwise. Example: info = await manager.get_user_connection(opponent_id) if info: # User is online """ user_id_str = str(user_id) async with get_redis() as redis: user_conn_key = self._user_conn_key(user_id_str) sid = await redis.get(user_conn_key) if not sid: return None return await self.get_connection(sid) async def is_user_online(self, user_id: str | UUID) -> bool: """Check if a user has an active connection. Args: user_id: User's UUID or string ID. Returns: True if user is connected, False otherwise. Example: if await manager.is_user_online(opponent_id): # Send real-time notification """ info = await self.get_user_connection(user_id) if info is None: return False # Check if connection is stale if info.is_stale(): logger.debug(f"User {user_id} connection is stale") return False return True async def get_game_connections(self, game_id: str) -> list[ConnectionInfo]: """Get all connections for a game. Args: game_id: Game ID. Returns: List of ConnectionInfo for all connected participants. Example: connections = await manager.get_game_connections("game-456") for conn in connections: # Send state update to each participant """ async with get_redis() as redis: game_conns_key = self._game_conns_key(game_id) sids = await redis.smembers(game_conns_key) connections = [] for sid in sids: info = await self.get_connection(sid) if info is not None: connections.append(info) return connections async def get_game_user_sids(self, game_id: str) -> dict[str, str]: """Get mapping of user_id to sid for a game. Useful for targeted message delivery to specific players. Args: game_id: Game ID. Returns: Dict mapping user_id to their sid. Example: user_sids = await manager.get_game_user_sids("game-456") player1_sid = user_sids.get(player1_id) """ connections = await self.get_game_connections(game_id) return {conn.user_id: conn.sid for conn in connections} async def get_opponent_sid( self, game_id: str, current_user_id: str | UUID, ) -> str | None: """Get the sid of the opponent in a 2-player game. Args: game_id: Game ID. current_user_id: The current user's ID. Returns: Opponent's sid if found and connected, None otherwise. Example: opponent_sid = await manager.get_opponent_sid("game-456", user_id) if opponent_sid: # Send message to opponent """ current_user_str = str(current_user_id) connections = await self.get_game_connections(game_id) for conn in connections: if conn.user_id != current_user_str: return conn.sid return None # ========================================================================= # Maintenance # ========================================================================= async def cleanup_stale_connections( self, threshold_seconds: int = HEARTBEAT_INTERVAL_SECONDS * 3, ) -> int: """Clean up stale connections that haven't sent heartbeats. Should be called periodically by a background task. Args: threshold_seconds: Seconds since last_seen to consider stale. Returns: Number of connections cleaned up. Example: # In background task count = await manager.cleanup_stale_connections() logger.info(f"Cleaned up {count} stale connections") """ count = 0 async with get_redis() as redis: # Scan for all connection keys async for key in redis.scan_iter(match=f"{CONN_PREFIX}*"): sid = key[len(CONN_PREFIX) :] info = await self.get_connection(sid) if info and info.is_stale(threshold_seconds): await self._cleanup_connection(sid) count += 1 logger.debug(f"Cleaned up stale connection: {sid}") if count > 0: logger.info(f"Cleaned up {count} stale connections") return count async def get_connection_count(self) -> int: """Get the total number of active connections. Useful for monitoring and admin dashboards. Returns: Number of active connections. """ count = 0 async with get_redis() as redis: async for _ in redis.scan_iter(match=f"{CONN_PREFIX}*"): count += 1 return count async def get_game_connection_count(self, game_id: str) -> int: """Get the number of connections for a specific game. Args: game_id: Game ID. Returns: Number of connected participants. """ async with get_redis() as redis: game_conns_key = self._game_conns_key(game_id) return await redis.scard(game_conns_key) # Global singleton instance connection_manager = ConnectionManager()