mantimon-tcg/backend/app/services/connection_manager.py
Cal Corum 0c810e5b30 Add Phase 4 WebSocket infrastructure (WS-001 through GS-001)
WebSocket Message Schemas (WS-002):
- Add Pydantic models for all client/server WebSocket messages
- Implement discriminated unions for message type parsing
- Include JoinGame, Action, Resign, Heartbeat client messages
- Include GameState, ActionResult, Error, TurnStart server messages

Connection Manager (WS-003):
- Add Redis-backed WebSocket connection tracking
- Implement user-to-sid mapping with TTL management
- Support game room association and opponent lookup
- Add heartbeat tracking for connection health

Socket.IO Authentication (WS-004):
- Add JWT-based authentication middleware
- Support token extraction from multiple formats
- Implement session setup with ConnectionManager integration
- Add require_auth helper for event handlers

Socket.IO Server Setup (WS-001):
- Configure AsyncServer with ASGI mode
- Register /game namespace with event handlers
- Integrate with FastAPI via ASGIApp wrapper
- Configure CORS from application settings

Game Service (GS-001):
- Add stateless GameService for game lifecycle orchestration
- Create engine per-operation using rules from GameState
- Implement action-based RNG seeding for deterministic replay
- Add rng_seed field to GameState for replay support

Architecture verified:
- Core module independence (no forbidden imports)
- Config from request pattern (rules in GameState)
- Dependency injection (constructor deps, method config)
- All 1090 tests passing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 22:21:20 -06:00

579 lines
18 KiB
Python

"""WebSocket connection management with Redis-backed session tracking.
This module manages WebSocket connection state using Redis for persistence
and cross-instance coordination. It tracks:
- Individual connection sessions (sid -> user_id, game_id, timestamps)
- User to connection mapping (user_id -> sid)
- Game participants (game_id -> set of sids)
Key Patterns:
conn:{sid} - Hash with connection details (user_id, game_id, connected_at, last_seen)
user_conn:{user_id} - String with active sid
game_conns:{game_id} - Set of sids for game participants
Lifecycle:
1. on_connect: Register connection, map user to sid
2. on_join_game: Add sid to game's connection set
3. on_heartbeat: Update last_seen timestamp
4. on_disconnect: Remove from all tracking structures
Example:
manager = ConnectionManager()
# On WebSocket connect
await manager.register_connection(sid, user_id)
# On joining a game
await manager.join_game(sid, game_id)
# Check if opponent is online
if await manager.is_user_online(opponent_id):
# Send real-time update
# On disconnect
await manager.unregister_connection(sid)
"""
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from uuid import UUID
from app.db.redis import get_redis
logger = logging.getLogger(__name__)
# Redis key patterns
CONN_PREFIX = "conn:"
USER_CONN_PREFIX = "user_conn:"
GAME_CONNS_PREFIX = "game_conns:"
# Connection TTL (auto-expire stale connections)
DEFAULT_CONN_TTL_SECONDS = 3600 # 1 hour
HEARTBEAT_INTERVAL_SECONDS = 30
@dataclass
class ConnectionInfo:
"""Information about a WebSocket connection.
Attributes:
sid: Socket.IO session ID.
user_id: Authenticated user's UUID.
game_id: Current game ID if in a game, None otherwise.
connected_at: When the connection was established.
last_seen: Last heartbeat or activity timestamp.
"""
sid: str
user_id: str
game_id: str | None
connected_at: datetime
last_seen: datetime
def is_stale(self, threshold_seconds: int = HEARTBEAT_INTERVAL_SECONDS * 3) -> bool:
"""Check if connection is stale (no recent activity).
Args:
threshold_seconds: Seconds since last_seen to consider stale.
Returns:
True if connection hasn't been seen recently.
"""
elapsed = (datetime.now(UTC) - self.last_seen).total_seconds()
return elapsed > threshold_seconds
class ConnectionManager:
"""Manages WebSocket connections with Redis-backed session tracking.
Provides methods to:
- Register/unregister connections
- Track which game a connection is participating in
- Check user online status
- Get all connections for a game
- Handle heartbeats and stale connection detection
Attributes:
conn_ttl_seconds: TTL for connection records in Redis.
"""
def __init__(self, conn_ttl_seconds: int = DEFAULT_CONN_TTL_SECONDS) -> None:
"""Initialize the ConnectionManager.
Args:
conn_ttl_seconds: How long to keep connection records in Redis.
"""
self.conn_ttl_seconds = conn_ttl_seconds
def _conn_key(self, sid: str) -> str:
"""Generate Redis key for a connection."""
return f"{CONN_PREFIX}{sid}"
def _user_conn_key(self, user_id: str) -> str:
"""Generate Redis key for user-to-connection mapping."""
return f"{USER_CONN_PREFIX}{user_id}"
def _game_conns_key(self, game_id: str) -> str:
"""Generate Redis key for game connection set."""
return f"{GAME_CONNS_PREFIX}{game_id}"
# =========================================================================
# Connection Lifecycle
# =========================================================================
async def register_connection(
self,
sid: str,
user_id: str | UUID,
) -> None:
"""Register a new WebSocket connection.
Creates connection record in Redis and maps user to this connection.
If user had a previous connection, it will be replaced.
Args:
sid: Socket.IO session ID.
user_id: Authenticated user's ID (UUID or string).
Example:
await manager.register_connection("abc123", user.id)
"""
user_id_str = str(user_id)
now = datetime.now(UTC).isoformat()
async with get_redis() as redis:
# Check for existing connection and clean it up
old_sid = await redis.get(self._user_conn_key(user_id_str))
if old_sid and old_sid != sid:
logger.info(f"Replacing old connection {old_sid} for user {user_id_str}")
await self._cleanup_connection(old_sid)
# Create connection record (hash)
conn_key = self._conn_key(sid)
await redis.hset(
conn_key,
mapping={
"user_id": user_id_str,
"game_id": "",
"connected_at": now,
"last_seen": now,
},
)
await redis.expire(conn_key, self.conn_ttl_seconds)
# Map user to this connection
user_conn_key = self._user_conn_key(user_id_str)
await redis.set(user_conn_key, sid)
await redis.expire(user_conn_key, self.conn_ttl_seconds)
logger.debug(f"Registered connection: sid={sid}, user_id={user_id_str}")
async def unregister_connection(self, sid: str) -> ConnectionInfo | None:
"""Unregister a WebSocket connection and clean up all related data.
Removes connection from:
- Connection record
- User-to-connection mapping
- Any game connection sets
Args:
sid: Socket.IO session ID.
Returns:
ConnectionInfo of the removed connection, or None if not found.
Example:
info = await manager.unregister_connection("abc123")
if info and info.game_id:
# Notify opponent of disconnect
"""
# Get connection info before cleanup
info = await self.get_connection(sid)
if info is None:
logger.debug(f"Connection not found for unregister: {sid}")
return None
await self._cleanup_connection(sid)
logger.debug(f"Unregistered connection: sid={sid}, user_id={info.user_id}")
return info
async def _cleanup_connection(self, sid: str) -> None:
"""Internal cleanup of a connection's Redis data.
Args:
sid: Socket.IO session ID.
"""
async with get_redis() as redis:
conn_key = self._conn_key(sid)
# Get connection data for cleanup
data = await redis.hgetall(conn_key)
if not data:
return
user_id = data.get("user_id", "")
game_id = data.get("game_id", "")
# Remove from game connection set if in a game
if game_id:
game_conns_key = self._game_conns_key(game_id)
await redis.srem(game_conns_key, sid)
# Remove user-to-connection mapping if it points to this sid
if user_id:
user_conn_key = self._user_conn_key(user_id)
current_sid = await redis.get(user_conn_key)
if current_sid == sid:
await redis.delete(user_conn_key)
# Delete connection record
await redis.delete(conn_key)
# =========================================================================
# Game Association
# =========================================================================
async def join_game(self, sid: str, game_id: str) -> bool:
"""Associate a connection with a game.
Updates the connection's game_id and adds it to the game's connection set.
If connection was in a different game, leaves that game first.
Args:
sid: Socket.IO session ID.
game_id: Game to join.
Returns:
True if successful, False if connection not found.
Example:
await manager.join_game("abc123", "game-456")
"""
async with get_redis() as redis:
conn_key = self._conn_key(sid)
# Check connection exists
exists = await redis.exists(conn_key)
if not exists:
logger.warning(f"Cannot join game: connection not found {sid}")
return False
# Leave current game if in one
current_game = await redis.hget(conn_key, "game_id")
if current_game and current_game != game_id:
old_game_conns_key = self._game_conns_key(current_game)
await redis.srem(old_game_conns_key, sid)
logger.debug(f"Connection {sid} left game {current_game}")
# Update connection's game_id
await redis.hset(conn_key, "game_id", game_id)
# Add to game's connection set
game_conns_key = self._game_conns_key(game_id)
await redis.sadd(game_conns_key, sid)
# Set TTL on game conns set (cleanup if game becomes inactive)
await redis.expire(game_conns_key, self.conn_ttl_seconds)
logger.debug(f"Connection {sid} joined game {game_id}")
return True
async def leave_game(self, sid: str) -> str | None:
"""Remove a connection from its current game.
Args:
sid: Socket.IO session ID.
Returns:
The game_id that was left, or None if not in a game.
Example:
game_id = await manager.leave_game("abc123")
"""
async with get_redis() as redis:
conn_key = self._conn_key(sid)
# Get current game
game_id = await redis.hget(conn_key, "game_id")
if not game_id:
return None
# Remove from game's connection set
game_conns_key = self._game_conns_key(game_id)
await redis.srem(game_conns_key, sid)
# Clear game_id on connection
await redis.hset(conn_key, "game_id", "")
logger.debug(f"Connection {sid} left game {game_id}")
return game_id
# =========================================================================
# Heartbeat / Activity
# =========================================================================
async def update_heartbeat(self, sid: str) -> bool:
"""Update the last_seen timestamp for a connection.
Should be called on each heartbeat message from the client.
Also refreshes TTL on connection records.
Args:
sid: Socket.IO session ID.
Returns:
True if successful, False if connection not found.
Example:
await manager.update_heartbeat("abc123")
"""
now = datetime.now(UTC).isoformat()
async with get_redis() as redis:
conn_key = self._conn_key(sid)
# Check exists and update atomically
exists = await redis.exists(conn_key)
if not exists:
return False
# Update last_seen
await redis.hset(conn_key, "last_seen", now)
# Refresh TTL
await redis.expire(conn_key, self.conn_ttl_seconds)
# Also refresh user mapping TTL
user_id = await redis.hget(conn_key, "user_id")
if user_id:
user_conn_key = self._user_conn_key(user_id)
await redis.expire(user_conn_key, self.conn_ttl_seconds)
return True
# =========================================================================
# Query Methods
# =========================================================================
async def get_connection(self, sid: str) -> ConnectionInfo | None:
"""Get connection info by session ID.
Args:
sid: Socket.IO session ID.
Returns:
ConnectionInfo if found, None otherwise.
Example:
info = await manager.get_connection("abc123")
if info:
print(f"User {info.user_id} connected at {info.connected_at}")
"""
async with get_redis() as redis:
conn_key = self._conn_key(sid)
data = await redis.hgetall(conn_key)
if not data:
return None
return ConnectionInfo(
sid=sid,
user_id=data.get("user_id", ""),
game_id=data.get("game_id") or None,
connected_at=datetime.fromisoformat(data["connected_at"]),
last_seen=datetime.fromisoformat(data["last_seen"]),
)
async def get_user_connection(self, user_id: str | UUID) -> ConnectionInfo | None:
"""Get connection info for a user.
Args:
user_id: User's UUID or string ID.
Returns:
ConnectionInfo if user is connected, None otherwise.
Example:
info = await manager.get_user_connection(opponent_id)
if info:
# User is online
"""
user_id_str = str(user_id)
async with get_redis() as redis:
user_conn_key = self._user_conn_key(user_id_str)
sid = await redis.get(user_conn_key)
if not sid:
return None
return await self.get_connection(sid)
async def is_user_online(self, user_id: str | UUID) -> bool:
"""Check if a user has an active connection.
Args:
user_id: User's UUID or string ID.
Returns:
True if user is connected, False otherwise.
Example:
if await manager.is_user_online(opponent_id):
# Send real-time notification
"""
info = await self.get_user_connection(user_id)
if info is None:
return False
# Check if connection is stale
if info.is_stale():
logger.debug(f"User {user_id} connection is stale")
return False
return True
async def get_game_connections(self, game_id: str) -> list[ConnectionInfo]:
"""Get all connections for a game.
Args:
game_id: Game ID.
Returns:
List of ConnectionInfo for all connected participants.
Example:
connections = await manager.get_game_connections("game-456")
for conn in connections:
# Send state update to each participant
"""
async with get_redis() as redis:
game_conns_key = self._game_conns_key(game_id)
sids = await redis.smembers(game_conns_key)
connections = []
for sid in sids:
info = await self.get_connection(sid)
if info is not None:
connections.append(info)
return connections
async def get_game_user_sids(self, game_id: str) -> dict[str, str]:
"""Get mapping of user_id to sid for a game.
Useful for targeted message delivery to specific players.
Args:
game_id: Game ID.
Returns:
Dict mapping user_id to their sid.
Example:
user_sids = await manager.get_game_user_sids("game-456")
player1_sid = user_sids.get(player1_id)
"""
connections = await self.get_game_connections(game_id)
return {conn.user_id: conn.sid for conn in connections}
async def get_opponent_sid(
self,
game_id: str,
current_user_id: str | UUID,
) -> str | None:
"""Get the sid of the opponent in a 2-player game.
Args:
game_id: Game ID.
current_user_id: The current user's ID.
Returns:
Opponent's sid if found and connected, None otherwise.
Example:
opponent_sid = await manager.get_opponent_sid("game-456", user_id)
if opponent_sid:
# Send message to opponent
"""
current_user_str = str(current_user_id)
connections = await self.get_game_connections(game_id)
for conn in connections:
if conn.user_id != current_user_str:
return conn.sid
return None
# =========================================================================
# Maintenance
# =========================================================================
async def cleanup_stale_connections(
self,
threshold_seconds: int = HEARTBEAT_INTERVAL_SECONDS * 3,
) -> int:
"""Clean up stale connections that haven't sent heartbeats.
Should be called periodically by a background task.
Args:
threshold_seconds: Seconds since last_seen to consider stale.
Returns:
Number of connections cleaned up.
Example:
# In background task
count = await manager.cleanup_stale_connections()
logger.info(f"Cleaned up {count} stale connections")
"""
count = 0
async with get_redis() as redis:
# Scan for all connection keys
async for key in redis.scan_iter(match=f"{CONN_PREFIX}*"):
sid = key[len(CONN_PREFIX) :]
info = await self.get_connection(sid)
if info and info.is_stale(threshold_seconds):
await self._cleanup_connection(sid)
count += 1
logger.debug(f"Cleaned up stale connection: {sid}")
if count > 0:
logger.info(f"Cleaned up {count} stale connections")
return count
async def get_connection_count(self) -> int:
"""Get the total number of active connections.
Useful for monitoring and admin dashboards.
Returns:
Number of active connections.
"""
count = 0
async with get_redis() as redis:
async for _ in redis.scan_iter(match=f"{CONN_PREFIX}*"):
count += 1
return count
async def get_game_connection_count(self, game_id: str) -> int:
"""Get the number of connections for a specific game.
Args:
game_id: Game ID.
Returns:
Number of connected participants.
"""
async with get_redis() as redis:
game_conns_key = self._game_conns_key(game_id)
return await redis.scard(game_conns_key)
# Global singleton instance
connection_manager = ConnectionManager()