"""Socket.IO authentication middleware for WebSocket connections. This module provides JWT-based authentication for Socket.IO connections. It validates access tokens and attaches user information to the socket session. Authentication Flow: 1. Client connects with `auth: { token: "JWT_ACCESS_TOKEN" }` 2. Server extracts and validates the JWT 3. If valid, user_id is stored in socket session 4. If invalid, connection is rejected with appropriate error Session Data: After successful authentication, the socket session contains: - user_id: str (UUID as string) - authenticated_at: str (ISO timestamp) Example: # In connect handler: auth_result = await auth_handler.authenticate_connection(sid, auth) if not auth_result.success: return False # Reject connection # Later, get user_id: session = await sio.get_session(sid, namespace="/game") user_id = session.get("user_id") """ import logging from collections.abc import Callable from dataclasses import dataclass from datetime import UTC, datetime from uuid import UUID from app.services.connection_manager import ConnectionManager, connection_manager from app.services.jwt_service import verify_access_token logger = logging.getLogger(__name__) # Type alias for token verifier - takes token string, returns user_id or None TokenVerifier = Callable[[str], UUID | None] @dataclass class AuthResult: """Result of authentication attempt. Attributes: success: Whether authentication succeeded. user_id: User's UUID if successful, None otherwise. error_code: Error code for client if failed. error_message: Human-readable error message if failed. """ success: bool user_id: UUID | None = None error_code: str | None = None error_message: str | None = None class AuthHandler: """Handler for Socket.IO authentication with injectable dependencies. This class encapsulates authentication logic with dependencies injected via constructor, making it testable without monkey patching. Attributes: _token_verifier: Function to verify JWT tokens. _connection_manager: ConnectionManager for tracking connections. """ def __init__( self, token_verifier: TokenVerifier | None = None, conn_manager: ConnectionManager | None = None, ) -> None: """Initialize AuthHandler with dependencies. Args: token_verifier: Function to verify tokens. Uses jwt_service if not provided. conn_manager: ConnectionManager instance. Uses global if not provided. """ self._token_verifier = token_verifier or verify_access_token self._connection_manager = conn_manager or connection_manager async def authenticate_connection( self, sid: str, auth: dict[str, object] | None, ) -> AuthResult: """Authenticate a Socket.IO connection using JWT. Extracts the JWT from auth data, validates it, and returns the result. Does NOT modify socket session - caller should handle that. Args: sid: Socket session ID (for logging). auth: Authentication data from connect event. Returns: AuthResult with success status and user_id or error details. """ token = extract_token(auth) if token is None: logger.debug(f"Connection {sid}: No token provided") return AuthResult( success=False, error_code="missing_token", error_message="Authentication token required", ) user_id = self._token_verifier(token) if user_id is None: logger.debug(f"Connection {sid}: Invalid or expired token") return AuthResult( success=False, error_code="invalid_token", error_message="Invalid or expired token", ) logger.debug(f"Connection {sid}: Authenticated as user {user_id}") return AuthResult( success=True, user_id=user_id, ) async def setup_authenticated_session( self, sio: object, sid: str, user_id: UUID, namespace: str = "/game", ) -> None: """Set up socket session with authenticated user data. Saves user_id and authentication timestamp to the socket session, and registers the connection with ConnectionManager. Args: sio: Socket.IO AsyncServer instance (required). sid: Socket session ID. user_id: Authenticated user's UUID. namespace: Socket.IO namespace. """ session_data = { "user_id": str(user_id), "authenticated_at": datetime.now(UTC).isoformat(), } await sio.save_session(sid, session_data, namespace=namespace) await self._connection_manager.register_connection(sid, user_id) logger.info(f"Session established: sid={sid}, user_id={user_id}") async def cleanup_authenticated_session( self, sid: str, namespace: str = "/game", ) -> str | None: """Clean up session data on disconnect. Unregisters the connection from ConnectionManager and returns the user_id for any additional cleanup needed. Args: sid: Socket session ID. namespace: Socket.IO namespace. Returns: user_id if session was authenticated, None otherwise. """ conn_info = await self._connection_manager.unregister_connection(sid) if conn_info: logger.info(f"Session cleaned up: sid={sid}, user_id={conn_info.user_id}") return conn_info.user_id logger.debug(f"No session to clean up for {sid}") return None # Global singleton instance auth_handler = AuthHandler() def extract_token(auth: dict[str, object] | None) -> str | None: """Extract JWT token from Socket.IO auth data. Clients should send the token in the auth dict: socket.connect({ auth: { token: "JWT_TOKEN" } }) Also supports: - auth.authorization: "Bearer TOKEN" - auth.access_token: "TOKEN" Args: auth: Authentication data from Socket.IO connect. Returns: JWT token string if found, None otherwise. Example: token = extract_token({"token": "eyJ..."}) token = extract_token({"authorization": "Bearer eyJ..."}) """ if auth is None: return None # Primary: auth.token token = auth.get("token") if token and isinstance(token, str): return token # Alternative: auth.authorization (Bearer token) authorization = auth.get("authorization") if authorization and isinstance(authorization, str): if authorization.lower().startswith("bearer "): return authorization[7:] return authorization # Alternative: auth.access_token access_token = auth.get("access_token") if access_token and isinstance(access_token, str): return access_token return None async def get_session_user_id( sio: object, sid: str, namespace: str = "/game", ) -> str | None: """Get the authenticated user_id from a socket session. Convenience function to extract user_id from session data. Args: sio: Socket.IO AsyncServer instance. sid: Socket session ID. namespace: Socket.IO namespace. Returns: user_id string if authenticated, None otherwise. Example: user_id = await get_session_user_id(sio, sid) if not user_id: await sio.emit("error", {"message": "Not authenticated"}, to=sid) return """ try: session = await sio.get_session(sid, namespace=namespace) return session.get("user_id") if session else None except Exception: return None async def require_auth( sio: object, sid: str, namespace: str = "/game", ) -> str | None: """Require authentication for an event handler. Returns the user_id if authenticated, None if not. Logs a warning if authentication is missing. Args: sio: Socket.IO AsyncServer instance. sid: Socket session ID. namespace: Socket.IO namespace. Returns: user_id string if authenticated, None otherwise. Example: @sio.on("game:action", namespace="/game") async def on_action(sid, data): user_id = await require_auth(sio, sid) if not user_id: return {"error": "Not authenticated"} # ... handle action """ user_id = await get_session_user_id(sio, sid, namespace) if user_id is None: logger.warning(f"Unauthenticated event from {sid}") return user_id