Fix audit warnings about empty string defaults hiding data corruption: 1. get_connection_info(): Validate required fields (user_id, connected_at, last_seen) exist before creating ConnectionInfo. Return None and log warning for corrupted records instead of returning invalid data. 2. unregister_connection(): Log warning if user_id is missing during cleanup. Existing code safely handles this (skips cleanup), but now we have visibility into data corruption. Test added for corrupted data case. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
656 lines
20 KiB
Python
656 lines
20 KiB
Python
"""Tests for ConnectionManager service.
|
|
|
|
This module tests WebSocket connection tracking with Redis. Since these are
|
|
unit tests, we inject a mock Redis factory to test the ConnectionManager logic
|
|
without requiring a real Redis instance.
|
|
|
|
Uses dependency injection pattern - no monkey patching required.
|
|
"""
|
|
|
|
from collections.abc import AsyncIterator
|
|
from contextlib import asynccontextmanager
|
|
from datetime import UTC, datetime, timedelta
|
|
from unittest.mock import AsyncMock
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from app.services.connection_manager import (
|
|
CONN_PREFIX,
|
|
GAME_CONNS_PREFIX,
|
|
HEARTBEAT_INTERVAL_SECONDS,
|
|
USER_CONN_PREFIX,
|
|
ConnectionInfo,
|
|
ConnectionManager,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_redis() -> AsyncMock:
|
|
"""Create a mock Redis client.
|
|
|
|
This mock provides all the Redis operations used by ConnectionManager
|
|
with sensible defaults that can be overridden per-test.
|
|
"""
|
|
redis = AsyncMock()
|
|
redis.hset = AsyncMock()
|
|
redis.hget = AsyncMock()
|
|
redis.hgetall = AsyncMock(return_value={})
|
|
redis.set = AsyncMock()
|
|
redis.get = AsyncMock(return_value=None)
|
|
redis.delete = AsyncMock()
|
|
redis.exists = AsyncMock(return_value=False)
|
|
redis.expire = AsyncMock()
|
|
redis.sadd = AsyncMock()
|
|
redis.srem = AsyncMock()
|
|
redis.smembers = AsyncMock(return_value=set())
|
|
redis.scard = AsyncMock(return_value=0)
|
|
return redis
|
|
|
|
|
|
@pytest.fixture
|
|
def manager(mock_redis: AsyncMock) -> ConnectionManager:
|
|
"""Create a ConnectionManager with injected mock Redis.
|
|
|
|
The mock Redis is injected via the redis_factory parameter,
|
|
eliminating the need for monkey patching in tests.
|
|
"""
|
|
|
|
@asynccontextmanager
|
|
async def mock_redis_factory() -> AsyncIterator[AsyncMock]:
|
|
yield mock_redis
|
|
|
|
return ConnectionManager(conn_ttl_seconds=3600, redis_factory=mock_redis_factory)
|
|
|
|
|
|
class TestConnectionInfoDataclass:
|
|
"""Tests for the ConnectionInfo dataclass."""
|
|
|
|
def test_is_stale_returns_false_for_recent_connection(self) -> None:
|
|
"""Test that recent connections are not marked as stale.
|
|
|
|
Connections that have been seen within the threshold should be
|
|
considered active, not stale.
|
|
"""
|
|
now = datetime.now(UTC)
|
|
info = ConnectionInfo(
|
|
sid="test-sid",
|
|
user_id="user-123",
|
|
game_id=None,
|
|
connected_at=now,
|
|
last_seen=now,
|
|
)
|
|
|
|
assert info.is_stale() is False
|
|
|
|
def test_is_stale_returns_true_for_old_connection(self) -> None:
|
|
"""Test that old connections are marked as stale.
|
|
|
|
Connections that haven't been seen for longer than the threshold
|
|
should be considered stale and eligible for cleanup.
|
|
"""
|
|
now = datetime.now(UTC)
|
|
old_time = now - timedelta(seconds=HEARTBEAT_INTERVAL_SECONDS * 4)
|
|
info = ConnectionInfo(
|
|
sid="test-sid",
|
|
user_id="user-123",
|
|
game_id=None,
|
|
connected_at=old_time,
|
|
last_seen=old_time,
|
|
)
|
|
|
|
assert info.is_stale() is True
|
|
|
|
def test_is_stale_with_custom_threshold(self) -> None:
|
|
"""Test is_stale with a custom threshold.
|
|
|
|
The threshold parameter allows callers to customize what counts
|
|
as stale for different use cases.
|
|
"""
|
|
now = datetime.now(UTC)
|
|
last_seen = now - timedelta(seconds=60)
|
|
info = ConnectionInfo(
|
|
sid="test-sid",
|
|
user_id="user-123",
|
|
game_id=None,
|
|
connected_at=last_seen,
|
|
last_seen=last_seen,
|
|
)
|
|
|
|
# 60 seconds old, with 30 second threshold = stale
|
|
assert info.is_stale(threshold_seconds=30) is True
|
|
|
|
# 60 seconds old, with 120 second threshold = not stale
|
|
assert info.is_stale(threshold_seconds=120) is False
|
|
|
|
|
|
class TestRegisterConnection:
|
|
"""Tests for connection registration."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_connection_creates_records(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that registering a connection creates all necessary Redis records.
|
|
|
|
Registration should create:
|
|
1. Connection hash with user_id, game_id, timestamps
|
|
2. User-to-connection mapping
|
|
"""
|
|
sid = "test-sid-123"
|
|
user_id = str(uuid4())
|
|
|
|
await manager.register_connection(sid, user_id)
|
|
|
|
# Verify connection hash was created
|
|
conn_key = f"{CONN_PREFIX}{sid}"
|
|
mock_redis.hset.assert_called()
|
|
hset_call = mock_redis.hset.call_args
|
|
assert hset_call.args[0] == conn_key
|
|
mapping = hset_call.kwargs["mapping"]
|
|
assert mapping["user_id"] == user_id
|
|
assert mapping["game_id"] == ""
|
|
|
|
# Verify user-to-connection mapping was created
|
|
user_conn_key = f"{USER_CONN_PREFIX}{user_id}"
|
|
mock_redis.set.assert_called_with(user_conn_key, sid)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_connection_replaces_old_connection(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that registering a new connection replaces the old one.
|
|
|
|
When a user reconnects, their old connection should be cleaned up
|
|
to prevent stale connection data from lingering.
|
|
"""
|
|
old_sid = "old-sid"
|
|
new_sid = "new-sid"
|
|
user_id = str(uuid4())
|
|
|
|
# Mock: user has existing connection
|
|
mock_redis.get.return_value = old_sid
|
|
mock_redis.hgetall.return_value = {
|
|
"user_id": user_id,
|
|
"game_id": "",
|
|
"connected_at": datetime.now(UTC).isoformat(),
|
|
"last_seen": datetime.now(UTC).isoformat(),
|
|
}
|
|
|
|
await manager.register_connection(new_sid, user_id)
|
|
|
|
# Verify old connection was cleaned up (delete called for old conn key)
|
|
old_conn_key = f"{CONN_PREFIX}{old_sid}"
|
|
delete_calls = [call.args[0] for call in mock_redis.delete.call_args_list]
|
|
assert old_conn_key in delete_calls
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_connection_accepts_uuid(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that register_connection accepts UUID objects.
|
|
|
|
The user_id can be passed as either a string or UUID object
|
|
for convenience.
|
|
"""
|
|
sid = "test-sid"
|
|
user_uuid = uuid4()
|
|
|
|
await manager.register_connection(sid, user_uuid)
|
|
|
|
# Verify user_id was converted to string
|
|
hset_call = mock_redis.hset.call_args
|
|
mapping = hset_call.kwargs["mapping"]
|
|
assert mapping["user_id"] == str(user_uuid)
|
|
|
|
|
|
class TestUnregisterConnection:
|
|
"""Tests for connection unregistration."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unregister_returns_none_for_unknown_sid(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that unregistering unknown connection returns None.
|
|
|
|
Unregistering a non-existent connection should be a no-op
|
|
and return None to indicate nothing was cleaned up.
|
|
"""
|
|
mock_redis.hgetall.return_value = {}
|
|
|
|
result = await manager.unregister_connection("unknown-sid")
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unregister_cleans_up_all_data(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that unregistering cleans up all Redis data.
|
|
|
|
Unregistration should remove:
|
|
1. Connection hash
|
|
2. User-to-connection mapping
|
|
3. Remove from game connection set (if in a game)
|
|
"""
|
|
sid = "test-sid"
|
|
user_id = "user-123"
|
|
game_id = "game-456"
|
|
now = datetime.now(UTC)
|
|
|
|
mock_redis.hgetall.return_value = {
|
|
"user_id": user_id,
|
|
"game_id": game_id,
|
|
"connected_at": now.isoformat(),
|
|
"last_seen": now.isoformat(),
|
|
}
|
|
mock_redis.get.return_value = sid # User's current connection
|
|
|
|
result = await manager.unregister_connection(sid)
|
|
|
|
assert result is not None
|
|
assert result.user_id == user_id
|
|
assert result.game_id == game_id
|
|
|
|
# Verify cleanup calls
|
|
mock_redis.delete.assert_called()
|
|
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
|
|
|
|
|
|
class TestGameAssociation:
|
|
"""Tests for game association methods."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_join_game_adds_to_game_set(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that joining a game adds connection to game's set.
|
|
|
|
The connection should be tracked in the game's connection set
|
|
for broadcasting updates to all participants.
|
|
"""
|
|
sid = "test-sid"
|
|
game_id = "game-123"
|
|
|
|
mock_redis.exists.return_value = True
|
|
mock_redis.hget.return_value = "" # Not in a game yet
|
|
|
|
result = await manager.join_game(sid, game_id)
|
|
|
|
assert result is True
|
|
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
|
|
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", game_id)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_join_game_returns_false_for_unknown_connection(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that joining a game fails for unknown connections.
|
|
|
|
Non-existent connections should not be able to join games.
|
|
"""
|
|
mock_redis.exists.return_value = False
|
|
|
|
result = await manager.join_game("unknown-sid", "game-123")
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_join_game_leaves_previous_game(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that joining a new game leaves the previous game.
|
|
|
|
A connection can only be in one game at a time, so joining
|
|
a new game should automatically leave any previous game.
|
|
"""
|
|
sid = "test-sid"
|
|
old_game = "old-game"
|
|
new_game = "new-game"
|
|
|
|
mock_redis.exists.return_value = True
|
|
mock_redis.hget.return_value = old_game
|
|
|
|
await manager.join_game(sid, new_game)
|
|
|
|
# Should remove from old game
|
|
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{old_game}", sid)
|
|
# Should add to new game
|
|
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{new_game}", sid)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_leave_game_removes_from_set(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that leaving a game removes connection from set.
|
|
|
|
The connection should be removed from the game's set and
|
|
its game_id should be cleared.
|
|
"""
|
|
sid = "test-sid"
|
|
game_id = "game-123"
|
|
|
|
mock_redis.hget.return_value = game_id
|
|
|
|
result = await manager.leave_game(sid)
|
|
|
|
assert result == game_id
|
|
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
|
|
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_leave_game_returns_none_when_not_in_game(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that leaving when not in a game returns None.
|
|
|
|
If the connection isn't in a game, leave_game should return
|
|
None to indicate no game was left.
|
|
"""
|
|
mock_redis.hget.return_value = ""
|
|
|
|
result = await manager.leave_game("test-sid")
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestHeartbeat:
|
|
"""Tests for heartbeat/activity tracking."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_heartbeat_refreshes_last_seen(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that heartbeat updates the last_seen timestamp.
|
|
|
|
Heartbeats keep connections alive by updating timestamps
|
|
and refreshing TTLs on Redis keys.
|
|
"""
|
|
sid = "test-sid"
|
|
mock_redis.exists.return_value = True
|
|
mock_redis.hget.return_value = "user-123"
|
|
|
|
result = await manager.update_heartbeat(sid)
|
|
|
|
assert result is True
|
|
mock_redis.hset.assert_called()
|
|
mock_redis.expire.assert_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_heartbeat_returns_false_for_unknown(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that heartbeat fails for unknown connections.
|
|
|
|
Unknown connections should not be able to send heartbeats,
|
|
returning False to indicate failure.
|
|
"""
|
|
mock_redis.exists.return_value = False
|
|
|
|
result = await manager.update_heartbeat("unknown-sid")
|
|
|
|
assert result is False
|
|
|
|
|
|
class TestQueryMethods:
|
|
"""Tests for connection query methods."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_connection_returns_info(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that get_connection returns ConnectionInfo.
|
|
|
|
The returned info should contain all the stored connection data.
|
|
"""
|
|
sid = "test-sid"
|
|
now = datetime.now(UTC)
|
|
|
|
mock_redis.hgetall.return_value = {
|
|
"user_id": "user-123",
|
|
"game_id": "game-456",
|
|
"connected_at": now.isoformat(),
|
|
"last_seen": now.isoformat(),
|
|
}
|
|
|
|
result = await manager.get_connection(sid)
|
|
|
|
assert result is not None
|
|
assert result.sid == sid
|
|
assert result.user_id == "user-123"
|
|
assert result.game_id == "game-456"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_connection_returns_none_for_unknown(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that get_connection returns None for unknown sid.
|
|
|
|
Non-existent connections should return None rather than raising.
|
|
"""
|
|
mock_redis.hgetall.return_value = {}
|
|
|
|
result = await manager.get_connection("unknown-sid")
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_connection_returns_none_for_corrupted_data(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that get_connection returns None for corrupted records.
|
|
|
|
If a connection record is missing required fields (user_id,
|
|
connected_at, last_seen), it indicates data corruption. The
|
|
method should return None and log a warning rather than
|
|
returning a ConnectionInfo with invalid/empty data.
|
|
"""
|
|
# Record missing user_id
|
|
mock_redis.hgetall.return_value = {
|
|
"game_id": "game-123",
|
|
"connected_at": "2024-01-01T00:00:00+00:00",
|
|
"last_seen": "2024-01-01T00:00:00+00:00",
|
|
}
|
|
|
|
result = await manager.get_connection("corrupted-sid")
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_is_user_online_returns_true_for_connected_user(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that is_user_online returns True for connected users.
|
|
|
|
Users with active, non-stale connections should be considered online.
|
|
"""
|
|
user_id = "user-123"
|
|
now = datetime.now(UTC)
|
|
|
|
mock_redis.get.return_value = "test-sid"
|
|
mock_redis.hgetall.return_value = {
|
|
"user_id": user_id,
|
|
"game_id": "",
|
|
"connected_at": now.isoformat(),
|
|
"last_seen": now.isoformat(),
|
|
}
|
|
|
|
result = await manager.is_user_online(user_id)
|
|
|
|
assert result is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_is_user_online_returns_false_for_stale_connection(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that is_user_online returns False for stale connections.
|
|
|
|
Users with stale connections (no recent heartbeat) should be
|
|
considered offline even if their connection record exists.
|
|
"""
|
|
user_id = "user-123"
|
|
old_time = datetime.now(UTC) - timedelta(seconds=HEARTBEAT_INTERVAL_SECONDS * 4)
|
|
|
|
mock_redis.get.return_value = "test-sid"
|
|
mock_redis.hgetall.return_value = {
|
|
"user_id": user_id,
|
|
"game_id": "",
|
|
"connected_at": old_time.isoformat(),
|
|
"last_seen": old_time.isoformat(),
|
|
}
|
|
|
|
result = await manager.is_user_online(user_id)
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_game_connections_returns_all_participants(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that get_game_connections returns all game participants.
|
|
|
|
Should return ConnectionInfo for each sid in the game's connection set.
|
|
"""
|
|
game_id = "game-123"
|
|
now = datetime.now(UTC)
|
|
|
|
mock_redis.smembers.return_value = {"sid-1", "sid-2"}
|
|
|
|
# Use a function to return the right data based on the key queried
|
|
def hgetall_by_key(key: str) -> dict[str, str]:
|
|
if key == f"{CONN_PREFIX}sid-1":
|
|
return {
|
|
"user_id": "user-1",
|
|
"game_id": game_id,
|
|
"connected_at": now.isoformat(),
|
|
"last_seen": now.isoformat(),
|
|
}
|
|
elif key == f"{CONN_PREFIX}sid-2":
|
|
return {
|
|
"user_id": "user-2",
|
|
"game_id": game_id,
|
|
"connected_at": now.isoformat(),
|
|
"last_seen": now.isoformat(),
|
|
}
|
|
return {}
|
|
|
|
mock_redis.hgetall.side_effect = hgetall_by_key
|
|
|
|
result = await manager.get_game_connections(game_id)
|
|
|
|
assert len(result) == 2
|
|
user_ids = {conn.user_id for conn in result}
|
|
assert user_ids == {"user-1", "user-2"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_opponent_sid_returns_other_player(
|
|
self,
|
|
manager: ConnectionManager,
|
|
mock_redis: AsyncMock,
|
|
) -> None:
|
|
"""Test that get_opponent_sid returns the other player's sid.
|
|
|
|
In a 2-player game, this should return the sid of the player
|
|
who is not the current user.
|
|
"""
|
|
game_id = "game-123"
|
|
current_user = "user-1"
|
|
opponent_user = "user-2"
|
|
opponent_sid = "sid-2"
|
|
now = datetime.now(UTC)
|
|
|
|
mock_redis.smembers.return_value = {"sid-1", "sid-2"}
|
|
|
|
# Use a function to return the right data based on the key queried
|
|
# This avoids dependency on set iteration order
|
|
def hgetall_by_key(key: str) -> dict[str, str]:
|
|
if key == f"{CONN_PREFIX}sid-1":
|
|
return {
|
|
"user_id": current_user,
|
|
"game_id": game_id,
|
|
"connected_at": now.isoformat(),
|
|
"last_seen": now.isoformat(),
|
|
}
|
|
elif key == f"{CONN_PREFIX}sid-2":
|
|
return {
|
|
"user_id": opponent_user,
|
|
"game_id": game_id,
|
|
"connected_at": now.isoformat(),
|
|
"last_seen": now.isoformat(),
|
|
}
|
|
return {}
|
|
|
|
mock_redis.hgetall.side_effect = hgetall_by_key
|
|
|
|
result = await manager.get_opponent_sid(game_id, current_user)
|
|
|
|
assert result == opponent_sid
|
|
|
|
|
|
class TestKeyGeneration:
|
|
"""Tests for Redis key generation methods."""
|
|
|
|
def test_conn_key_format(self) -> None:
|
|
"""Test that connection keys have correct format.
|
|
|
|
Keys should follow the pattern conn:{sid} for easy identification
|
|
and pattern matching.
|
|
"""
|
|
manager = ConnectionManager()
|
|
key = manager._conn_key("test-sid-123")
|
|
assert key == "conn:test-sid-123"
|
|
|
|
def test_user_conn_key_format(self) -> None:
|
|
"""Test that user connection keys have correct format.
|
|
|
|
Keys should follow the pattern user_conn:{user_id}.
|
|
"""
|
|
manager = ConnectionManager()
|
|
key = manager._user_conn_key("user-456")
|
|
assert key == "user_conn:user-456"
|
|
|
|
def test_game_conns_key_format(self) -> None:
|
|
"""Test that game connections keys have correct format.
|
|
|
|
Keys should follow the pattern game_conns:{game_id}.
|
|
"""
|
|
manager = ConnectionManager()
|
|
key = manager._game_conns_key("game-789")
|
|
assert key == "game_conns:game-789"
|