"""Tests for ConnectionManager service. This module tests WebSocket connection tracking with Redis. Since these are unit tests, we inject a mock Redis factory to test the ConnectionManager logic without requiring a real Redis instance. Uses dependency injection pattern - no monkey patching required. """ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from datetime import UTC, datetime, timedelta from unittest.mock import AsyncMock from uuid import uuid4 import pytest from app.services.connection_manager import ( CONN_PREFIX, GAME_CONNS_PREFIX, HEARTBEAT_INTERVAL_SECONDS, USER_CONN_PREFIX, ConnectionInfo, ConnectionManager, ) @pytest.fixture def mock_redis() -> AsyncMock: """Create a mock Redis client. This mock provides all the Redis operations used by ConnectionManager with sensible defaults that can be overridden per-test. """ redis = AsyncMock() redis.hset = AsyncMock() redis.hget = AsyncMock() redis.hgetall = AsyncMock(return_value={}) redis.set = AsyncMock() redis.get = AsyncMock(return_value=None) redis.delete = AsyncMock() redis.exists = AsyncMock(return_value=False) redis.expire = AsyncMock() redis.sadd = AsyncMock() redis.srem = AsyncMock() redis.smembers = AsyncMock(return_value=set()) redis.scard = AsyncMock(return_value=0) return redis @pytest.fixture def manager(mock_redis: AsyncMock) -> ConnectionManager: """Create a ConnectionManager with injected mock Redis. The mock Redis is injected via the redis_factory parameter, eliminating the need for monkey patching in tests. """ @asynccontextmanager async def mock_redis_factory() -> AsyncIterator[AsyncMock]: yield mock_redis return ConnectionManager(conn_ttl_seconds=3600, redis_factory=mock_redis_factory) class TestConnectionInfoDataclass: """Tests for the ConnectionInfo dataclass.""" def test_is_stale_returns_false_for_recent_connection(self) -> None: """Test that recent connections are not marked as stale. Connections that have been seen within the threshold should be considered active, not stale. """ now = datetime.now(UTC) info = ConnectionInfo( sid="test-sid", user_id="user-123", game_id=None, connected_at=now, last_seen=now, ) assert info.is_stale() is False def test_is_stale_returns_true_for_old_connection(self) -> None: """Test that old connections are marked as stale. Connections that haven't been seen for longer than the threshold should be considered stale and eligible for cleanup. """ now = datetime.now(UTC) old_time = now - timedelta(seconds=HEARTBEAT_INTERVAL_SECONDS * 4) info = ConnectionInfo( sid="test-sid", user_id="user-123", game_id=None, connected_at=old_time, last_seen=old_time, ) assert info.is_stale() is True def test_is_stale_with_custom_threshold(self) -> None: """Test is_stale with a custom threshold. The threshold parameter allows callers to customize what counts as stale for different use cases. """ now = datetime.now(UTC) last_seen = now - timedelta(seconds=60) info = ConnectionInfo( sid="test-sid", user_id="user-123", game_id=None, connected_at=last_seen, last_seen=last_seen, ) # 60 seconds old, with 30 second threshold = stale assert info.is_stale(threshold_seconds=30) is True # 60 seconds old, with 120 second threshold = not stale assert info.is_stale(threshold_seconds=120) is False class TestRegisterConnection: """Tests for connection registration.""" @pytest.mark.asyncio async def test_register_connection_creates_records( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that registering a connection creates all necessary Redis records. Registration should create: 1. Connection hash with user_id, game_id, timestamps 2. User-to-connection mapping """ sid = "test-sid-123" user_id = str(uuid4()) await manager.register_connection(sid, user_id) # Verify connection hash was created conn_key = f"{CONN_PREFIX}{sid}" mock_redis.hset.assert_called() hset_call = mock_redis.hset.call_args assert hset_call.args[0] == conn_key mapping = hset_call.kwargs["mapping"] assert mapping["user_id"] == user_id assert mapping["game_id"] == "" # Verify user-to-connection mapping was created user_conn_key = f"{USER_CONN_PREFIX}{user_id}" mock_redis.set.assert_called_with(user_conn_key, sid) @pytest.mark.asyncio async def test_register_connection_replaces_old_connection( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that registering a new connection replaces the old one. When a user reconnects, their old connection should be cleaned up to prevent stale connection data from lingering. """ old_sid = "old-sid" new_sid = "new-sid" user_id = str(uuid4()) # Mock: user has existing connection mock_redis.get.return_value = old_sid mock_redis.hgetall.return_value = { "user_id": user_id, "game_id": "", "connected_at": datetime.now(UTC).isoformat(), "last_seen": datetime.now(UTC).isoformat(), } await manager.register_connection(new_sid, user_id) # Verify old connection was cleaned up (delete called for old conn key) old_conn_key = f"{CONN_PREFIX}{old_sid}" delete_calls = [call.args[0] for call in mock_redis.delete.call_args_list] assert old_conn_key in delete_calls @pytest.mark.asyncio async def test_register_connection_accepts_uuid( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that register_connection accepts UUID objects. The user_id can be passed as either a string or UUID object for convenience. """ sid = "test-sid" user_uuid = uuid4() await manager.register_connection(sid, user_uuid) # Verify user_id was converted to string hset_call = mock_redis.hset.call_args mapping = hset_call.kwargs["mapping"] assert mapping["user_id"] == str(user_uuid) class TestUnregisterConnection: """Tests for connection unregistration.""" @pytest.mark.asyncio async def test_unregister_returns_none_for_unknown_sid( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that unregistering unknown connection returns None. Unregistering a non-existent connection should be a no-op and return None to indicate nothing was cleaned up. """ mock_redis.hgetall.return_value = {} result = await manager.unregister_connection("unknown-sid") assert result is None @pytest.mark.asyncio async def test_unregister_cleans_up_all_data( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that unregistering cleans up all Redis data. Unregistration should remove: 1. Connection hash 2. User-to-connection mapping 3. Remove from game connection set (if in a game) """ sid = "test-sid" user_id = "user-123" game_id = "game-456" now = datetime.now(UTC) mock_redis.hgetall.return_value = { "user_id": user_id, "game_id": game_id, "connected_at": now.isoformat(), "last_seen": now.isoformat(), } mock_redis.get.return_value = sid # User's current connection result = await manager.unregister_connection(sid) assert result is not None assert result.user_id == user_id assert result.game_id == game_id # Verify cleanup calls mock_redis.delete.assert_called() mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid) class TestGameAssociation: """Tests for game association methods.""" @pytest.mark.asyncio async def test_join_game_adds_to_game_set( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that joining a game adds connection to game's set. The connection should be tracked in the game's connection set for broadcasting updates to all participants. """ sid = "test-sid" game_id = "game-123" mock_redis.exists.return_value = True mock_redis.hget.return_value = "" # Not in a game yet result = await manager.join_game(sid, game_id) assert result is True mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid) mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", game_id) @pytest.mark.asyncio async def test_join_game_returns_false_for_unknown_connection( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that joining a game fails for unknown connections. Non-existent connections should not be able to join games. """ mock_redis.exists.return_value = False result = await manager.join_game("unknown-sid", "game-123") assert result is False @pytest.mark.asyncio async def test_join_game_leaves_previous_game( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that joining a new game leaves the previous game. A connection can only be in one game at a time, so joining a new game should automatically leave any previous game. """ sid = "test-sid" old_game = "old-game" new_game = "new-game" mock_redis.exists.return_value = True mock_redis.hget.return_value = old_game await manager.join_game(sid, new_game) # Should remove from old game mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{old_game}", sid) # Should add to new game mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{new_game}", sid) @pytest.mark.asyncio async def test_leave_game_removes_from_set( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that leaving a game removes connection from set. The connection should be removed from the game's set and its game_id should be cleared. """ sid = "test-sid" game_id = "game-123" mock_redis.hget.return_value = game_id result = await manager.leave_game(sid) assert result == game_id mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid) mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "") @pytest.mark.asyncio async def test_leave_game_returns_none_when_not_in_game( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that leaving when not in a game returns None. If the connection isn't in a game, leave_game should return None to indicate no game was left. """ mock_redis.hget.return_value = "" result = await manager.leave_game("test-sid") assert result is None class TestHeartbeat: """Tests for heartbeat/activity tracking.""" @pytest.mark.asyncio async def test_update_heartbeat_refreshes_last_seen( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that heartbeat updates the last_seen timestamp. Heartbeats keep connections alive by updating timestamps and refreshing TTLs on Redis keys. """ sid = "test-sid" mock_redis.exists.return_value = True mock_redis.hget.return_value = "user-123" result = await manager.update_heartbeat(sid) assert result is True mock_redis.hset.assert_called() mock_redis.expire.assert_called() @pytest.mark.asyncio async def test_update_heartbeat_returns_false_for_unknown( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that heartbeat fails for unknown connections. Unknown connections should not be able to send heartbeats, returning False to indicate failure. """ mock_redis.exists.return_value = False result = await manager.update_heartbeat("unknown-sid") assert result is False class TestQueryMethods: """Tests for connection query methods.""" @pytest.mark.asyncio async def test_get_connection_returns_info( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_connection returns ConnectionInfo. The returned info should contain all the stored connection data. """ sid = "test-sid" now = datetime.now(UTC) mock_redis.hgetall.return_value = { "user_id": "user-123", "game_id": "game-456", "connected_at": now.isoformat(), "last_seen": now.isoformat(), } result = await manager.get_connection(sid) assert result is not None assert result.sid == sid assert result.user_id == "user-123" assert result.game_id == "game-456" @pytest.mark.asyncio async def test_get_connection_returns_none_for_unknown( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_connection returns None for unknown sid. Non-existent connections should return None rather than raising. """ mock_redis.hgetall.return_value = {} result = await manager.get_connection("unknown-sid") assert result is None @pytest.mark.asyncio async def test_get_connection_returns_none_for_corrupted_data( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_connection returns None for corrupted records. If a connection record is missing required fields (user_id, connected_at, last_seen), it indicates data corruption. The method should return None and log a warning rather than returning a ConnectionInfo with invalid/empty data. """ # Record missing user_id mock_redis.hgetall.return_value = { "game_id": "game-123", "connected_at": "2024-01-01T00:00:00+00:00", "last_seen": "2024-01-01T00:00:00+00:00", } result = await manager.get_connection("corrupted-sid") assert result is None @pytest.mark.asyncio async def test_is_user_online_returns_true_for_connected_user( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that is_user_online returns True for connected users. Users with active, non-stale connections should be considered online. """ user_id = "user-123" now = datetime.now(UTC) mock_redis.get.return_value = "test-sid" mock_redis.hgetall.return_value = { "user_id": user_id, "game_id": "", "connected_at": now.isoformat(), "last_seen": now.isoformat(), } result = await manager.is_user_online(user_id) assert result is True @pytest.mark.asyncio async def test_is_user_online_returns_false_for_stale_connection( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that is_user_online returns False for stale connections. Users with stale connections (no recent heartbeat) should be considered offline even if their connection record exists. """ user_id = "user-123" old_time = datetime.now(UTC) - timedelta(seconds=HEARTBEAT_INTERVAL_SECONDS * 4) mock_redis.get.return_value = "test-sid" mock_redis.hgetall.return_value = { "user_id": user_id, "game_id": "", "connected_at": old_time.isoformat(), "last_seen": old_time.isoformat(), } result = await manager.is_user_online(user_id) assert result is False @pytest.mark.asyncio async def test_get_game_connections_returns_all_participants( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_game_connections returns all game participants. Should return ConnectionInfo for each sid in the game's connection set. """ game_id = "game-123" now = datetime.now(UTC) mock_redis.smembers.return_value = {"sid-1", "sid-2"} # Use a function to return the right data based on the key queried def hgetall_by_key(key: str) -> dict[str, str]: if key == f"{CONN_PREFIX}sid-1": return { "user_id": "user-1", "game_id": game_id, "connected_at": now.isoformat(), "last_seen": now.isoformat(), } elif key == f"{CONN_PREFIX}sid-2": return { "user_id": "user-2", "game_id": game_id, "connected_at": now.isoformat(), "last_seen": now.isoformat(), } return {} mock_redis.hgetall.side_effect = hgetall_by_key result = await manager.get_game_connections(game_id) assert len(result) == 2 user_ids = {conn.user_id for conn in result} assert user_ids == {"user-1", "user-2"} @pytest.mark.asyncio async def test_get_opponent_sid_returns_other_player( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_opponent_sid returns the other player's sid. In a 2-player game, this should return the sid of the player who is not the current user. """ game_id = "game-123" current_user = "user-1" opponent_user = "user-2" opponent_sid = "sid-2" now = datetime.now(UTC) mock_redis.smembers.return_value = {"sid-1", "sid-2"} # Use a function to return the right data based on the key queried # This avoids dependency on set iteration order def hgetall_by_key(key: str) -> dict[str, str]: if key == f"{CONN_PREFIX}sid-1": return { "user_id": current_user, "game_id": game_id, "connected_at": now.isoformat(), "last_seen": now.isoformat(), } elif key == f"{CONN_PREFIX}sid-2": return { "user_id": opponent_user, "game_id": game_id, "connected_at": now.isoformat(), "last_seen": now.isoformat(), } return {} mock_redis.hgetall.side_effect = hgetall_by_key result = await manager.get_opponent_sid(game_id, current_user) assert result == opponent_sid @pytest.mark.asyncio async def test_get_user_active_game_returns_game_id( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_user_active_game returns the game_id from connection. This method looks up the user's connection and returns their current game_id if they are in a game. """ user_id = "user-123" game_id = "game-456" now = datetime.now(UTC) # Mock: user has a connection with a game mock_redis.get.return_value = "test-sid" mock_redis.hgetall.return_value = { "user_id": user_id, "game_id": game_id, "connected_at": now.isoformat(), "last_seen": now.isoformat(), } result = await manager.get_user_active_game(user_id) assert result == game_id @pytest.mark.asyncio async def test_get_user_active_game_returns_none_when_not_in_game( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_user_active_game returns None when user has no game. If the user is connected but not in a game, their game_id will be None (or empty string in Redis), so we return None. """ user_id = "user-123" now = datetime.now(UTC) mock_redis.get.return_value = "test-sid" mock_redis.hgetall.return_value = { "user_id": user_id, "game_id": "", # Empty string means not in a game "connected_at": now.isoformat(), "last_seen": now.isoformat(), } result = await manager.get_user_active_game(user_id) # Empty string game_id becomes None in ConnectionInfo assert result is None @pytest.mark.asyncio async def test_get_user_active_game_returns_none_when_not_connected( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_user_active_game returns None when user is offline. If the user has no active connection, there's no game to return. """ mock_redis.get.return_value = None # No connection for user result = await manager.get_user_active_game("user-123") assert result is None @pytest.mark.asyncio async def test_get_user_active_game_accepts_uuid( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_user_active_game accepts UUID objects. The method should work with both string and UUID user IDs for convenience. """ user_uuid = uuid4() game_id = "game-789" now = datetime.now(UTC) mock_redis.get.return_value = "test-sid" mock_redis.hgetall.return_value = { "user_id": str(user_uuid), "game_id": game_id, "connected_at": now.isoformat(), "last_seen": now.isoformat(), } result = await manager.get_user_active_game(user_uuid) assert result == game_id class TestKeyGeneration: """Tests for Redis key generation methods.""" def test_conn_key_format(self) -> None: """Test that connection keys have correct format. Keys should follow the pattern conn:{sid} for easy identification and pattern matching. """ manager = ConnectionManager() key = manager._conn_key("test-sid-123") assert key == "conn:test-sid-123" def test_user_conn_key_format(self) -> None: """Test that user connection keys have correct format. Keys should follow the pattern user_conn:{user_id}. """ manager = ConnectionManager() key = manager._user_conn_key("user-456") assert key == "user_conn:user-456" def test_game_conns_key_format(self) -> None: """Test that game connections keys have correct format. Keys should follow the pattern game_conns:{game_id}. """ manager = ConnectionManager() key = manager._game_conns_key("game-789") assert key == "game_conns:game-789" def test_spectators_key_format(self) -> None: """Test that spectators keys have correct format. Keys should follow the pattern spectators:{game_id}. """ manager = ConnectionManager() key = manager._spectators_key("game-789") assert key == "spectators:game-789" class TestSpectatorManagement: """Tests for spectator-related methods. Spectators are tracked separately from game participants using a dedicated Redis set per game. """ @pytest.mark.asyncio async def test_register_spectator_success( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test registering a spectator adds them to the spectators set. When a user starts spectating a game, they should be added to the spectators:{game_id} set and their connection's game_id should be updated to indicate spectating. """ sid = "spectator-sid" user_id = "user-123" game_id = "game-456" mock_redis.exists.return_value = True result = await manager.register_spectator(sid, user_id, game_id) assert result is True # Should update connection's game_id to indicate spectating mock_redis.hset.assert_called_with( f"{CONN_PREFIX}{sid}", "game_id", f"spectating:{game_id}", ) # Should add to spectators set mock_redis.sadd.assert_called_with(f"spectators:{game_id}", sid) # Should set TTL on spectators set mock_redis.expire.assert_called() @pytest.mark.asyncio async def test_register_spectator_connection_not_found( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test registering spectator fails if connection doesn't exist. A user must have an active connection before they can spectate. """ mock_redis.exists.return_value = False result = await manager.register_spectator("unknown-sid", "user-123", "game-456") assert result is False mock_redis.sadd.assert_not_called() @pytest.mark.asyncio async def test_unregister_spectator_success( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test unregistering a spectator removes them from the set. When a spectator leaves, they should be removed from the spectators set and their connection's game_id should be cleared. """ sid = "spectator-sid" game_id = "game-456" mock_redis.srem.return_value = 1 # 1 member removed mock_redis.hget.return_value = f"spectating:{game_id}" result = await manager.unregister_spectator(sid, game_id) assert result is True mock_redis.srem.assert_called_with(f"spectators:{game_id}", sid) mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "") @pytest.mark.asyncio async def test_unregister_spectator_not_in_set( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test unregistering non-spectator returns False. If the sid is not in the spectators set, unregister should return False to indicate nothing was removed. """ mock_redis.srem.return_value = 0 # No members removed mock_redis.hget.return_value = "" result = await manager.unregister_spectator("unknown-sid", "game-456") assert result is False @pytest.mark.asyncio async def test_get_spectator_count( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test get_spectator_count returns the set cardinality. Should use Redis SCARD to get the number of spectators efficiently. """ mock_redis.scard.return_value = 5 result = await manager.get_spectator_count("game-456") assert result == 5 mock_redis.scard.assert_called_with("spectators:game-456") @pytest.mark.asyncio async def test_get_game_spectators( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test get_game_spectators returns all spectator sids. Should return all sids from the spectators set for the game. """ expected_sids = {"sid-1", "sid-2", "sid-3"} mock_redis.smembers.return_value = expected_sids result = await manager.get_game_spectators("game-456") assert set(result) == expected_sids mock_redis.smembers.assert_called_with("spectators:game-456") @pytest.mark.asyncio async def test_is_spectating_returns_true( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test is_spectating returns True when sid is spectating. Should use Redis SISMEMBER for efficient membership check. """ mock_redis.sismember.return_value = True result = await manager.is_spectating("spectator-sid", "game-456") assert result is True mock_redis.sismember.assert_called_with("spectators:game-456", "spectator-sid") @pytest.mark.asyncio async def test_is_spectating_returns_false( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test is_spectating returns False when not spectating. Should return False for sids not in the spectators set. """ mock_redis.sismember.return_value = False result = await manager.is_spectating("player-sid", "game-456") assert result is False class TestCleanupWithSpectators: """Tests for connection cleanup including spectator state.""" @pytest.mark.asyncio async def test_cleanup_removes_spectator_from_set( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that cleanup removes spectator from spectators set. When a spectating connection is cleaned up (disconnect), the sid should be removed from the spectators set. """ sid = "spectator-sid" game_id = "game-456" mock_redis.hgetall.return_value = { "user_id": "user-123", "game_id": f"spectating:{game_id}", "connected_at": "2024-01-01T00:00:00+00:00", "last_seen": "2024-01-01T00:00:00+00:00", } mock_redis.get.return_value = sid await manager._cleanup_connection(sid) # Should remove from spectators set mock_redis.srem.assert_called_with(f"spectators:{game_id}", sid) @pytest.mark.asyncio async def test_cleanup_player_removes_from_game_conns( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that cleanup removes player from game_conns set. When a playing (not spectating) connection is cleaned up, the sid should be removed from game_conns, not spectators. """ sid = "player-sid" game_id = "game-456" mock_redis.hgetall.return_value = { "user_id": "user-123", "game_id": game_id, # Regular game_id, not spectating: "connected_at": "2024-01-01T00:00:00+00:00", "last_seen": "2024-01-01T00:00:00+00:00", } mock_redis.get.return_value = sid await manager._cleanup_connection(sid) # Should remove from game_conns, not spectators mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)