"""Tests for ConnectionManager service. This module tests WebSocket connection tracking with Redis. Since these are unit tests, we inject a mock Redis factory to test the ConnectionManager logic without requiring a real Redis instance. Uses dependency injection pattern - no monkey patching required. """ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from datetime import UTC, datetime, timedelta from unittest.mock import AsyncMock from uuid import uuid4 import pytest from app.services.connection_manager import ( CONN_PREFIX, GAME_CONNS_PREFIX, HEARTBEAT_INTERVAL_SECONDS, USER_CONN_PREFIX, ConnectionInfo, ConnectionManager, ) @pytest.fixture def mock_redis() -> AsyncMock: """Create a mock Redis client. This mock provides all the Redis operations used by ConnectionManager with sensible defaults that can be overridden per-test. """ redis = AsyncMock() redis.hset = AsyncMock() redis.hget = AsyncMock() redis.hgetall = AsyncMock(return_value={}) redis.set = AsyncMock() redis.get = AsyncMock(return_value=None) redis.delete = AsyncMock() redis.exists = AsyncMock(return_value=False) redis.expire = AsyncMock() redis.sadd = AsyncMock() redis.srem = AsyncMock() redis.smembers = AsyncMock(return_value=set()) redis.scard = AsyncMock(return_value=0) return redis @pytest.fixture def manager(mock_redis: AsyncMock) -> ConnectionManager: """Create a ConnectionManager with injected mock Redis. The mock Redis is injected via the redis_factory parameter, eliminating the need for monkey patching in tests. """ @asynccontextmanager async def mock_redis_factory() -> AsyncIterator[AsyncMock]: yield mock_redis return ConnectionManager(conn_ttl_seconds=3600, redis_factory=mock_redis_factory) class TestConnectionInfoDataclass: """Tests for the ConnectionInfo dataclass.""" def test_is_stale_returns_false_for_recent_connection(self) -> None: """Test that recent connections are not marked as stale. Connections that have been seen within the threshold should be considered active, not stale. """ now = datetime.now(UTC) info = ConnectionInfo( sid="test-sid", user_id="user-123", game_id=None, connected_at=now, last_seen=now, ) assert info.is_stale() is False def test_is_stale_returns_true_for_old_connection(self) -> None: """Test that old connections are marked as stale. Connections that haven't been seen for longer than the threshold should be considered stale and eligible for cleanup. """ now = datetime.now(UTC) old_time = now - timedelta(seconds=HEARTBEAT_INTERVAL_SECONDS * 4) info = ConnectionInfo( sid="test-sid", user_id="user-123", game_id=None, connected_at=old_time, last_seen=old_time, ) assert info.is_stale() is True def test_is_stale_with_custom_threshold(self) -> None: """Test is_stale with a custom threshold. The threshold parameter allows callers to customize what counts as stale for different use cases. """ now = datetime.now(UTC) last_seen = now - timedelta(seconds=60) info = ConnectionInfo( sid="test-sid", user_id="user-123", game_id=None, connected_at=last_seen, last_seen=last_seen, ) # 60 seconds old, with 30 second threshold = stale assert info.is_stale(threshold_seconds=30) is True # 60 seconds old, with 120 second threshold = not stale assert info.is_stale(threshold_seconds=120) is False class TestRegisterConnection: """Tests for connection registration.""" @pytest.mark.asyncio async def test_register_connection_creates_records( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that registering a connection creates all necessary Redis records. Registration should create: 1. Connection hash with user_id, game_id, timestamps 2. User-to-connection mapping """ sid = "test-sid-123" user_id = str(uuid4()) await manager.register_connection(sid, user_id) # Verify connection hash was created conn_key = f"{CONN_PREFIX}{sid}" mock_redis.hset.assert_called() hset_call = mock_redis.hset.call_args assert hset_call.args[0] == conn_key mapping = hset_call.kwargs["mapping"] assert mapping["user_id"] == user_id assert mapping["game_id"] == "" # Verify user-to-connection mapping was created user_conn_key = f"{USER_CONN_PREFIX}{user_id}" mock_redis.set.assert_called_with(user_conn_key, sid) @pytest.mark.asyncio async def test_register_connection_replaces_old_connection( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that registering a new connection replaces the old one. When a user reconnects, their old connection should be cleaned up to prevent stale connection data from lingering. """ old_sid = "old-sid" new_sid = "new-sid" user_id = str(uuid4()) # Mock: user has existing connection mock_redis.get.return_value = old_sid mock_redis.hgetall.return_value = { "user_id": user_id, "game_id": "", "connected_at": datetime.now(UTC).isoformat(), "last_seen": datetime.now(UTC).isoformat(), } await manager.register_connection(new_sid, user_id) # Verify old connection was cleaned up (delete called for old conn key) old_conn_key = f"{CONN_PREFIX}{old_sid}" delete_calls = [call.args[0] for call in mock_redis.delete.call_args_list] assert old_conn_key in delete_calls @pytest.mark.asyncio async def test_register_connection_accepts_uuid( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that register_connection accepts UUID objects. The user_id can be passed as either a string or UUID object for convenience. """ sid = "test-sid" user_uuid = uuid4() await manager.register_connection(sid, user_uuid) # Verify user_id was converted to string hset_call = mock_redis.hset.call_args mapping = hset_call.kwargs["mapping"] assert mapping["user_id"] == str(user_uuid) class TestUnregisterConnection: """Tests for connection unregistration.""" @pytest.mark.asyncio async def test_unregister_returns_none_for_unknown_sid( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that unregistering unknown connection returns None. Unregistering a non-existent connection should be a no-op and return None to indicate nothing was cleaned up. """ mock_redis.hgetall.return_value = {} result = await manager.unregister_connection("unknown-sid") assert result is None @pytest.mark.asyncio async def test_unregister_cleans_up_all_data( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that unregistering cleans up all Redis data. Unregistration should remove: 1. Connection hash 2. User-to-connection mapping 3. Remove from game connection set (if in a game) """ sid = "test-sid" user_id = "user-123" game_id = "game-456" now = datetime.now(UTC) mock_redis.hgetall.return_value = { "user_id": user_id, "game_id": game_id, "connected_at": now.isoformat(), "last_seen": now.isoformat(), } mock_redis.get.return_value = sid # User's current connection result = await manager.unregister_connection(sid) assert result is not None assert result.user_id == user_id assert result.game_id == game_id # Verify cleanup calls mock_redis.delete.assert_called() mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid) class TestGameAssociation: """Tests for game association methods.""" @pytest.mark.asyncio async def test_join_game_adds_to_game_set( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that joining a game adds connection to game's set. The connection should be tracked in the game's connection set for broadcasting updates to all participants. """ sid = "test-sid" game_id = "game-123" mock_redis.exists.return_value = True mock_redis.hget.return_value = "" # Not in a game yet result = await manager.join_game(sid, game_id) assert result is True mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid) mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", game_id) @pytest.mark.asyncio async def test_join_game_returns_false_for_unknown_connection( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that joining a game fails for unknown connections. Non-existent connections should not be able to join games. """ mock_redis.exists.return_value = False result = await manager.join_game("unknown-sid", "game-123") assert result is False @pytest.mark.asyncio async def test_join_game_leaves_previous_game( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that joining a new game leaves the previous game. A connection can only be in one game at a time, so joining a new game should automatically leave any previous game. """ sid = "test-sid" old_game = "old-game" new_game = "new-game" mock_redis.exists.return_value = True mock_redis.hget.return_value = old_game await manager.join_game(sid, new_game) # Should remove from old game mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{old_game}", sid) # Should add to new game mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{new_game}", sid) @pytest.mark.asyncio async def test_leave_game_removes_from_set( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that leaving a game removes connection from set. The connection should be removed from the game's set and its game_id should be cleared. """ sid = "test-sid" game_id = "game-123" mock_redis.hget.return_value = game_id result = await manager.leave_game(sid) assert result == game_id mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid) mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "") @pytest.mark.asyncio async def test_leave_game_returns_none_when_not_in_game( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that leaving when not in a game returns None. If the connection isn't in a game, leave_game should return None to indicate no game was left. """ mock_redis.hget.return_value = "" result = await manager.leave_game("test-sid") assert result is None class TestHeartbeat: """Tests for heartbeat/activity tracking.""" @pytest.mark.asyncio async def test_update_heartbeat_refreshes_last_seen( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that heartbeat updates the last_seen timestamp. Heartbeats keep connections alive by updating timestamps and refreshing TTLs on Redis keys. """ sid = "test-sid" mock_redis.exists.return_value = True mock_redis.hget.return_value = "user-123" result = await manager.update_heartbeat(sid) assert result is True mock_redis.hset.assert_called() mock_redis.expire.assert_called() @pytest.mark.asyncio async def test_update_heartbeat_returns_false_for_unknown( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that heartbeat fails for unknown connections. Unknown connections should not be able to send heartbeats, returning False to indicate failure. """ mock_redis.exists.return_value = False result = await manager.update_heartbeat("unknown-sid") assert result is False class TestQueryMethods: """Tests for connection query methods.""" @pytest.mark.asyncio async def test_get_connection_returns_info( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_connection returns ConnectionInfo. The returned info should contain all the stored connection data. """ sid = "test-sid" now = datetime.now(UTC) mock_redis.hgetall.return_value = { "user_id": "user-123", "game_id": "game-456", "connected_at": now.isoformat(), "last_seen": now.isoformat(), } result = await manager.get_connection(sid) assert result is not None assert result.sid == sid assert result.user_id == "user-123" assert result.game_id == "game-456" @pytest.mark.asyncio async def test_get_connection_returns_none_for_unknown( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_connection returns None for unknown sid. Non-existent connections should return None rather than raising. """ mock_redis.hgetall.return_value = {} result = await manager.get_connection("unknown-sid") assert result is None @pytest.mark.asyncio async def test_get_connection_returns_none_for_corrupted_data( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_connection returns None for corrupted records. If a connection record is missing required fields (user_id, connected_at, last_seen), it indicates data corruption. The method should return None and log a warning rather than returning a ConnectionInfo with invalid/empty data. """ # Record missing user_id mock_redis.hgetall.return_value = { "game_id": "game-123", "connected_at": "2024-01-01T00:00:00+00:00", "last_seen": "2024-01-01T00:00:00+00:00", } result = await manager.get_connection("corrupted-sid") assert result is None @pytest.mark.asyncio async def test_is_user_online_returns_true_for_connected_user( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that is_user_online returns True for connected users. Users with active, non-stale connections should be considered online. """ user_id = "user-123" now = datetime.now(UTC) mock_redis.get.return_value = "test-sid" mock_redis.hgetall.return_value = { "user_id": user_id, "game_id": "", "connected_at": now.isoformat(), "last_seen": now.isoformat(), } result = await manager.is_user_online(user_id) assert result is True @pytest.mark.asyncio async def test_is_user_online_returns_false_for_stale_connection( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that is_user_online returns False for stale connections. Users with stale connections (no recent heartbeat) should be considered offline even if their connection record exists. """ user_id = "user-123" old_time = datetime.now(UTC) - timedelta(seconds=HEARTBEAT_INTERVAL_SECONDS * 4) mock_redis.get.return_value = "test-sid" mock_redis.hgetall.return_value = { "user_id": user_id, "game_id": "", "connected_at": old_time.isoformat(), "last_seen": old_time.isoformat(), } result = await manager.is_user_online(user_id) assert result is False @pytest.mark.asyncio async def test_get_game_connections_returns_all_participants( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_game_connections returns all game participants. Should return ConnectionInfo for each sid in the game's connection set. """ game_id = "game-123" now = datetime.now(UTC) mock_redis.smembers.return_value = {"sid-1", "sid-2"} # Use a function to return the right data based on the key queried def hgetall_by_key(key: str) -> dict[str, str]: if key == f"{CONN_PREFIX}sid-1": return { "user_id": "user-1", "game_id": game_id, "connected_at": now.isoformat(), "last_seen": now.isoformat(), } elif key == f"{CONN_PREFIX}sid-2": return { "user_id": "user-2", "game_id": game_id, "connected_at": now.isoformat(), "last_seen": now.isoformat(), } return {} mock_redis.hgetall.side_effect = hgetall_by_key result = await manager.get_game_connections(game_id) assert len(result) == 2 user_ids = {conn.user_id for conn in result} assert user_ids == {"user-1", "user-2"} @pytest.mark.asyncio async def test_get_opponent_sid_returns_other_player( self, manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: """Test that get_opponent_sid returns the other player's sid. In a 2-player game, this should return the sid of the player who is not the current user. """ game_id = "game-123" current_user = "user-1" opponent_user = "user-2" opponent_sid = "sid-2" now = datetime.now(UTC) mock_redis.smembers.return_value = {"sid-1", "sid-2"} # Use a function to return the right data based on the key queried # This avoids dependency on set iteration order def hgetall_by_key(key: str) -> dict[str, str]: if key == f"{CONN_PREFIX}sid-1": return { "user_id": current_user, "game_id": game_id, "connected_at": now.isoformat(), "last_seen": now.isoformat(), } elif key == f"{CONN_PREFIX}sid-2": return { "user_id": opponent_user, "game_id": game_id, "connected_at": now.isoformat(), "last_seen": now.isoformat(), } return {} mock_redis.hgetall.side_effect = hgetall_by_key result = await manager.get_opponent_sid(game_id, current_user) assert result == opponent_sid class TestKeyGeneration: """Tests for Redis key generation methods.""" def test_conn_key_format(self) -> None: """Test that connection keys have correct format. Keys should follow the pattern conn:{sid} for easy identification and pattern matching. """ manager = ConnectionManager() key = manager._conn_key("test-sid-123") assert key == "conn:test-sid-123" def test_user_conn_key_format(self) -> None: """Test that user connection keys have correct format. Keys should follow the pattern user_conn:{user_id}. """ manager = ConnectionManager() key = manager._user_conn_key("user-456") assert key == "user_conn:user-456" def test_game_conns_key_format(self) -> None: """Test that game connections keys have correct format. Keys should follow the pattern game_conns:{game_id}. """ manager = ConnectionManager() key = manager._game_conns_key("game-789") assert key == "game_conns:game-789"