"""Tests for Socket.IO authentication middleware. This module tests JWT-based authentication for WebSocket connections, including token extraction, validation, and session management. """ from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 import pytest from app.socketio.auth import ( AuthHandler, AuthResult, extract_token, get_session_user_id, require_auth, ) class TestExtractToken: """Tests for token extraction from Socket.IO auth data.""" def test_extract_token_from_token_field(self) -> None: """Test extracting token from the primary 'token' field. The standard way for clients to pass the JWT is via auth.token. This is the recommended format for Socket.IO clients. """ auth = {"token": "my-jwt-token"} result = extract_token(auth) assert result == "my-jwt-token" def test_extract_token_from_authorization_bearer(self) -> None: """Test extracting token from Bearer authorization header format. Some clients may pass the token as a Bearer token for consistency with HTTP API authentication patterns. """ auth = {"authorization": "Bearer my-jwt-token"} result = extract_token(auth) assert result == "my-jwt-token" def test_extract_token_from_authorization_without_bearer(self) -> None: """Test extracting token from authorization without Bearer prefix. If the client provides just the token in authorization field, we should still accept it. """ auth = {"authorization": "my-jwt-token"} result = extract_token(auth) assert result == "my-jwt-token" def test_extract_token_from_access_token_field(self) -> None: """Test extracting token from access_token field. Alternative field name for OAuth-style clients. """ auth = {"access_token": "my-jwt-token"} result = extract_token(auth) assert result == "my-jwt-token" def test_extract_token_returns_none_for_none_auth(self) -> None: """Test that None auth data returns None token. Clients that don't provide any auth should get None, triggering the authentication failure path. """ result = extract_token(None) assert result is None def test_extract_token_returns_none_for_empty_auth(self) -> None: """Test that empty auth dict returns None token. An empty auth object should be treated as unauthenticated. """ result = extract_token({}) assert result is None def test_extract_token_returns_none_for_non_string_token(self) -> None: """Test that non-string token values are rejected. Only string tokens are valid - reject numbers, objects, etc. """ result = extract_token({"token": 12345}) assert result is None result = extract_token({"token": {"nested": "value"}}) assert result is None def test_extract_token_prefers_token_field(self) -> None: """Test that 'token' field takes precedence over alternatives. If multiple token fields are present, we should use the primary 'token' field. """ auth = { "token": "primary-token", "authorization": "Bearer secondary-token", "access_token": "tertiary-token", } result = extract_token(auth) assert result == "primary-token" class TestAuthenticateConnection: """Tests for connection authentication via AuthHandler.""" @pytest.fixture def mock_token_verifier(self) -> MagicMock: """Create a mock token verifier function.""" return MagicMock() @pytest.fixture def mock_connection_manager(self) -> AsyncMock: """Create a mock ConnectionManager.""" cm = AsyncMock() cm.register_connection = AsyncMock() cm.unregister_connection = AsyncMock() return cm @pytest.fixture def auth_handler( self, mock_token_verifier: MagicMock, mock_connection_manager: AsyncMock, ) -> AuthHandler: """Create AuthHandler with injected mocks.""" return AuthHandler( token_verifier=mock_token_verifier, conn_manager=mock_connection_manager, ) async def test_authenticate_success_with_valid_token( self, auth_handler: AuthHandler, mock_token_verifier: MagicMock, ) -> None: """Test successful authentication with a valid JWT. A valid access token should result in AuthResult with success=True and the user_id from the token. """ user_id = uuid4() mock_token_verifier.return_value = user_id result = await auth_handler.authenticate_connection("test-sid", {"token": "valid-token"}) assert result.success is True assert result.user_id == user_id assert result.error_code is None mock_token_verifier.assert_called_once_with("valid-token") async def test_authenticate_fails_with_missing_token( self, auth_handler: AuthHandler, ) -> None: """Test authentication failure when no token is provided. Connections without any auth data should fail with a 'missing_token' error code. """ result = await auth_handler.authenticate_connection("test-sid", None) assert result.success is False assert result.user_id is None assert result.error_code == "missing_token" assert "required" in result.error_message.lower() async def test_authenticate_fails_with_empty_auth( self, auth_handler: AuthHandler, ) -> None: """Test authentication failure with empty auth object. An auth object without any token fields should fail. """ result = await auth_handler.authenticate_connection("test-sid", {}) assert result.success is False assert result.error_code == "missing_token" async def test_authenticate_fails_with_invalid_token( self, auth_handler: AuthHandler, mock_token_verifier: MagicMock, ) -> None: """Test authentication failure with invalid/expired JWT. When verify_access_token returns None (invalid token), we should fail with 'invalid_token' error. """ mock_token_verifier.return_value = None result = await auth_handler.authenticate_connection("test-sid", {"token": "invalid-token"}) assert result.success is False assert result.user_id is None assert result.error_code == "invalid_token" assert ( "invalid" in result.error_message.lower() or "expired" in result.error_message.lower() ) async def test_authenticate_extracts_token_from_bearer( self, auth_handler: AuthHandler, mock_token_verifier: MagicMock, ) -> None: """Test that authentication works with Bearer format. The authenticate function should handle Bearer token format through the extract_token function. """ user_id = uuid4() mock_token_verifier.return_value = user_id result = await auth_handler.authenticate_connection( "test-sid", {"authorization": "Bearer my-token"} ) assert result.success is True mock_token_verifier.assert_called_once_with("my-token") class TestSetupAuthenticatedSession: """Tests for session setup after authentication.""" @pytest.fixture def mock_connection_manager(self) -> AsyncMock: """Create a mock ConnectionManager.""" cm = AsyncMock() cm.register_connection = AsyncMock() return cm @pytest.fixture def auth_handler(self, mock_connection_manager: AsyncMock) -> AuthHandler: """Create AuthHandler with injected mock.""" return AuthHandler(conn_manager=mock_connection_manager) async def test_setup_saves_session_data( self, auth_handler: AuthHandler, ) -> None: """Test that session setup saves user_id and timestamp. After authentication, the socket session should contain the user_id and authentication timestamp. """ user_id = uuid4() mock_sio = AsyncMock() await auth_handler.setup_authenticated_session(mock_sio, "test-sid", user_id) mock_sio.save_session.assert_called_once() call_args = mock_sio.save_session.call_args assert call_args.args[0] == "test-sid" session_data = call_args.args[1] assert session_data["user_id"] == str(user_id) assert "authenticated_at" in session_data async def test_setup_registers_with_connection_manager( self, auth_handler: AuthHandler, mock_connection_manager: AsyncMock, ) -> None: """Test that session setup registers with ConnectionManager. The connection should be tracked in ConnectionManager for presence detection and game association. """ user_id = uuid4() mock_sio = AsyncMock() await auth_handler.setup_authenticated_session(mock_sio, "test-sid", user_id) mock_connection_manager.register_connection.assert_called_once_with("test-sid", user_id) class TestCleanupAuthenticatedSession: """Tests for session cleanup on disconnect.""" @pytest.fixture def mock_connection_manager(self) -> AsyncMock: """Create a mock ConnectionManager.""" return AsyncMock() @pytest.fixture def auth_handler(self, mock_connection_manager: AsyncMock) -> AuthHandler: """Create AuthHandler with injected mock.""" return AuthHandler(conn_manager=mock_connection_manager) async def test_cleanup_unregisters_connection( self, auth_handler: AuthHandler, mock_connection_manager: AsyncMock, ) -> None: """Test that cleanup unregisters from ConnectionManager. On disconnect, the connection should be removed from ConnectionManager to update presence tracking. """ mock_conn_info = MagicMock() mock_conn_info.user_id = "user-123" mock_connection_manager.unregister_connection = AsyncMock(return_value=mock_conn_info) result = await auth_handler.cleanup_authenticated_session("test-sid") assert result == "user-123" mock_connection_manager.unregister_connection.assert_called_once_with("test-sid") async def test_cleanup_returns_none_for_unknown_session( self, auth_handler: AuthHandler, mock_connection_manager: AsyncMock, ) -> None: """Test cleanup returns None for non-existent sessions. If the connection wasn't registered (e.g., auth failed), cleanup should return None gracefully. """ mock_connection_manager.unregister_connection = AsyncMock(return_value=None) result = await auth_handler.cleanup_authenticated_session("unknown-sid") assert result is None class TestGetSessionUserId: """Tests for session user_id retrieval.""" async def test_get_session_user_id_returns_id(self) -> None: """Test retrieving user_id from authenticated session. For authenticated sessions, get_session_user_id should return the stored user_id string. """ mock_sio = AsyncMock() mock_sio.get_session = AsyncMock( return_value={"user_id": "user-123", "authenticated_at": "2024-01-01"} ) result = await get_session_user_id(mock_sio, "test-sid") assert result == "user-123" mock_sio.get_session.assert_called_once_with("test-sid", namespace="/game") async def test_get_session_user_id_returns_none_for_missing(self) -> None: """Test that missing session returns None. If no session exists for the sid, we should return None rather than raising an error. """ mock_sio = AsyncMock() mock_sio.get_session = AsyncMock(return_value=None) result = await get_session_user_id(mock_sio, "test-sid") assert result is None async def test_get_session_user_id_handles_exception(self) -> None: """Test that exceptions are caught and return None. If get_session raises an exception, we should catch it and return None to avoid breaking the event handler. """ mock_sio = AsyncMock() mock_sio.get_session = AsyncMock(side_effect=Exception("Session error")) result = await get_session_user_id(mock_sio, "test-sid") assert result is None class TestRequireAuth: """Tests for the require_auth helper.""" async def test_require_auth_returns_user_id_for_authenticated(self) -> None: """Test that require_auth returns user_id for valid sessions. Authenticated sessions should return the user_id for use in event handlers. """ mock_sio = AsyncMock() mock_sio.get_session = AsyncMock(return_value={"user_id": "user-123"}) result = await require_auth(mock_sio, "test-sid") assert result == "user-123" async def test_require_auth_returns_none_for_unauthenticated(self) -> None: """Test that require_auth returns None for unauthenticated sessions. Unauthenticated events should get None, allowing handlers to return an error response. """ mock_sio = AsyncMock() mock_sio.get_session = AsyncMock(return_value=None) result = await require_auth(mock_sio, "test-sid") assert result is None class TestAuthResultDataclass: """Tests for the AuthResult dataclass.""" def test_auth_result_success(self) -> None: """Test creating a successful AuthResult. Success results should have user_id and no error fields. """ user_id = uuid4() result = AuthResult(success=True, user_id=user_id) assert result.success is True assert result.user_id == user_id assert result.error_code is None assert result.error_message is None def test_auth_result_failure(self) -> None: """Test creating a failed AuthResult. Failure results should have error code/message and no user_id. """ result = AuthResult( success=False, error_code="invalid_token", error_message="Token expired", ) assert result.success is False assert result.user_id is None assert result.error_code == "invalid_token" assert result.error_message == "Token expired"