- ConnectionManager: Add redis_factory constructor parameter - GameService: Add engine_factory constructor parameter - AuthHandler: New class replacing standalone functions with token_verifier and conn_manager injection - Update all tests to use constructor DI instead of patch() - Update CLAUDE.md with factory injection patterns - Update services README with new patterns - Add socketio README documenting AuthHandler and events Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
436 lines
14 KiB
Python
436 lines
14 KiB
Python
"""Tests for Socket.IO authentication middleware.
|
|
|
|
This module tests JWT-based authentication for WebSocket connections,
|
|
including token extraction, validation, and session management.
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from app.socketio.auth import (
|
|
AuthHandler,
|
|
AuthResult,
|
|
extract_token,
|
|
get_session_user_id,
|
|
require_auth,
|
|
)
|
|
|
|
|
|
class TestExtractToken:
|
|
"""Tests for token extraction from Socket.IO auth data."""
|
|
|
|
def test_extract_token_from_token_field(self) -> None:
|
|
"""Test extracting token from the primary 'token' field.
|
|
|
|
The standard way for clients to pass the JWT is via auth.token.
|
|
This is the recommended format for Socket.IO clients.
|
|
"""
|
|
auth = {"token": "my-jwt-token"}
|
|
result = extract_token(auth)
|
|
assert result == "my-jwt-token"
|
|
|
|
def test_extract_token_from_authorization_bearer(self) -> None:
|
|
"""Test extracting token from Bearer authorization header format.
|
|
|
|
Some clients may pass the token as a Bearer token for consistency
|
|
with HTTP API authentication patterns.
|
|
"""
|
|
auth = {"authorization": "Bearer my-jwt-token"}
|
|
result = extract_token(auth)
|
|
assert result == "my-jwt-token"
|
|
|
|
def test_extract_token_from_authorization_without_bearer(self) -> None:
|
|
"""Test extracting token from authorization without Bearer prefix.
|
|
|
|
If the client provides just the token in authorization field,
|
|
we should still accept it.
|
|
"""
|
|
auth = {"authorization": "my-jwt-token"}
|
|
result = extract_token(auth)
|
|
assert result == "my-jwt-token"
|
|
|
|
def test_extract_token_from_access_token_field(self) -> None:
|
|
"""Test extracting token from access_token field.
|
|
|
|
Alternative field name for OAuth-style clients.
|
|
"""
|
|
auth = {"access_token": "my-jwt-token"}
|
|
result = extract_token(auth)
|
|
assert result == "my-jwt-token"
|
|
|
|
def test_extract_token_returns_none_for_none_auth(self) -> None:
|
|
"""Test that None auth data returns None token.
|
|
|
|
Clients that don't provide any auth should get None,
|
|
triggering the authentication failure path.
|
|
"""
|
|
result = extract_token(None)
|
|
assert result is None
|
|
|
|
def test_extract_token_returns_none_for_empty_auth(self) -> None:
|
|
"""Test that empty auth dict returns None token.
|
|
|
|
An empty auth object should be treated as unauthenticated.
|
|
"""
|
|
result = extract_token({})
|
|
assert result is None
|
|
|
|
def test_extract_token_returns_none_for_non_string_token(self) -> None:
|
|
"""Test that non-string token values are rejected.
|
|
|
|
Only string tokens are valid - reject numbers, objects, etc.
|
|
"""
|
|
result = extract_token({"token": 12345})
|
|
assert result is None
|
|
|
|
result = extract_token({"token": {"nested": "value"}})
|
|
assert result is None
|
|
|
|
def test_extract_token_prefers_token_field(self) -> None:
|
|
"""Test that 'token' field takes precedence over alternatives.
|
|
|
|
If multiple token fields are present, we should use the
|
|
primary 'token' field.
|
|
"""
|
|
auth = {
|
|
"token": "primary-token",
|
|
"authorization": "Bearer secondary-token",
|
|
"access_token": "tertiary-token",
|
|
}
|
|
result = extract_token(auth)
|
|
assert result == "primary-token"
|
|
|
|
|
|
class TestAuthenticateConnection:
|
|
"""Tests for connection authentication via AuthHandler."""
|
|
|
|
@pytest.fixture
|
|
def mock_token_verifier(self) -> MagicMock:
|
|
"""Create a mock token verifier function."""
|
|
return MagicMock()
|
|
|
|
@pytest.fixture
|
|
def mock_connection_manager(self) -> AsyncMock:
|
|
"""Create a mock ConnectionManager."""
|
|
cm = AsyncMock()
|
|
cm.register_connection = AsyncMock()
|
|
cm.unregister_connection = AsyncMock()
|
|
return cm
|
|
|
|
@pytest.fixture
|
|
def auth_handler(
|
|
self,
|
|
mock_token_verifier: MagicMock,
|
|
mock_connection_manager: AsyncMock,
|
|
) -> AuthHandler:
|
|
"""Create AuthHandler with injected mocks."""
|
|
return AuthHandler(
|
|
token_verifier=mock_token_verifier,
|
|
conn_manager=mock_connection_manager,
|
|
)
|
|
|
|
async def test_authenticate_success_with_valid_token(
|
|
self,
|
|
auth_handler: AuthHandler,
|
|
mock_token_verifier: MagicMock,
|
|
) -> None:
|
|
"""Test successful authentication with a valid JWT.
|
|
|
|
A valid access token should result in AuthResult with
|
|
success=True and the user_id from the token.
|
|
"""
|
|
user_id = uuid4()
|
|
mock_token_verifier.return_value = user_id
|
|
|
|
result = await auth_handler.authenticate_connection("test-sid", {"token": "valid-token"})
|
|
|
|
assert result.success is True
|
|
assert result.user_id == user_id
|
|
assert result.error_code is None
|
|
mock_token_verifier.assert_called_once_with("valid-token")
|
|
|
|
async def test_authenticate_fails_with_missing_token(
|
|
self,
|
|
auth_handler: AuthHandler,
|
|
) -> None:
|
|
"""Test authentication failure when no token is provided.
|
|
|
|
Connections without any auth data should fail with
|
|
a 'missing_token' error code.
|
|
"""
|
|
result = await auth_handler.authenticate_connection("test-sid", None)
|
|
|
|
assert result.success is False
|
|
assert result.user_id is None
|
|
assert result.error_code == "missing_token"
|
|
assert "required" in result.error_message.lower()
|
|
|
|
async def test_authenticate_fails_with_empty_auth(
|
|
self,
|
|
auth_handler: AuthHandler,
|
|
) -> None:
|
|
"""Test authentication failure with empty auth object.
|
|
|
|
An auth object without any token fields should fail.
|
|
"""
|
|
result = await auth_handler.authenticate_connection("test-sid", {})
|
|
|
|
assert result.success is False
|
|
assert result.error_code == "missing_token"
|
|
|
|
async def test_authenticate_fails_with_invalid_token(
|
|
self,
|
|
auth_handler: AuthHandler,
|
|
mock_token_verifier: MagicMock,
|
|
) -> None:
|
|
"""Test authentication failure with invalid/expired JWT.
|
|
|
|
When verify_access_token returns None (invalid token),
|
|
we should fail with 'invalid_token' error.
|
|
"""
|
|
mock_token_verifier.return_value = None
|
|
|
|
result = await auth_handler.authenticate_connection("test-sid", {"token": "invalid-token"})
|
|
|
|
assert result.success is False
|
|
assert result.user_id is None
|
|
assert result.error_code == "invalid_token"
|
|
assert (
|
|
"invalid" in result.error_message.lower() or "expired" in result.error_message.lower()
|
|
)
|
|
|
|
async def test_authenticate_extracts_token_from_bearer(
|
|
self,
|
|
auth_handler: AuthHandler,
|
|
mock_token_verifier: MagicMock,
|
|
) -> None:
|
|
"""Test that authentication works with Bearer format.
|
|
|
|
The authenticate function should handle Bearer token format
|
|
through the extract_token function.
|
|
"""
|
|
user_id = uuid4()
|
|
mock_token_verifier.return_value = user_id
|
|
|
|
result = await auth_handler.authenticate_connection(
|
|
"test-sid", {"authorization": "Bearer my-token"}
|
|
)
|
|
|
|
assert result.success is True
|
|
mock_token_verifier.assert_called_once_with("my-token")
|
|
|
|
|
|
class TestSetupAuthenticatedSession:
|
|
"""Tests for session setup after authentication."""
|
|
|
|
@pytest.fixture
|
|
def mock_connection_manager(self) -> AsyncMock:
|
|
"""Create a mock ConnectionManager."""
|
|
cm = AsyncMock()
|
|
cm.register_connection = AsyncMock()
|
|
return cm
|
|
|
|
@pytest.fixture
|
|
def auth_handler(self, mock_connection_manager: AsyncMock) -> AuthHandler:
|
|
"""Create AuthHandler with injected mock."""
|
|
return AuthHandler(conn_manager=mock_connection_manager)
|
|
|
|
async def test_setup_saves_session_data(
|
|
self,
|
|
auth_handler: AuthHandler,
|
|
) -> None:
|
|
"""Test that session setup saves user_id and timestamp.
|
|
|
|
After authentication, the socket session should contain
|
|
the user_id and authentication timestamp.
|
|
"""
|
|
user_id = uuid4()
|
|
mock_sio = AsyncMock()
|
|
|
|
await auth_handler.setup_authenticated_session(mock_sio, "test-sid", user_id)
|
|
|
|
mock_sio.save_session.assert_called_once()
|
|
call_args = mock_sio.save_session.call_args
|
|
assert call_args.args[0] == "test-sid"
|
|
|
|
session_data = call_args.args[1]
|
|
assert session_data["user_id"] == str(user_id)
|
|
assert "authenticated_at" in session_data
|
|
|
|
async def test_setup_registers_with_connection_manager(
|
|
self,
|
|
auth_handler: AuthHandler,
|
|
mock_connection_manager: AsyncMock,
|
|
) -> None:
|
|
"""Test that session setup registers with ConnectionManager.
|
|
|
|
The connection should be tracked in ConnectionManager for
|
|
presence detection and game association.
|
|
"""
|
|
user_id = uuid4()
|
|
mock_sio = AsyncMock()
|
|
|
|
await auth_handler.setup_authenticated_session(mock_sio, "test-sid", user_id)
|
|
|
|
mock_connection_manager.register_connection.assert_called_once_with("test-sid", user_id)
|
|
|
|
|
|
class TestCleanupAuthenticatedSession:
|
|
"""Tests for session cleanup on disconnect."""
|
|
|
|
@pytest.fixture
|
|
def mock_connection_manager(self) -> AsyncMock:
|
|
"""Create a mock ConnectionManager."""
|
|
return AsyncMock()
|
|
|
|
@pytest.fixture
|
|
def auth_handler(self, mock_connection_manager: AsyncMock) -> AuthHandler:
|
|
"""Create AuthHandler with injected mock."""
|
|
return AuthHandler(conn_manager=mock_connection_manager)
|
|
|
|
async def test_cleanup_unregisters_connection(
|
|
self,
|
|
auth_handler: AuthHandler,
|
|
mock_connection_manager: AsyncMock,
|
|
) -> None:
|
|
"""Test that cleanup unregisters from ConnectionManager.
|
|
|
|
On disconnect, the connection should be removed from
|
|
ConnectionManager to update presence tracking.
|
|
"""
|
|
mock_conn_info = MagicMock()
|
|
mock_conn_info.user_id = "user-123"
|
|
mock_connection_manager.unregister_connection = AsyncMock(return_value=mock_conn_info)
|
|
|
|
result = await auth_handler.cleanup_authenticated_session("test-sid")
|
|
|
|
assert result == "user-123"
|
|
mock_connection_manager.unregister_connection.assert_called_once_with("test-sid")
|
|
|
|
async def test_cleanup_returns_none_for_unknown_session(
|
|
self,
|
|
auth_handler: AuthHandler,
|
|
mock_connection_manager: AsyncMock,
|
|
) -> None:
|
|
"""Test cleanup returns None for non-existent sessions.
|
|
|
|
If the connection wasn't registered (e.g., auth failed),
|
|
cleanup should return None gracefully.
|
|
"""
|
|
mock_connection_manager.unregister_connection = AsyncMock(return_value=None)
|
|
|
|
result = await auth_handler.cleanup_authenticated_session("unknown-sid")
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestGetSessionUserId:
|
|
"""Tests for session user_id retrieval."""
|
|
|
|
async def test_get_session_user_id_returns_id(self) -> None:
|
|
"""Test retrieving user_id from authenticated session.
|
|
|
|
For authenticated sessions, get_session_user_id should
|
|
return the stored user_id string.
|
|
"""
|
|
mock_sio = AsyncMock()
|
|
mock_sio.get_session = AsyncMock(
|
|
return_value={"user_id": "user-123", "authenticated_at": "2024-01-01"}
|
|
)
|
|
|
|
result = await get_session_user_id(mock_sio, "test-sid")
|
|
|
|
assert result == "user-123"
|
|
mock_sio.get_session.assert_called_once_with("test-sid", namespace="/game")
|
|
|
|
async def test_get_session_user_id_returns_none_for_missing(self) -> None:
|
|
"""Test that missing session returns None.
|
|
|
|
If no session exists for the sid, we should return None
|
|
rather than raising an error.
|
|
"""
|
|
mock_sio = AsyncMock()
|
|
mock_sio.get_session = AsyncMock(return_value=None)
|
|
|
|
result = await get_session_user_id(mock_sio, "test-sid")
|
|
|
|
assert result is None
|
|
|
|
async def test_get_session_user_id_handles_exception(self) -> None:
|
|
"""Test that exceptions are caught and return None.
|
|
|
|
If get_session raises an exception, we should catch it
|
|
and return None to avoid breaking the event handler.
|
|
"""
|
|
mock_sio = AsyncMock()
|
|
mock_sio.get_session = AsyncMock(side_effect=Exception("Session error"))
|
|
|
|
result = await get_session_user_id(mock_sio, "test-sid")
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestRequireAuth:
|
|
"""Tests for the require_auth helper."""
|
|
|
|
async def test_require_auth_returns_user_id_for_authenticated(self) -> None:
|
|
"""Test that require_auth returns user_id for valid sessions.
|
|
|
|
Authenticated sessions should return the user_id for use
|
|
in event handlers.
|
|
"""
|
|
mock_sio = AsyncMock()
|
|
mock_sio.get_session = AsyncMock(return_value={"user_id": "user-123"})
|
|
|
|
result = await require_auth(mock_sio, "test-sid")
|
|
|
|
assert result == "user-123"
|
|
|
|
async def test_require_auth_returns_none_for_unauthenticated(self) -> None:
|
|
"""Test that require_auth returns None for unauthenticated sessions.
|
|
|
|
Unauthenticated events should get None, allowing handlers
|
|
to return an error response.
|
|
"""
|
|
mock_sio = AsyncMock()
|
|
mock_sio.get_session = AsyncMock(return_value=None)
|
|
|
|
result = await require_auth(mock_sio, "test-sid")
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestAuthResultDataclass:
|
|
"""Tests for the AuthResult dataclass."""
|
|
|
|
def test_auth_result_success(self) -> None:
|
|
"""Test creating a successful AuthResult.
|
|
|
|
Success results should have user_id and no error fields.
|
|
"""
|
|
user_id = uuid4()
|
|
result = AuthResult(success=True, user_id=user_id)
|
|
|
|
assert result.success is True
|
|
assert result.user_id == user_id
|
|
assert result.error_code is None
|
|
assert result.error_message is None
|
|
|
|
def test_auth_result_failure(self) -> None:
|
|
"""Test creating a failed AuthResult.
|
|
|
|
Failure results should have error code/message and no user_id.
|
|
"""
|
|
result = AuthResult(
|
|
success=False,
|
|
error_code="invalid_token",
|
|
error_message="Token expired",
|
|
)
|
|
|
|
assert result.success is False
|
|
assert result.user_id is None
|
|
assert result.error_code == "invalid_token"
|
|
assert result.error_message == "Token expired"
|