- ConnectionManager: Add redis_factory constructor parameter - GameService: Add engine_factory constructor parameter - AuthHandler: New class replacing standalone functions with token_verifier and conn_manager injection - Update all tests to use constructor DI instead of patch() - Update CLAUDE.md with factory injection patterns - Update services README with new patterns - Add socketio README documenting AuthHandler and events Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
594 lines
19 KiB
Python
594 lines
19 KiB
Python
"""WebSocket connection management with Redis-backed session tracking.
|
|
|
|
This module manages WebSocket connection state using Redis for persistence
|
|
and cross-instance coordination. It tracks:
|
|
- Individual connection sessions (sid -> user_id, game_id, timestamps)
|
|
- User to connection mapping (user_id -> sid)
|
|
- Game participants (game_id -> set of sids)
|
|
|
|
Key Patterns:
|
|
conn:{sid} - Hash with connection details (user_id, game_id, connected_at, last_seen)
|
|
user_conn:{user_id} - String with active sid
|
|
game_conns:{game_id} - Set of sids for game participants
|
|
|
|
Lifecycle:
|
|
1. on_connect: Register connection, map user to sid
|
|
2. on_join_game: Add sid to game's connection set
|
|
3. on_heartbeat: Update last_seen timestamp
|
|
4. on_disconnect: Remove from all tracking structures
|
|
|
|
Example:
|
|
manager = ConnectionManager()
|
|
|
|
# On WebSocket connect
|
|
await manager.register_connection(sid, user_id)
|
|
|
|
# On joining a game
|
|
await manager.join_game(sid, game_id)
|
|
|
|
# Check if opponent is online
|
|
if await manager.is_user_online(opponent_id):
|
|
# Send real-time update
|
|
|
|
# On disconnect
|
|
await manager.unregister_connection(sid)
|
|
"""
|
|
|
|
import logging
|
|
from collections.abc import AsyncIterator, Callable
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from typing import TYPE_CHECKING
|
|
from uuid import UUID
|
|
|
|
from app.db.redis import get_redis
|
|
|
|
if TYPE_CHECKING:
|
|
from redis.asyncio import Redis
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Type alias for redis factory - a callable that returns an async context manager
|
|
RedisFactory = Callable[[], AsyncIterator["Redis"]]
|
|
|
|
# Redis key patterns
|
|
CONN_PREFIX = "conn:"
|
|
USER_CONN_PREFIX = "user_conn:"
|
|
GAME_CONNS_PREFIX = "game_conns:"
|
|
|
|
# Connection TTL (auto-expire stale connections)
|
|
DEFAULT_CONN_TTL_SECONDS = 3600 # 1 hour
|
|
HEARTBEAT_INTERVAL_SECONDS = 30
|
|
|
|
|
|
@dataclass
|
|
class ConnectionInfo:
|
|
"""Information about a WebSocket connection.
|
|
|
|
Attributes:
|
|
sid: Socket.IO session ID.
|
|
user_id: Authenticated user's UUID.
|
|
game_id: Current game ID if in a game, None otherwise.
|
|
connected_at: When the connection was established.
|
|
last_seen: Last heartbeat or activity timestamp.
|
|
"""
|
|
|
|
sid: str
|
|
user_id: str
|
|
game_id: str | None
|
|
connected_at: datetime
|
|
last_seen: datetime
|
|
|
|
def is_stale(self, threshold_seconds: int = HEARTBEAT_INTERVAL_SECONDS * 3) -> bool:
|
|
"""Check if connection is stale (no recent activity).
|
|
|
|
Args:
|
|
threshold_seconds: Seconds since last_seen to consider stale.
|
|
|
|
Returns:
|
|
True if connection hasn't been seen recently.
|
|
"""
|
|
elapsed = (datetime.now(UTC) - self.last_seen).total_seconds()
|
|
return elapsed > threshold_seconds
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Manages WebSocket connections with Redis-backed session tracking.
|
|
|
|
Provides methods to:
|
|
- Register/unregister connections
|
|
- Track which game a connection is participating in
|
|
- Check user online status
|
|
- Get all connections for a game
|
|
- Handle heartbeats and stale connection detection
|
|
|
|
Attributes:
|
|
conn_ttl_seconds: TTL for connection records in Redis.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
conn_ttl_seconds: int = DEFAULT_CONN_TTL_SECONDS,
|
|
redis_factory: RedisFactory | None = None,
|
|
) -> None:
|
|
"""Initialize the ConnectionManager.
|
|
|
|
Args:
|
|
conn_ttl_seconds: How long to keep connection records in Redis.
|
|
redis_factory: Optional factory for Redis connections. If not provided,
|
|
uses the default get_redis from app.db.redis. Useful for testing.
|
|
"""
|
|
self.conn_ttl_seconds = conn_ttl_seconds
|
|
self._get_redis = redis_factory if redis_factory is not None else get_redis
|
|
|
|
def _conn_key(self, sid: str) -> str:
|
|
"""Generate Redis key for a connection."""
|
|
return f"{CONN_PREFIX}{sid}"
|
|
|
|
def _user_conn_key(self, user_id: str) -> str:
|
|
"""Generate Redis key for user-to-connection mapping."""
|
|
return f"{USER_CONN_PREFIX}{user_id}"
|
|
|
|
def _game_conns_key(self, game_id: str) -> str:
|
|
"""Generate Redis key for game connection set."""
|
|
return f"{GAME_CONNS_PREFIX}{game_id}"
|
|
|
|
# =========================================================================
|
|
# Connection Lifecycle
|
|
# =========================================================================
|
|
|
|
async def register_connection(
|
|
self,
|
|
sid: str,
|
|
user_id: str | UUID,
|
|
) -> None:
|
|
"""Register a new WebSocket connection.
|
|
|
|
Creates connection record in Redis and maps user to this connection.
|
|
If user had a previous connection, it will be replaced.
|
|
|
|
Args:
|
|
sid: Socket.IO session ID.
|
|
user_id: Authenticated user's ID (UUID or string).
|
|
|
|
Example:
|
|
await manager.register_connection("abc123", user.id)
|
|
"""
|
|
user_id_str = str(user_id)
|
|
now = datetime.now(UTC).isoformat()
|
|
|
|
async with self._get_redis() as redis:
|
|
# Check for existing connection and clean it up
|
|
old_sid = await redis.get(self._user_conn_key(user_id_str))
|
|
if old_sid and old_sid != sid:
|
|
logger.info(f"Replacing old connection {old_sid} for user {user_id_str}")
|
|
await self._cleanup_connection(old_sid)
|
|
|
|
# Create connection record (hash)
|
|
conn_key = self._conn_key(sid)
|
|
await redis.hset(
|
|
conn_key,
|
|
mapping={
|
|
"user_id": user_id_str,
|
|
"game_id": "",
|
|
"connected_at": now,
|
|
"last_seen": now,
|
|
},
|
|
)
|
|
await redis.expire(conn_key, self.conn_ttl_seconds)
|
|
|
|
# Map user to this connection
|
|
user_conn_key = self._user_conn_key(user_id_str)
|
|
await redis.set(user_conn_key, sid)
|
|
await redis.expire(user_conn_key, self.conn_ttl_seconds)
|
|
|
|
logger.debug(f"Registered connection: sid={sid}, user_id={user_id_str}")
|
|
|
|
async def unregister_connection(self, sid: str) -> ConnectionInfo | None:
|
|
"""Unregister a WebSocket connection and clean up all related data.
|
|
|
|
Removes connection from:
|
|
- Connection record
|
|
- User-to-connection mapping
|
|
- Any game connection sets
|
|
|
|
Args:
|
|
sid: Socket.IO session ID.
|
|
|
|
Returns:
|
|
ConnectionInfo of the removed connection, or None if not found.
|
|
|
|
Example:
|
|
info = await manager.unregister_connection("abc123")
|
|
if info and info.game_id:
|
|
# Notify opponent of disconnect
|
|
"""
|
|
# Get connection info before cleanup
|
|
info = await self.get_connection(sid)
|
|
if info is None:
|
|
logger.debug(f"Connection not found for unregister: {sid}")
|
|
return None
|
|
|
|
await self._cleanup_connection(sid)
|
|
logger.debug(f"Unregistered connection: sid={sid}, user_id={info.user_id}")
|
|
return info
|
|
|
|
async def _cleanup_connection(self, sid: str) -> None:
|
|
"""Internal cleanup of a connection's Redis data.
|
|
|
|
Args:
|
|
sid: Socket.IO session ID.
|
|
"""
|
|
async with self._get_redis() as redis:
|
|
conn_key = self._conn_key(sid)
|
|
|
|
# Get connection data for cleanup
|
|
data = await redis.hgetall(conn_key)
|
|
if not data:
|
|
return
|
|
|
|
user_id = data.get("user_id", "")
|
|
game_id = data.get("game_id", "")
|
|
|
|
# Remove from game connection set if in a game
|
|
if game_id:
|
|
game_conns_key = self._game_conns_key(game_id)
|
|
await redis.srem(game_conns_key, sid)
|
|
|
|
# Remove user-to-connection mapping if it points to this sid
|
|
if user_id:
|
|
user_conn_key = self._user_conn_key(user_id)
|
|
current_sid = await redis.get(user_conn_key)
|
|
if current_sid == sid:
|
|
await redis.delete(user_conn_key)
|
|
|
|
# Delete connection record
|
|
await redis.delete(conn_key)
|
|
|
|
# =========================================================================
|
|
# Game Association
|
|
# =========================================================================
|
|
|
|
async def join_game(self, sid: str, game_id: str) -> bool:
|
|
"""Associate a connection with a game.
|
|
|
|
Updates the connection's game_id and adds it to the game's connection set.
|
|
If connection was in a different game, leaves that game first.
|
|
|
|
Args:
|
|
sid: Socket.IO session ID.
|
|
game_id: Game to join.
|
|
|
|
Returns:
|
|
True if successful, False if connection not found.
|
|
|
|
Example:
|
|
await manager.join_game("abc123", "game-456")
|
|
"""
|
|
async with self._get_redis() as redis:
|
|
conn_key = self._conn_key(sid)
|
|
|
|
# Check connection exists
|
|
exists = await redis.exists(conn_key)
|
|
if not exists:
|
|
logger.warning(f"Cannot join game: connection not found {sid}")
|
|
return False
|
|
|
|
# Leave current game if in one
|
|
current_game = await redis.hget(conn_key, "game_id")
|
|
if current_game and current_game != game_id:
|
|
old_game_conns_key = self._game_conns_key(current_game)
|
|
await redis.srem(old_game_conns_key, sid)
|
|
logger.debug(f"Connection {sid} left game {current_game}")
|
|
|
|
# Update connection's game_id
|
|
await redis.hset(conn_key, "game_id", game_id)
|
|
|
|
# Add to game's connection set
|
|
game_conns_key = self._game_conns_key(game_id)
|
|
await redis.sadd(game_conns_key, sid)
|
|
|
|
# Set TTL on game conns set (cleanup if game becomes inactive)
|
|
await redis.expire(game_conns_key, self.conn_ttl_seconds)
|
|
|
|
logger.debug(f"Connection {sid} joined game {game_id}")
|
|
return True
|
|
|
|
async def leave_game(self, sid: str) -> str | None:
|
|
"""Remove a connection from its current game.
|
|
|
|
Args:
|
|
sid: Socket.IO session ID.
|
|
|
|
Returns:
|
|
The game_id that was left, or None if not in a game.
|
|
|
|
Example:
|
|
game_id = await manager.leave_game("abc123")
|
|
"""
|
|
async with self._get_redis() as redis:
|
|
conn_key = self._conn_key(sid)
|
|
|
|
# Get current game
|
|
game_id = await redis.hget(conn_key, "game_id")
|
|
if not game_id:
|
|
return None
|
|
|
|
# Remove from game's connection set
|
|
game_conns_key = self._game_conns_key(game_id)
|
|
await redis.srem(game_conns_key, sid)
|
|
|
|
# Clear game_id on connection
|
|
await redis.hset(conn_key, "game_id", "")
|
|
|
|
logger.debug(f"Connection {sid} left game {game_id}")
|
|
return game_id
|
|
|
|
# =========================================================================
|
|
# Heartbeat / Activity
|
|
# =========================================================================
|
|
|
|
async def update_heartbeat(self, sid: str) -> bool:
|
|
"""Update the last_seen timestamp for a connection.
|
|
|
|
Should be called on each heartbeat message from the client.
|
|
Also refreshes TTL on connection records.
|
|
|
|
Args:
|
|
sid: Socket.IO session ID.
|
|
|
|
Returns:
|
|
True if successful, False if connection not found.
|
|
|
|
Example:
|
|
await manager.update_heartbeat("abc123")
|
|
"""
|
|
now = datetime.now(UTC).isoformat()
|
|
|
|
async with self._get_redis() as redis:
|
|
conn_key = self._conn_key(sid)
|
|
|
|
# Check exists and update atomically
|
|
exists = await redis.exists(conn_key)
|
|
if not exists:
|
|
return False
|
|
|
|
# Update last_seen
|
|
await redis.hset(conn_key, "last_seen", now)
|
|
|
|
# Refresh TTL
|
|
await redis.expire(conn_key, self.conn_ttl_seconds)
|
|
|
|
# Also refresh user mapping TTL
|
|
user_id = await redis.hget(conn_key, "user_id")
|
|
if user_id:
|
|
user_conn_key = self._user_conn_key(user_id)
|
|
await redis.expire(user_conn_key, self.conn_ttl_seconds)
|
|
|
|
return True
|
|
|
|
# =========================================================================
|
|
# Query Methods
|
|
# =========================================================================
|
|
|
|
async def get_connection(self, sid: str) -> ConnectionInfo | None:
|
|
"""Get connection info by session ID.
|
|
|
|
Args:
|
|
sid: Socket.IO session ID.
|
|
|
|
Returns:
|
|
ConnectionInfo if found, None otherwise.
|
|
|
|
Example:
|
|
info = await manager.get_connection("abc123")
|
|
if info:
|
|
print(f"User {info.user_id} connected at {info.connected_at}")
|
|
"""
|
|
async with self._get_redis() as redis:
|
|
conn_key = self._conn_key(sid)
|
|
data = await redis.hgetall(conn_key)
|
|
|
|
if not data:
|
|
return None
|
|
|
|
return ConnectionInfo(
|
|
sid=sid,
|
|
user_id=data.get("user_id", ""),
|
|
game_id=data.get("game_id") or None,
|
|
connected_at=datetime.fromisoformat(data["connected_at"]),
|
|
last_seen=datetime.fromisoformat(data["last_seen"]),
|
|
)
|
|
|
|
async def get_user_connection(self, user_id: str | UUID) -> ConnectionInfo | None:
|
|
"""Get connection info for a user.
|
|
|
|
Args:
|
|
user_id: User's UUID or string ID.
|
|
|
|
Returns:
|
|
ConnectionInfo if user is connected, None otherwise.
|
|
|
|
Example:
|
|
info = await manager.get_user_connection(opponent_id)
|
|
if info:
|
|
# User is online
|
|
"""
|
|
user_id_str = str(user_id)
|
|
|
|
async with self._get_redis() as redis:
|
|
user_conn_key = self._user_conn_key(user_id_str)
|
|
sid = await redis.get(user_conn_key)
|
|
|
|
if not sid:
|
|
return None
|
|
|
|
return await self.get_connection(sid)
|
|
|
|
async def is_user_online(self, user_id: str | UUID) -> bool:
|
|
"""Check if a user has an active connection.
|
|
|
|
Args:
|
|
user_id: User's UUID or string ID.
|
|
|
|
Returns:
|
|
True if user is connected, False otherwise.
|
|
|
|
Example:
|
|
if await manager.is_user_online(opponent_id):
|
|
# Send real-time notification
|
|
"""
|
|
info = await self.get_user_connection(user_id)
|
|
if info is None:
|
|
return False
|
|
|
|
# Check if connection is stale
|
|
if info.is_stale():
|
|
logger.debug(f"User {user_id} connection is stale")
|
|
return False
|
|
|
|
return True
|
|
|
|
async def get_game_connections(self, game_id: str) -> list[ConnectionInfo]:
|
|
"""Get all connections for a game.
|
|
|
|
Args:
|
|
game_id: Game ID.
|
|
|
|
Returns:
|
|
List of ConnectionInfo for all connected participants.
|
|
|
|
Example:
|
|
connections = await manager.get_game_connections("game-456")
|
|
for conn in connections:
|
|
# Send state update to each participant
|
|
"""
|
|
async with self._get_redis() as redis:
|
|
game_conns_key = self._game_conns_key(game_id)
|
|
sids = await redis.smembers(game_conns_key)
|
|
|
|
connections = []
|
|
for sid in sids:
|
|
info = await self.get_connection(sid)
|
|
if info is not None:
|
|
connections.append(info)
|
|
|
|
return connections
|
|
|
|
async def get_game_user_sids(self, game_id: str) -> dict[str, str]:
|
|
"""Get mapping of user_id to sid for a game.
|
|
|
|
Useful for targeted message delivery to specific players.
|
|
|
|
Args:
|
|
game_id: Game ID.
|
|
|
|
Returns:
|
|
Dict mapping user_id to their sid.
|
|
|
|
Example:
|
|
user_sids = await manager.get_game_user_sids("game-456")
|
|
player1_sid = user_sids.get(player1_id)
|
|
"""
|
|
connections = await self.get_game_connections(game_id)
|
|
return {conn.user_id: conn.sid for conn in connections}
|
|
|
|
async def get_opponent_sid(
|
|
self,
|
|
game_id: str,
|
|
current_user_id: str | UUID,
|
|
) -> str | None:
|
|
"""Get the sid of the opponent in a 2-player game.
|
|
|
|
Args:
|
|
game_id: Game ID.
|
|
current_user_id: The current user's ID.
|
|
|
|
Returns:
|
|
Opponent's sid if found and connected, None otherwise.
|
|
|
|
Example:
|
|
opponent_sid = await manager.get_opponent_sid("game-456", user_id)
|
|
if opponent_sid:
|
|
# Send message to opponent
|
|
"""
|
|
current_user_str = str(current_user_id)
|
|
connections = await self.get_game_connections(game_id)
|
|
|
|
for conn in connections:
|
|
if conn.user_id != current_user_str:
|
|
return conn.sid
|
|
|
|
return None
|
|
|
|
# =========================================================================
|
|
# Maintenance
|
|
# =========================================================================
|
|
|
|
async def cleanup_stale_connections(
|
|
self,
|
|
threshold_seconds: int = HEARTBEAT_INTERVAL_SECONDS * 3,
|
|
) -> int:
|
|
"""Clean up stale connections that haven't sent heartbeats.
|
|
|
|
Should be called periodically by a background task.
|
|
|
|
Args:
|
|
threshold_seconds: Seconds since last_seen to consider stale.
|
|
|
|
Returns:
|
|
Number of connections cleaned up.
|
|
|
|
Example:
|
|
# In background task
|
|
count = await manager.cleanup_stale_connections()
|
|
logger.info(f"Cleaned up {count} stale connections")
|
|
"""
|
|
count = 0
|
|
async with self._get_redis() as redis:
|
|
# Scan for all connection keys
|
|
async for key in redis.scan_iter(match=f"{CONN_PREFIX}*"):
|
|
sid = key[len(CONN_PREFIX) :]
|
|
info = await self.get_connection(sid)
|
|
|
|
if info and info.is_stale(threshold_seconds):
|
|
await self._cleanup_connection(sid)
|
|
count += 1
|
|
logger.debug(f"Cleaned up stale connection: {sid}")
|
|
|
|
if count > 0:
|
|
logger.info(f"Cleaned up {count} stale connections")
|
|
|
|
return count
|
|
|
|
async def get_connection_count(self) -> int:
|
|
"""Get the total number of active connections.
|
|
|
|
Useful for monitoring and admin dashboards.
|
|
|
|
Returns:
|
|
Number of active connections.
|
|
"""
|
|
count = 0
|
|
async with self._get_redis() as redis:
|
|
async for _ in redis.scan_iter(match=f"{CONN_PREFIX}*"):
|
|
count += 1
|
|
return count
|
|
|
|
async def get_game_connection_count(self, game_id: str) -> int:
|
|
"""Get the number of connections for a specific game.
|
|
|
|
Args:
|
|
game_id: Game ID.
|
|
|
|
Returns:
|
|
Number of connected participants.
|
|
"""
|
|
async with self._get_redis() as redis:
|
|
game_conns_key = self._game_conns_key(game_id)
|
|
return await redis.scard(game_conns_key)
|
|
|
|
|
|
# Global singleton instance
|
|
connection_manager = ConnectionManager()
|