""" WebSocket Connection Manager Manages WebSocket connections, session tracking, and room broadcasting. Includes session expiration for cleaning up zombie connections. """ import logging from dataclasses import dataclass, field from uuid import UUID import pendulum import socketio from pendulum import DateTime from app.config import get_settings logger = logging.getLogger(f"{__name__}.ConnectionManager") @dataclass class SessionInfo: """ Tracks metadata for a WebSocket session. Used for: - Identifying session owner (user_id) - Tracking session lifetime (connected_at) - Detecting zombie connections (last_activity) - Managing game room membership (games) """ user_id: str | None connected_at: DateTime last_activity: DateTime games: set[str] = field(default_factory=set) ip_address: str | None = None def inactive_seconds(self) -> float: """Return seconds since last activity.""" return (pendulum.now("UTC") - self.last_activity).total_seconds() class ConnectionManager: """ Manages WebSocket connections and rooms. Features: - Session lifecycle management (connect/disconnect) - Activity tracking for zombie detection - Game room management (join/leave/broadcast) - Session expiration for cleanup - Connection statistics for health monitoring """ def __init__(self, sio: socketio.AsyncServer): self.sio = sio self._sessions: dict[str, SessionInfo] = {} # sid -> SessionInfo self._user_sessions: dict[str, set[str]] = {} # user_id -> set of sids self.game_rooms: dict[str, set[str]] = {} # game_id -> set of sids @property def user_sessions(self) -> dict[str, str | None]: """ Backward-compatible property: returns sid -> user_id mapping. Used by existing tests and code that expects simple session tracking. """ return {sid: info.user_id for sid, info in self._sessions.items()} async def connect( self, sid: str, user_id: str, ip_address: str | None = None ) -> None: """ Register a new connection with session tracking. Args: sid: Socket.io session ID user_id: Authenticated user ID ip_address: Client IP address (optional, for logging) """ now = pendulum.now("UTC") self._sessions[sid] = SessionInfo( user_id=user_id, connected_at=now, last_activity=now, games=set(), ip_address=ip_address, ) # Track user's multiple sessions (e.g., multiple browser tabs) if user_id not in self._user_sessions: self._user_sessions[user_id] = set() self._user_sessions[user_id].add(sid) logger.info(f"User {user_id} connected with session {sid} from {ip_address}") async def disconnect(self, sid: str) -> None: """ Handle disconnection and cleanup session. Removes session from all tracking structures and game rooms. """ session = self._sessions.pop(sid, None) if session: # Remove from user tracking if session.user_id and session.user_id in self._user_sessions: self._user_sessions[session.user_id].discard(sid) if not self._user_sessions[session.user_id]: del self._user_sessions[session.user_id] # Remove from all game rooms for game_id in list(session.games): if game_id in self.game_rooms: self.game_rooms[game_id].discard(sid) await self.broadcast_to_game( game_id, "user_disconnected", {"user_id": session.user_id} ) duration = (pendulum.now("UTC") - session.connected_at).total_seconds() logger.info( f"User {session.user_id} disconnected (session {sid}, " f"duration: {duration:.0f}s)" ) else: logger.debug(f"Unknown session {sid} disconnected") async def update_activity(self, sid: str) -> None: """ Update last activity timestamp for session. Call this on any meaningful user action to prevent the session from being marked as zombie. """ if sid in self._sessions: self._sessions[sid].last_activity = pendulum.now("UTC") def get_session(self, sid: str) -> SessionInfo | None: """Get session info for a connection.""" return self._sessions.get(sid) def get_user_id(self, sid: str) -> str | None: """Get user ID for a session (convenience method).""" session = self._sessions.get(sid) return session.user_id if session else None async def join_game(self, sid: str, game_id: str, role: str) -> None: """ Add user to game room. Args: sid: Socket.io session ID game_id: Game UUID as string role: User role in game (player, spectator) """ await self.sio.enter_room(sid, game_id) if game_id not in self.game_rooms: self.game_rooms[game_id] = set() self.game_rooms[game_id].add(sid) # Update session's game tracking if sid in self._sessions: self._sessions[sid].games.add(game_id) user_id = self.get_user_id(sid) logger.info(f"User {user_id} joined game {game_id} as {role}") await self.broadcast_to_game( game_id, "user_connected", {"user_id": user_id, "role": role} ) # Update activity await self.update_activity(sid) async def leave_game(self, sid: str, game_id: str) -> None: """Remove user from game room.""" await self.sio.leave_room(sid, game_id) if game_id in self.game_rooms: self.game_rooms[game_id].discard(sid) # Update session's game tracking if sid in self._sessions: self._sessions[sid].games.discard(game_id) user_id = self.get_user_id(sid) logger.info(f"User {user_id} left game {game_id}") async def broadcast_to_game(self, game_id: str, event: str, data: dict) -> None: """Broadcast event to all users in game room.""" await self.sio.emit(event, data, room=game_id) logger.debug(f"Broadcast {event} to game {game_id}") async def emit_to_user(self, sid: str, event: str, data: dict) -> None: """Emit event to specific user.""" await self.sio.emit(event, data, room=sid) def get_game_participants(self, game_id: str) -> set[str]: """Get all session IDs in game room.""" return self.game_rooms.get(game_id, set()) async def expire_inactive_sessions(self, timeout_seconds: int | None = None) -> list[str]: """ Expire sessions with no activity beyond timeout. This is called periodically by a background task to clean up zombie connections that weren't properly disconnected. Args: timeout_seconds: Override default timeout (uses config if None) Returns: List of expired session IDs """ if timeout_seconds is None: settings = get_settings() # Use connection timeout as inactivity threshold (default 60s) # This is separate from Socket.io's ping_timeout which handles transport-level issues # This handles application-level inactivity (no events for extended period) timeout_seconds = settings.ws_connection_timeout * 5 # 5 min default expired = [] for sid, session in list(self._sessions.items()): inactive_secs = session.inactive_seconds() if inactive_secs > timeout_seconds: expired.append(sid) logger.warning( f"Expiring inactive session {sid} (user={session.user_id}, " f"inactive {inactive_secs:.0f}s)" ) for sid in expired: await self.disconnect(sid) # Force Socket.io to close the connection try: await self.sio.disconnect(sid) except Exception as e: logger.debug(f"Error disconnecting expired session {sid}: {e}") if expired: logger.info(f"Expired {len(expired)} inactive sessions") return expired def get_stats(self) -> dict: """ Return connection statistics for health monitoring. Includes: - Total active sessions - Unique connected users - Active game rooms - Per-game participant counts - Session age statistics """ now = pendulum.now("UTC") # Calculate session age stats session_ages = [] inactive_counts = {"<1m": 0, "1-5m": 0, "5-15m": 0, ">15m": 0} for session in self._sessions.values(): age = (now - session.connected_at).total_seconds() session_ages.append(age) inactive = session.inactive_seconds() if inactive < 60: inactive_counts["<1m"] += 1 elif inactive < 300: inactive_counts["1-5m"] += 1 elif inactive < 900: inactive_counts["5-15m"] += 1 else: inactive_counts[">15m"] += 1 return { "total_sessions": len(self._sessions), "unique_users": len(self._user_sessions), "active_game_rooms": len([r for r in self.game_rooms.values() if r]), "sessions_per_game": { gid: len(sids) for gid, sids in self.game_rooms.items() if sids }, "oldest_session_seconds": max(session_ages) if session_ages else 0, "avg_session_seconds": sum(session_ages) / len(session_ages) if session_ages else 0, "inactivity_distribution": inactive_counts, }