mantimon-tcg/backend/tests/socketio/test_auth.py
Cal Corum f512c7b2b3 Refactor to dependency injection pattern - no monkey patching
- ConnectionManager: Add redis_factory constructor parameter
- GameService: Add engine_factory constructor parameter
- AuthHandler: New class replacing standalone functions with
  token_verifier and conn_manager injection
- Update all tests to use constructor DI instead of patch()
- Update CLAUDE.md with factory injection patterns
- Update services README with new patterns
- Add socketio README documenting AuthHandler and events

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 22:54:57 -06:00

436 lines
14 KiB
Python

"""Tests for Socket.IO authentication middleware.
This module tests JWT-based authentication for WebSocket connections,
including token extraction, validation, and session management.
"""
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from app.socketio.auth import (
AuthHandler,
AuthResult,
extract_token,
get_session_user_id,
require_auth,
)
class TestExtractToken:
"""Tests for token extraction from Socket.IO auth data."""
def test_extract_token_from_token_field(self) -> None:
"""Test extracting token from the primary 'token' field.
The standard way for clients to pass the JWT is via auth.token.
This is the recommended format for Socket.IO clients.
"""
auth = {"token": "my-jwt-token"}
result = extract_token(auth)
assert result == "my-jwt-token"
def test_extract_token_from_authorization_bearer(self) -> None:
"""Test extracting token from Bearer authorization header format.
Some clients may pass the token as a Bearer token for consistency
with HTTP API authentication patterns.
"""
auth = {"authorization": "Bearer my-jwt-token"}
result = extract_token(auth)
assert result == "my-jwt-token"
def test_extract_token_from_authorization_without_bearer(self) -> None:
"""Test extracting token from authorization without Bearer prefix.
If the client provides just the token in authorization field,
we should still accept it.
"""
auth = {"authorization": "my-jwt-token"}
result = extract_token(auth)
assert result == "my-jwt-token"
def test_extract_token_from_access_token_field(self) -> None:
"""Test extracting token from access_token field.
Alternative field name for OAuth-style clients.
"""
auth = {"access_token": "my-jwt-token"}
result = extract_token(auth)
assert result == "my-jwt-token"
def test_extract_token_returns_none_for_none_auth(self) -> None:
"""Test that None auth data returns None token.
Clients that don't provide any auth should get None,
triggering the authentication failure path.
"""
result = extract_token(None)
assert result is None
def test_extract_token_returns_none_for_empty_auth(self) -> None:
"""Test that empty auth dict returns None token.
An empty auth object should be treated as unauthenticated.
"""
result = extract_token({})
assert result is None
def test_extract_token_returns_none_for_non_string_token(self) -> None:
"""Test that non-string token values are rejected.
Only string tokens are valid - reject numbers, objects, etc.
"""
result = extract_token({"token": 12345})
assert result is None
result = extract_token({"token": {"nested": "value"}})
assert result is None
def test_extract_token_prefers_token_field(self) -> None:
"""Test that 'token' field takes precedence over alternatives.
If multiple token fields are present, we should use the
primary 'token' field.
"""
auth = {
"token": "primary-token",
"authorization": "Bearer secondary-token",
"access_token": "tertiary-token",
}
result = extract_token(auth)
assert result == "primary-token"
class TestAuthenticateConnection:
"""Tests for connection authentication via AuthHandler."""
@pytest.fixture
def mock_token_verifier(self) -> MagicMock:
"""Create a mock token verifier function."""
return MagicMock()
@pytest.fixture
def mock_connection_manager(self) -> AsyncMock:
"""Create a mock ConnectionManager."""
cm = AsyncMock()
cm.register_connection = AsyncMock()
cm.unregister_connection = AsyncMock()
return cm
@pytest.fixture
def auth_handler(
self,
mock_token_verifier: MagicMock,
mock_connection_manager: AsyncMock,
) -> AuthHandler:
"""Create AuthHandler with injected mocks."""
return AuthHandler(
token_verifier=mock_token_verifier,
conn_manager=mock_connection_manager,
)
async def test_authenticate_success_with_valid_token(
self,
auth_handler: AuthHandler,
mock_token_verifier: MagicMock,
) -> None:
"""Test successful authentication with a valid JWT.
A valid access token should result in AuthResult with
success=True and the user_id from the token.
"""
user_id = uuid4()
mock_token_verifier.return_value = user_id
result = await auth_handler.authenticate_connection("test-sid", {"token": "valid-token"})
assert result.success is True
assert result.user_id == user_id
assert result.error_code is None
mock_token_verifier.assert_called_once_with("valid-token")
async def test_authenticate_fails_with_missing_token(
self,
auth_handler: AuthHandler,
) -> None:
"""Test authentication failure when no token is provided.
Connections without any auth data should fail with
a 'missing_token' error code.
"""
result = await auth_handler.authenticate_connection("test-sid", None)
assert result.success is False
assert result.user_id is None
assert result.error_code == "missing_token"
assert "required" in result.error_message.lower()
async def test_authenticate_fails_with_empty_auth(
self,
auth_handler: AuthHandler,
) -> None:
"""Test authentication failure with empty auth object.
An auth object without any token fields should fail.
"""
result = await auth_handler.authenticate_connection("test-sid", {})
assert result.success is False
assert result.error_code == "missing_token"
async def test_authenticate_fails_with_invalid_token(
self,
auth_handler: AuthHandler,
mock_token_verifier: MagicMock,
) -> None:
"""Test authentication failure with invalid/expired JWT.
When verify_access_token returns None (invalid token),
we should fail with 'invalid_token' error.
"""
mock_token_verifier.return_value = None
result = await auth_handler.authenticate_connection("test-sid", {"token": "invalid-token"})
assert result.success is False
assert result.user_id is None
assert result.error_code == "invalid_token"
assert (
"invalid" in result.error_message.lower() or "expired" in result.error_message.lower()
)
async def test_authenticate_extracts_token_from_bearer(
self,
auth_handler: AuthHandler,
mock_token_verifier: MagicMock,
) -> None:
"""Test that authentication works with Bearer format.
The authenticate function should handle Bearer token format
through the extract_token function.
"""
user_id = uuid4()
mock_token_verifier.return_value = user_id
result = await auth_handler.authenticate_connection(
"test-sid", {"authorization": "Bearer my-token"}
)
assert result.success is True
mock_token_verifier.assert_called_once_with("my-token")
class TestSetupAuthenticatedSession:
"""Tests for session setup after authentication."""
@pytest.fixture
def mock_connection_manager(self) -> AsyncMock:
"""Create a mock ConnectionManager."""
cm = AsyncMock()
cm.register_connection = AsyncMock()
return cm
@pytest.fixture
def auth_handler(self, mock_connection_manager: AsyncMock) -> AuthHandler:
"""Create AuthHandler with injected mock."""
return AuthHandler(conn_manager=mock_connection_manager)
async def test_setup_saves_session_data(
self,
auth_handler: AuthHandler,
) -> None:
"""Test that session setup saves user_id and timestamp.
After authentication, the socket session should contain
the user_id and authentication timestamp.
"""
user_id = uuid4()
mock_sio = AsyncMock()
await auth_handler.setup_authenticated_session(mock_sio, "test-sid", user_id)
mock_sio.save_session.assert_called_once()
call_args = mock_sio.save_session.call_args
assert call_args.args[0] == "test-sid"
session_data = call_args.args[1]
assert session_data["user_id"] == str(user_id)
assert "authenticated_at" in session_data
async def test_setup_registers_with_connection_manager(
self,
auth_handler: AuthHandler,
mock_connection_manager: AsyncMock,
) -> None:
"""Test that session setup registers with ConnectionManager.
The connection should be tracked in ConnectionManager for
presence detection and game association.
"""
user_id = uuid4()
mock_sio = AsyncMock()
await auth_handler.setup_authenticated_session(mock_sio, "test-sid", user_id)
mock_connection_manager.register_connection.assert_called_once_with("test-sid", user_id)
class TestCleanupAuthenticatedSession:
"""Tests for session cleanup on disconnect."""
@pytest.fixture
def mock_connection_manager(self) -> AsyncMock:
"""Create a mock ConnectionManager."""
return AsyncMock()
@pytest.fixture
def auth_handler(self, mock_connection_manager: AsyncMock) -> AuthHandler:
"""Create AuthHandler with injected mock."""
return AuthHandler(conn_manager=mock_connection_manager)
async def test_cleanup_unregisters_connection(
self,
auth_handler: AuthHandler,
mock_connection_manager: AsyncMock,
) -> None:
"""Test that cleanup unregisters from ConnectionManager.
On disconnect, the connection should be removed from
ConnectionManager to update presence tracking.
"""
mock_conn_info = MagicMock()
mock_conn_info.user_id = "user-123"
mock_connection_manager.unregister_connection = AsyncMock(return_value=mock_conn_info)
result = await auth_handler.cleanup_authenticated_session("test-sid")
assert result == "user-123"
mock_connection_manager.unregister_connection.assert_called_once_with("test-sid")
async def test_cleanup_returns_none_for_unknown_session(
self,
auth_handler: AuthHandler,
mock_connection_manager: AsyncMock,
) -> None:
"""Test cleanup returns None for non-existent sessions.
If the connection wasn't registered (e.g., auth failed),
cleanup should return None gracefully.
"""
mock_connection_manager.unregister_connection = AsyncMock(return_value=None)
result = await auth_handler.cleanup_authenticated_session("unknown-sid")
assert result is None
class TestGetSessionUserId:
"""Tests for session user_id retrieval."""
async def test_get_session_user_id_returns_id(self) -> None:
"""Test retrieving user_id from authenticated session.
For authenticated sessions, get_session_user_id should
return the stored user_id string.
"""
mock_sio = AsyncMock()
mock_sio.get_session = AsyncMock(
return_value={"user_id": "user-123", "authenticated_at": "2024-01-01"}
)
result = await get_session_user_id(mock_sio, "test-sid")
assert result == "user-123"
mock_sio.get_session.assert_called_once_with("test-sid", namespace="/game")
async def test_get_session_user_id_returns_none_for_missing(self) -> None:
"""Test that missing session returns None.
If no session exists for the sid, we should return None
rather than raising an error.
"""
mock_sio = AsyncMock()
mock_sio.get_session = AsyncMock(return_value=None)
result = await get_session_user_id(mock_sio, "test-sid")
assert result is None
async def test_get_session_user_id_handles_exception(self) -> None:
"""Test that exceptions are caught and return None.
If get_session raises an exception, we should catch it
and return None to avoid breaking the event handler.
"""
mock_sio = AsyncMock()
mock_sio.get_session = AsyncMock(side_effect=Exception("Session error"))
result = await get_session_user_id(mock_sio, "test-sid")
assert result is None
class TestRequireAuth:
"""Tests for the require_auth helper."""
async def test_require_auth_returns_user_id_for_authenticated(self) -> None:
"""Test that require_auth returns user_id for valid sessions.
Authenticated sessions should return the user_id for use
in event handlers.
"""
mock_sio = AsyncMock()
mock_sio.get_session = AsyncMock(return_value={"user_id": "user-123"})
result = await require_auth(mock_sio, "test-sid")
assert result == "user-123"
async def test_require_auth_returns_none_for_unauthenticated(self) -> None:
"""Test that require_auth returns None for unauthenticated sessions.
Unauthenticated events should get None, allowing handlers
to return an error response.
"""
mock_sio = AsyncMock()
mock_sio.get_session = AsyncMock(return_value=None)
result = await require_auth(mock_sio, "test-sid")
assert result is None
class TestAuthResultDataclass:
"""Tests for the AuthResult dataclass."""
def test_auth_result_success(self) -> None:
"""Test creating a successful AuthResult.
Success results should have user_id and no error fields.
"""
user_id = uuid4()
result = AuthResult(success=True, user_id=user_id)
assert result.success is True
assert result.user_id == user_id
assert result.error_code is None
assert result.error_message is None
def test_auth_result_failure(self) -> None:
"""Test creating a failed AuthResult.
Failure results should have error code/message and no user_id.
"""
result = AuthResult(
success=False,
error_code="invalid_token",
error_message="Token expired",
)
assert result.success is False
assert result.user_id is None
assert result.error_code == "invalid_token"
assert result.error_message == "Token expired"