diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index e2fb625..2ce79f2 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -78,11 +78,11 @@ def validate_deck(cards, config: DeckConfig | None = None): ## Dependency Injection for Services -**Services use constructor-based dependency injection for repositories and other services.** +**Services use constructor-based dependency injection for repositories, services, and external resources.** -Services still use DI for data access, but config comes from method parameters (request). +Services use DI for data access and external dependencies, but config comes from method parameters (request). -### Required Pattern +### Required Pattern: Services and Repositories ```python class DeckService: @@ -109,12 +109,78 @@ class DeckService: ... ``` +### Required Pattern: Factory Injection for External Resources + +For external resources (Redis, databases, engines), inject factory functions rather than instances: + +```python +from collections.abc import AsyncIterator, Callable + +# Type aliases for factories +RedisFactory = Callable[[], AsyncIterator["Redis"]] +EngineFactory = Callable[[GameState], GameEngine] +TokenVerifier = Callable[[str], UUID | None] + +class ConnectionManager: + """External resources injected as factories.""" + + def __init__( + self, + conn_ttl_seconds: int = 3600, + redis_factory: RedisFactory | None = None, # Factory, not instance + ) -> None: + self.conn_ttl_seconds = conn_ttl_seconds + # Fall back to global only if not provided + self._get_redis = redis_factory if redis_factory is not None else get_redis + + async def register_connection(self, sid: str, user_id: UUID) -> None: + async with self._get_redis() as redis: # Use injected factory + await redis.hset(...) +``` + +```python +class GameService: + """Engine created via injected factory for testability.""" + + def __init__( + self, + state_manager: GameStateManager | None = None, + engine_factory: EngineFactory | None = None, # Factory for engine creation + ) -> None: + self._state_manager = state_manager or game_state_manager + self._engine_factory = engine_factory or self._default_engine_factory + + async def execute_action(self, game_id: str, player_id: str, action: Action): + state = await self.get_game_state(game_id) + engine = self._engine_factory(state) # Create via factory + result = await engine.execute_action(state, player_id, action) +``` + +### Global Singletons with DI Support + +Create global instances for production, but always support injection for testing: + +```python +# Global singleton for production use +connection_manager = ConnectionManager() + +# Tests inject mocks via constructor - no patching needed +def test_register_connection(mock_redis): + @asynccontextmanager + async def mock_redis_factory(): + yield mock_redis + + manager = ConnectionManager(redis_factory=mock_redis_factory) + # Test with injected mock... +``` + ### Why This Matters 1. **Testability**: Dependencies can be mocked without patching globals 2. **Offline Fork**: Services can be swapped for local implementations 3. **Explicit Dependencies**: Constructor shows all requirements 4. **Stateless Operations**: Config comes from request, not server state +5. **No Monkey Patching**: Tests use constructor injection, not `patch()` --- @@ -253,6 +319,55 @@ def test_paralyzed_pokemon_cannot_attack(): | `tests/core/` | Core engine tests | No | | `tests/api/` | API endpoint tests | Yes | +### No Monkey Patching - Use Dependency Injection + +**Never use `patch()` or `monkeypatch` for unit tests.** Instead, inject mocks via constructor. + +```python +# WRONG - Monkey patching +async def test_execute_action(): + with patch.object(game_service, "_create_engine") as mock: + mock.return_value = mock_engine + result = await game_service.execute_action(...) + +# CORRECT - Dependency injection +@pytest.fixture +def mock_engine(): + engine = MagicMock() + engine.execute_action = AsyncMock(return_value=ActionResult(success=True)) + return engine + +@pytest.fixture +def game_service(mock_state_manager, mock_engine): + return GameService( + state_manager=mock_state_manager, + engine_factory=lambda state: mock_engine, # Inject via constructor + ) + +async def test_execute_action(game_service, mock_engine): + result = await game_service.execute_action(...) + mock_engine.execute_action.assert_called_once() +``` + +### Factory Fixtures for External Resources + +```python +@pytest.fixture +def mock_redis(): + redis = AsyncMock() + redis.hset = AsyncMock() + redis.hget = AsyncMock(return_value=None) + return redis + +@pytest.fixture +def connection_manager(mock_redis): + @asynccontextmanager + async def mock_redis_factory(): + yield mock_redis + + return ConnectionManager(redis_factory=mock_redis_factory) +``` + ### Use Seeded RNG for Determinism ```python @@ -276,6 +391,12 @@ app/ │ ├── config.py # RulesConfig and sub-configs │ └── engine.py # GameEngine orchestrator ├── services/ # Business logic (uses repositories) +│ ├── game_service.py # Game orchestration (uses engine_factory DI) +│ ├── connection_manager.py # WebSocket tracking (uses redis_factory DI) +│ └── ... +├── socketio/ # WebSocket real-time communication +│ ├── server.py # Socket.IO server setup and handlers +│ └── auth.py # AuthHandler class (uses token_verifier DI) ├── repositories/ # Data access layer │ ├── protocols.py # Repository protocols (interfaces) │ └── postgres/ # PostgreSQL implementations @@ -322,6 +443,9 @@ app/ | Hardcoded magic numbers | Use values from config parameter | | Tests without docstrings | Always explain what and why | | Unit tests in `tests/services/` | Use `tests/unit/` for no-DB tests | +| Using `patch()` in unit tests | Inject mocks via constructor DI | +| `async with get_redis()` in method body | Inject `redis_factory` via constructor | +| Importing dependencies inside functions | Import at module level, inject via constructor | --- @@ -330,3 +454,5 @@ app/ - `app/core/AGENTS.md` - Detailed core engine guidelines - `app/core/README.md` - Core module documentation - `app/core/effects/README.md` - Effect system documentation +- `app/services/README.md` - Services layer patterns +- `app/socketio/README.md` - WebSocket module documentation diff --git a/backend/app/services/README.md b/backend/app/services/README.md index 9bd0fa2..6512a56 100644 --- a/backend/app/services/README.md +++ b/backend/app/services/README.md @@ -19,6 +19,8 @@ Business logic layer between API endpoints and data access. | `DeckValidator` | Pure deck validation functions | CardService (for lookups) | | `UserService` | User profile management | Direct DB access | | `GameStateManager` | Game state (Redis + Postgres) | Redis, AsyncSession | +| `GameService` | Game orchestration + actions | GameStateManager, CardService, engine_factory | +| `ConnectionManager` | WebSocket connection tracking | redis_factory | | `JWTService` | Token creation/verification | Settings | | `TokenStore` | Refresh token storage | Redis | @@ -93,6 +95,50 @@ Benefits: - **Offline Fork**: Swap PostgresRepository for LocalRepository - **Decoupling**: Services don't know about SQLAlchemy +### Factory Injection for External Resources + +For external resources (Redis, GameEngine), inject factory functions rather than instances: + +```python +from collections.abc import AsyncIterator, Callable + +# Type aliases for factories +RedisFactory = Callable[[], AsyncIterator["Redis"]] +EngineFactory = Callable[[GameState], GameEngine] + +class ConnectionManager: + """External resources injected as factories.""" + + def __init__( + self, + conn_ttl_seconds: int = 3600, + redis_factory: RedisFactory | None = None, + ) -> None: + self.conn_ttl_seconds = conn_ttl_seconds + self._get_redis = redis_factory if redis_factory is not None else get_redis + +class GameService: + """Engine created via injected factory for testability.""" + + def __init__( + self, + state_manager: GameStateManager | None = None, + engine_factory: EngineFactory | None = None, + ) -> None: + self._state_manager = state_manager or game_state_manager + self._engine_factory = engine_factory or self._default_engine_factory +``` + +Global singletons are created for production, but constructors always support injection: + +```python +# Production singleton +connection_manager = ConnectionManager() + +# Tests inject mocks - no patching needed +manager = ConnectionManager(redis_factory=mock_redis_factory) +``` + ### Pure Validation Functions Validation logic is extracted into pure functions for reuse: @@ -183,6 +229,27 @@ await manager.persist_to_db(game_id, game_state) | `tests/unit/services/` | Pure unit tests, mocked deps | No | | `tests/services/` | Integration tests | Yes (testcontainers) | +**No monkey patching** - Always use constructor DI for mocks: + +```python +# WRONG - monkey patching +async def test_something(): + with patch.object(service, "_method") as mock: + ... + +# CORRECT - constructor injection +@pytest.fixture +def mock_redis(): + return AsyncMock() + +@pytest.fixture +def connection_manager(mock_redis): + @asynccontextmanager + async def mock_redis_factory(): + yield mock_redis + return ConnectionManager(redis_factory=mock_redis_factory) +``` + ### Unit Test Example ```python diff --git a/backend/app/services/connection_manager.py b/backend/app/services/connection_manager.py index 185d3a8..e8823f2 100644 --- a/backend/app/services/connection_manager.py +++ b/backend/app/services/connection_manager.py @@ -35,14 +35,22 @@ Example: """ import logging +from collections.abc import AsyncIterator, Callable from dataclasses import dataclass from datetime import UTC, datetime +from typing import TYPE_CHECKING from uuid import UUID from app.db.redis import get_redis +if TYPE_CHECKING: + from redis.asyncio import Redis + logger = logging.getLogger(__name__) +# Type alias for redis factory - a callable that returns an async context manager +RedisFactory = Callable[[], AsyncIterator["Redis"]] + # Redis key patterns CONN_PREFIX = "conn:" USER_CONN_PREFIX = "user_conn:" @@ -98,13 +106,20 @@ class ConnectionManager: conn_ttl_seconds: TTL for connection records in Redis. """ - def __init__(self, conn_ttl_seconds: int = DEFAULT_CONN_TTL_SECONDS) -> None: + def __init__( + self, + conn_ttl_seconds: int = DEFAULT_CONN_TTL_SECONDS, + redis_factory: RedisFactory | None = None, + ) -> None: """Initialize the ConnectionManager. Args: conn_ttl_seconds: How long to keep connection records in Redis. + redis_factory: Optional factory for Redis connections. If not provided, + uses the default get_redis from app.db.redis. Useful for testing. """ self.conn_ttl_seconds = conn_ttl_seconds + self._get_redis = redis_factory if redis_factory is not None else get_redis def _conn_key(self, sid: str) -> str: """Generate Redis key for a connection.""" @@ -142,7 +157,7 @@ class ConnectionManager: user_id_str = str(user_id) now = datetime.now(UTC).isoformat() - async with get_redis() as redis: + async with self._get_redis() as redis: # Check for existing connection and clean it up old_sid = await redis.get(self._user_conn_key(user_id_str)) if old_sid and old_sid != sid: @@ -204,7 +219,7 @@ class ConnectionManager: Args: sid: Socket.IO session ID. """ - async with get_redis() as redis: + async with self._get_redis() as redis: conn_key = self._conn_key(sid) # Get connection data for cleanup @@ -250,7 +265,7 @@ class ConnectionManager: Example: await manager.join_game("abc123", "game-456") """ - async with get_redis() as redis: + async with self._get_redis() as redis: conn_key = self._conn_key(sid) # Check connection exists @@ -291,7 +306,7 @@ class ConnectionManager: Example: game_id = await manager.leave_game("abc123") """ - async with get_redis() as redis: + async with self._get_redis() as redis: conn_key = self._conn_key(sid) # Get current game @@ -330,7 +345,7 @@ class ConnectionManager: """ now = datetime.now(UTC).isoformat() - async with get_redis() as redis: + async with self._get_redis() as redis: conn_key = self._conn_key(sid) # Check exists and update atomically @@ -370,7 +385,7 @@ class ConnectionManager: if info: print(f"User {info.user_id} connected at {info.connected_at}") """ - async with get_redis() as redis: + async with self._get_redis() as redis: conn_key = self._conn_key(sid) data = await redis.hgetall(conn_key) @@ -401,7 +416,7 @@ class ConnectionManager: """ user_id_str = str(user_id) - async with get_redis() as redis: + async with self._get_redis() as redis: user_conn_key = self._user_conn_key(user_id_str) sid = await redis.get(user_conn_key) @@ -448,7 +463,7 @@ class ConnectionManager: for conn in connections: # Send state update to each participant """ - async with get_redis() as redis: + async with self._get_redis() as redis: game_conns_key = self._game_conns_key(game_id) sids = await redis.smembers(game_conns_key) @@ -530,7 +545,7 @@ class ConnectionManager: logger.info(f"Cleaned up {count} stale connections") """ count = 0 - async with get_redis() as redis: + async with self._get_redis() as redis: # Scan for all connection keys async for key in redis.scan_iter(match=f"{CONN_PREFIX}*"): sid = key[len(CONN_PREFIX) :] @@ -555,7 +570,7 @@ class ConnectionManager: Number of active connections. """ count = 0 - async with get_redis() as redis: + async with self._get_redis() as redis: async for _ in redis.scan_iter(match=f"{CONN_PREFIX}*"): count += 1 return count @@ -569,7 +584,7 @@ class ConnectionManager: Returns: Number of connected participants. """ - async with get_redis() as redis: + async with self._get_redis() as redis: game_conns_key = self._game_conns_key(game_id) return await redis.scard(game_conns_key) diff --git a/backend/app/services/game_service.py b/backend/app/services/game_service.py index 5578ce4..105b8f1 100644 --- a/backend/app/services/game_service.py +++ b/backend/app/services/game_service.py @@ -36,6 +36,7 @@ Example: """ import logging +from collections.abc import Callable from dataclasses import dataclass, field from typing import Any from uuid import UUID @@ -51,6 +52,9 @@ from app.services.game_state_manager import GameStateManager, game_state_manager logger = logging.getLogger(__name__) +# Type alias for engine factory - takes GameState, returns GameEngine +EngineFactory = Callable[[GameState], GameEngine] + # ============================================================================= # Exceptions @@ -180,12 +184,14 @@ class GameService: Attributes: _state_manager: GameStateManager for persistence. _card_service: CardService for card definitions. + _engine_factory: Factory function to create GameEngine instances. """ def __init__( self, state_manager: GameStateManager | None = None, card_service: CardService | None = None, + engine_factory: EngineFactory | None = None, ) -> None: """Initialize the GameService. @@ -196,12 +202,16 @@ class GameService: Args: state_manager: GameStateManager instance. Uses global if not provided. card_service: CardService instance. Uses global if not provided. + engine_factory: Optional factory for creating GameEngine instances. + If not provided, uses the default _default_engine_factory method. + Useful for testing with mock engines. """ self._state_manager = state_manager or game_state_manager self._card_service = card_service or get_card_service() + self._engine_factory = engine_factory or self._default_engine_factory - def _create_engine_for_game(self, game: GameState) -> GameEngine: - """Create a GameEngine configured for a specific game's rules. + def _default_engine_factory(self, game: GameState) -> GameEngine: + """Default factory for creating a GameEngine from game state. The engine is created on-demand using the rules stored in the game state. This ensures each game uses its own configuration. @@ -412,8 +422,8 @@ class GameService: if not isinstance(action, ResignAction) and state.current_player_id != player_id: raise NotPlayerTurnError(game_id, player_id, state.current_player_id) - # Create engine with this game's rules - engine = self._create_engine_for_game(state) + # Create engine with this game's rules via factory + engine = self._engine_factory(state) # Execute the action result: ActionResult = await engine.execute_action(state, player_id, action) diff --git a/backend/app/socketio/README.md b/backend/app/socketio/README.md new file mode 100644 index 0000000..1477885 --- /dev/null +++ b/backend/app/socketio/README.md @@ -0,0 +1,120 @@ +# Socket.IO Module + +Real-time WebSocket communication for active game sessions. + +## Architecture + +``` +Client <--WebSocket--> Socket.IO Server <--> GameService <--> GameEngine + | + v + ConnectionManager <--> Redis (presence tracking) +``` + +## Components + +| Component | Responsibility | DI Pattern | +|-----------|---------------|------------| +| `server.py` | Socket.IO server setup, event handlers | Uses `auth_handler` | +| `auth.py` | `AuthHandler` class for JWT authentication | `token_verifier`, `conn_manager` | + +## AuthHandler + +Handles JWT-based authentication with injectable dependencies: + +```python +class AuthHandler: + def __init__( + self, + token_verifier: TokenVerifier | None = None, # Callable[[str], UUID | None] + conn_manager: ConnectionManager | None = None, + ) -> None: + self._token_verifier = token_verifier or verify_access_token + self._connection_manager = conn_manager or connection_manager + +# Global singleton for production +auth_handler = AuthHandler() + +# Tests inject mocks via constructor +handler = AuthHandler( + token_verifier=lambda token: user_id, + conn_manager=mock_connection_manager, +) +``` + +### Methods + +| Method | Purpose | +|--------|---------| +| `authenticate_connection(sid, auth)` | Validate JWT, return `AuthResult` | +| `setup_authenticated_session(sio, sid, user_id)` | Save session, register connection | +| `cleanup_authenticated_session(sid)` | Unregister connection, return user_id | + +## Event Handlers + +All handlers are in the `/game` namespace: + +| Event | Direction | Purpose | +|-------|-----------|---------| +| `connect` | Client→Server | Authenticate and establish session | +| `disconnect` | Client→Server | Clean up session and notify opponent | +| `game:join` | Client→Server | Join/rejoin a game session | +| `game:action` | Client→Server | Execute a game action | +| `game:resign` | Client→Server | Resign from game | +| `game:heartbeat` | Client→Server | Keep connection alive | + +## Client Connection + +```javascript +const socket = io("wss://api.example.com", { + auth: { token: "JWT_ACCESS_TOKEN" } +}); + +socket.on("connect", () => { + socket.emit("game:join", { game_id: "uuid" }); +}); + +socket.on("game:state", (state) => { + // Full game state update +}); + +socket.on("auth_error", (error) => { + // { code: "invalid_token", message: "..." } +}); +``` + +## Testing + +Use dependency injection - no monkey patching: + +```python +@pytest.fixture +def mock_token_verifier(): + return MagicMock(return_value=uuid4()) + +@pytest.fixture +def mock_connection_manager(): + cm = AsyncMock() + cm.register_connection = AsyncMock() + return cm + +@pytest.fixture +def auth_handler(mock_token_verifier, mock_connection_manager): + return AuthHandler( + token_verifier=mock_token_verifier, + conn_manager=mock_connection_manager, + ) + +async def test_authenticate_success(auth_handler, mock_token_verifier): + """Test successful authentication with valid JWT.""" + result = await auth_handler.authenticate_connection("sid", {"token": "valid"}) + assert result.success + mock_token_verifier.assert_called_once_with("valid") +``` + +## See Also + +- `app/services/connection_manager.py` - Connection presence tracking +- `app/services/game_service.py` - Game orchestration +- `app/schemas/ws_messages.py` - WebSocket message schemas +- `CLAUDE.md` - Architecture guidelines diff --git a/backend/app/socketio/__init__.py b/backend/app/socketio/__init__.py index 1e5469f..aae08b2 100644 --- a/backend/app/socketio/__init__.py +++ b/backend/app/socketio/__init__.py @@ -23,12 +23,12 @@ Usage: """ from app.socketio.auth import ( + AuthHandler, AuthResult, - authenticate_connection, - cleanup_authenticated_session, + auth_handler, + extract_token, get_session_user_id, require_auth, - setup_authenticated_session, ) from app.socketio.server import create_socketio_app, sio @@ -37,10 +37,10 @@ __all__ = [ "create_socketio_app", "sio", # Auth + "AuthHandler", "AuthResult", - "authenticate_connection", - "cleanup_authenticated_session", + "auth_handler", + "extract_token", "get_session_user_id", "require_auth", - "setup_authenticated_session", ] diff --git a/backend/app/socketio/auth.py b/backend/app/socketio/auth.py index 125afec..9dfc797 100644 --- a/backend/app/socketio/auth.py +++ b/backend/app/socketio/auth.py @@ -16,7 +16,7 @@ Session Data: Example: # In connect handler: - auth_result = await authenticate_connection(sid, auth) + auth_result = await auth_handler.authenticate_connection(sid, auth) if not auth_result.success: return False # Reject connection @@ -26,15 +26,19 @@ Example: """ import logging +from collections.abc import Callable from dataclasses import dataclass from datetime import UTC, datetime from uuid import UUID -from app.services.connection_manager import connection_manager +from app.services.connection_manager import ConnectionManager, connection_manager from app.services.jwt_service import verify_access_token logger = logging.getLogger(__name__) +# Type alias for token verifier - takes token string, returns user_id or None +TokenVerifier = Callable[[str], UUID | None] + @dataclass class AuthResult: @@ -53,6 +57,133 @@ class AuthResult: error_message: str | None = None +class AuthHandler: + """Handler for Socket.IO authentication with injectable dependencies. + + This class encapsulates authentication logic with dependencies injected + via constructor, making it testable without monkey patching. + + Attributes: + _token_verifier: Function to verify JWT tokens. + _connection_manager: ConnectionManager for tracking connections. + """ + + def __init__( + self, + token_verifier: TokenVerifier | None = None, + conn_manager: ConnectionManager | None = None, + ) -> None: + """Initialize AuthHandler with dependencies. + + Args: + token_verifier: Function to verify tokens. Uses jwt_service if not provided. + conn_manager: ConnectionManager instance. Uses global if not provided. + """ + self._token_verifier = token_verifier or verify_access_token + self._connection_manager = conn_manager or connection_manager + + async def authenticate_connection( + self, + sid: str, + auth: dict[str, object] | None, + ) -> AuthResult: + """Authenticate a Socket.IO connection using JWT. + + Extracts the JWT from auth data, validates it, and returns the result. + Does NOT modify socket session - caller should handle that. + + Args: + sid: Socket session ID (for logging). + auth: Authentication data from connect event. + + Returns: + AuthResult with success status and user_id or error details. + """ + token = extract_token(auth) + + if token is None: + logger.debug(f"Connection {sid}: No token provided") + return AuthResult( + success=False, + error_code="missing_token", + error_message="Authentication token required", + ) + + user_id = self._token_verifier(token) + + if user_id is None: + logger.debug(f"Connection {sid}: Invalid or expired token") + return AuthResult( + success=False, + error_code="invalid_token", + error_message="Invalid or expired token", + ) + + logger.debug(f"Connection {sid}: Authenticated as user {user_id}") + return AuthResult( + success=True, + user_id=user_id, + ) + + async def setup_authenticated_session( + self, + sio: object, + sid: str, + user_id: UUID, + namespace: str = "/game", + ) -> None: + """Set up socket session with authenticated user data. + + Saves user_id and authentication timestamp to the socket session, + and registers the connection with ConnectionManager. + + Args: + sio: Socket.IO AsyncServer instance (required). + sid: Socket session ID. + user_id: Authenticated user's UUID. + namespace: Socket.IO namespace. + """ + session_data = { + "user_id": str(user_id), + "authenticated_at": datetime.now(UTC).isoformat(), + } + await sio.save_session(sid, session_data, namespace=namespace) + + await self._connection_manager.register_connection(sid, user_id) + + logger.info(f"Session established: sid={sid}, user_id={user_id}") + + async def cleanup_authenticated_session( + self, + sid: str, + namespace: str = "/game", + ) -> str | None: + """Clean up session data on disconnect. + + Unregisters the connection from ConnectionManager and returns + the user_id for any additional cleanup needed. + + Args: + sid: Socket session ID. + namespace: Socket.IO namespace. + + Returns: + user_id if session was authenticated, None otherwise. + """ + conn_info = await self._connection_manager.unregister_connection(sid) + + if conn_info: + logger.info(f"Session cleaned up: sid={sid}, user_id={conn_info.user_id}") + return conn_info.user_id + + logger.debug(f"No session to clean up for {sid}") + return None + + +# Global singleton instance +auth_handler = AuthHandler() + + def extract_token(auth: dict[str, object] | None) -> str | None: """Extract JWT token from Socket.IO auth data. @@ -96,131 +227,6 @@ def extract_token(auth: dict[str, object] | None) -> str | None: return None -async def authenticate_connection( - sid: str, - auth: dict[str, object] | None, -) -> AuthResult: - """Authenticate a Socket.IO connection using JWT. - - Extracts the JWT from auth data, validates it, and returns the result. - Does NOT modify socket session - caller should handle that. - - Args: - sid: Socket session ID (for logging). - auth: Authentication data from connect event. - - Returns: - AuthResult with success status and user_id or error details. - - Example: - result = await authenticate_connection(sid, auth) - if result.success: - await sio.save_session(sid, {"user_id": str(result.user_id)}) - else: - logger.warning(f"Auth failed: {result.error_message}") - return False # Reject connection - """ - # Extract token from auth data - token = extract_token(auth) - - if token is None: - logger.debug(f"Connection {sid}: No token provided") - return AuthResult( - success=False, - error_code="missing_token", - error_message="Authentication token required", - ) - - # Validate the token - user_id = verify_access_token(token) - - if user_id is None: - logger.debug(f"Connection {sid}: Invalid or expired token") - return AuthResult( - success=False, - error_code="invalid_token", - error_message="Invalid or expired token", - ) - - logger.debug(f"Connection {sid}: Authenticated as user {user_id}") - return AuthResult( - success=True, - user_id=user_id, - ) - - -async def setup_authenticated_session( - sio: object, - sid: str, - user_id: UUID, - namespace: str = "/game", -) -> None: - """Set up socket session with authenticated user data. - - Saves user_id and authentication timestamp to the socket session, - and registers the connection with ConnectionManager. - - Args: - sio: Socket.IO AsyncServer instance. - sid: Socket session ID. - user_id: Authenticated user's UUID. - namespace: Socket.IO namespace. - - Example: - if auth_result.success: - await setup_authenticated_session(sio, sid, auth_result.user_id) - """ - # Import here to avoid circular dependency - from app.socketio.server import sio as server_sio - - # Use provided sio or fall back to server sio - socket_server = sio if sio is not None else server_sio - - # Save to socket session - session_data = { - "user_id": str(user_id), - "authenticated_at": datetime.now(UTC).isoformat(), - } - await socket_server.save_session(sid, session_data, namespace=namespace) - - # Register with ConnectionManager - await connection_manager.register_connection(sid, user_id) - - logger.info(f"Session established: sid={sid}, user_id={user_id}") - - -async def cleanup_authenticated_session( - sid: str, - namespace: str = "/game", -) -> str | None: - """Clean up session data on disconnect. - - Unregisters the connection from ConnectionManager and returns - the user_id for any additional cleanup needed. - - Args: - sid: Socket session ID. - namespace: Socket.IO namespace. - - Returns: - user_id if session was authenticated, None otherwise. - - Example: - user_id = await cleanup_authenticated_session(sid) - if user_id: - # Notify opponent, etc. - """ - # Unregister from ConnectionManager - conn_info = await connection_manager.unregister_connection(sid) - - if conn_info: - logger.info(f"Session cleaned up: sid={sid}, user_id={conn_info.user_id}") - return conn_info.user_id - - logger.debug(f"No session to clean up for {sid}") - return None - - async def get_session_user_id( sio: object, sid: str, diff --git a/backend/app/socketio/server.py b/backend/app/socketio/server.py index 3f4127c..b409e8e 100644 --- a/backend/app/socketio/server.py +++ b/backend/app/socketio/server.py @@ -33,12 +33,7 @@ from typing import TYPE_CHECKING import socketio from app.config import settings -from app.socketio.auth import ( - authenticate_connection, - cleanup_authenticated_session, - require_auth, - setup_authenticated_session, -) +from app.socketio.auth import auth_handler, require_auth if TYPE_CHECKING: from fastapi import FastAPI @@ -84,7 +79,7 @@ async def connect( None is treated as True (accept). """ # Authenticate the connection - auth_result = await authenticate_connection(sid, auth) + auth_result = await auth_handler.authenticate_connection(sid, auth) if not auth_result.success: logger.warning( @@ -103,7 +98,7 @@ async def connect( return False # Set up authenticated session and register connection - await setup_authenticated_session(sio, sid, auth_result.user_id, namespace="/game") + await auth_handler.setup_authenticated_session(sio, sid, auth_result.user_id, namespace="/game") logger.info(f"Client authenticated to /game: sid={sid}, user_id={auth_result.user_id}") return True @@ -119,7 +114,7 @@ async def disconnect(sid: str) -> None: sid: Socket session ID of disconnecting client. """ # Clean up session and get user info - user_id = await cleanup_authenticated_session(sid, namespace="/game") + user_id = await auth_handler.cleanup_authenticated_session(sid, namespace="/game") if user_id: logger.info(f"Client disconnected from /game: sid={sid}, user_id={user_id}") diff --git a/backend/tests/socketio/test_auth.py b/backend/tests/socketio/test_auth.py index 073c1c0..422a4fa 100644 --- a/backend/tests/socketio/test_auth.py +++ b/backend/tests/socketio/test_auth.py @@ -4,19 +4,17 @@ This module tests JWT-based authentication for WebSocket connections, including token extraction, validation, and session management. """ -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 import pytest from app.socketio.auth import ( + AuthHandler, AuthResult, - authenticate_connection, - cleanup_authenticated_session, extract_token, get_session_user_id, require_auth, - setup_authenticated_session, ) @@ -106,95 +104,143 @@ class TestExtractToken: class TestAuthenticateConnection: - """Tests for connection authentication.""" + """Tests for connection authentication via AuthHandler.""" - @pytest.mark.asyncio - async def test_authenticate_success_with_valid_token(self) -> None: + @pytest.fixture + def mock_token_verifier(self) -> MagicMock: + """Create a mock token verifier function.""" + return MagicMock() + + @pytest.fixture + def mock_connection_manager(self) -> AsyncMock: + """Create a mock ConnectionManager.""" + cm = AsyncMock() + cm.register_connection = AsyncMock() + cm.unregister_connection = AsyncMock() + return cm + + @pytest.fixture + def auth_handler( + self, + mock_token_verifier: MagicMock, + mock_connection_manager: AsyncMock, + ) -> AuthHandler: + """Create AuthHandler with injected mocks.""" + return AuthHandler( + token_verifier=mock_token_verifier, + conn_manager=mock_connection_manager, + ) + + async def test_authenticate_success_with_valid_token( + self, + auth_handler: AuthHandler, + mock_token_verifier: MagicMock, + ) -> None: """Test successful authentication with a valid JWT. A valid access token should result in AuthResult with success=True and the user_id from the token. """ user_id = uuid4() + mock_token_verifier.return_value = user_id - with patch("app.socketio.auth.verify_access_token") as mock_verify: - mock_verify.return_value = user_id + result = await auth_handler.authenticate_connection("test-sid", {"token": "valid-token"}) - result = await authenticate_connection("test-sid", {"token": "valid-token"}) + assert result.success is True + assert result.user_id == user_id + assert result.error_code is None + mock_token_verifier.assert_called_once_with("valid-token") - assert result.success is True - assert result.user_id == user_id - assert result.error_code is None - mock_verify.assert_called_once_with("valid-token") - - @pytest.mark.asyncio - async def test_authenticate_fails_with_missing_token(self) -> None: + async def test_authenticate_fails_with_missing_token( + self, + auth_handler: AuthHandler, + ) -> None: """Test authentication failure when no token is provided. Connections without any auth data should fail with a 'missing_token' error code. """ - result = await authenticate_connection("test-sid", None) + result = await auth_handler.authenticate_connection("test-sid", None) assert result.success is False assert result.user_id is None assert result.error_code == "missing_token" assert "required" in result.error_message.lower() - @pytest.mark.asyncio - async def test_authenticate_fails_with_empty_auth(self) -> None: + async def test_authenticate_fails_with_empty_auth( + self, + auth_handler: AuthHandler, + ) -> None: """Test authentication failure with empty auth object. An auth object without any token fields should fail. """ - result = await authenticate_connection("test-sid", {}) + result = await auth_handler.authenticate_connection("test-sid", {}) assert result.success is False assert result.error_code == "missing_token" - @pytest.mark.asyncio - async def test_authenticate_fails_with_invalid_token(self) -> None: + async def test_authenticate_fails_with_invalid_token( + self, + auth_handler: AuthHandler, + mock_token_verifier: MagicMock, + ) -> None: """Test authentication failure with invalid/expired JWT. When verify_access_token returns None (invalid token), we should fail with 'invalid_token' error. """ - with patch("app.socketio.auth.verify_access_token") as mock_verify: - mock_verify.return_value = None # Token validation failed + mock_token_verifier.return_value = None - result = await authenticate_connection("test-sid", {"token": "invalid-token"}) + result = await auth_handler.authenticate_connection("test-sid", {"token": "invalid-token"}) - assert result.success is False - assert result.user_id is None - assert result.error_code == "invalid_token" - assert ( - "invalid" in result.error_message.lower() - or "expired" in result.error_message.lower() - ) + assert result.success is False + assert result.user_id is None + assert result.error_code == "invalid_token" + assert ( + "invalid" in result.error_message.lower() or "expired" in result.error_message.lower() + ) - @pytest.mark.asyncio - async def test_authenticate_extracts_token_from_bearer(self) -> None: + async def test_authenticate_extracts_token_from_bearer( + self, + auth_handler: AuthHandler, + mock_token_verifier: MagicMock, + ) -> None: """Test that authentication works with Bearer format. The authenticate function should handle Bearer token format through the extract_token function. """ user_id = uuid4() + mock_token_verifier.return_value = user_id - with patch("app.socketio.auth.verify_access_token") as mock_verify: - mock_verify.return_value = user_id + result = await auth_handler.authenticate_connection( + "test-sid", {"authorization": "Bearer my-token"} + ) - result = await authenticate_connection("test-sid", {"authorization": "Bearer my-token"}) - - assert result.success is True - mock_verify.assert_called_once_with("my-token") + assert result.success is True + mock_token_verifier.assert_called_once_with("my-token") class TestSetupAuthenticatedSession: """Tests for session setup after authentication.""" - @pytest.mark.asyncio - async def test_setup_saves_session_data(self) -> None: + @pytest.fixture + def mock_connection_manager(self) -> AsyncMock: + """Create a mock ConnectionManager.""" + cm = AsyncMock() + cm.register_connection = AsyncMock() + return cm + + @pytest.fixture + def auth_handler(self, mock_connection_manager: AsyncMock) -> AuthHandler: + """Create AuthHandler with injected mock.""" + return AuthHandler(conn_manager=mock_connection_manager) + + async def test_setup_saves_session_data( + self, + auth_handler: AuthHandler, + ) -> None: """Test that session setup saves user_id and timestamp. After authentication, the socket session should contain @@ -203,22 +249,21 @@ class TestSetupAuthenticatedSession: user_id = uuid4() mock_sio = AsyncMock() - with patch("app.socketio.auth.connection_manager") as mock_cm: - mock_cm.register_connection = AsyncMock() + await auth_handler.setup_authenticated_session(mock_sio, "test-sid", user_id) - await setup_authenticated_session(mock_sio, "test-sid", user_id) + mock_sio.save_session.assert_called_once() + call_args = mock_sio.save_session.call_args + assert call_args.args[0] == "test-sid" - # Verify session was saved - mock_sio.save_session.assert_called_once() - call_args = mock_sio.save_session.call_args - assert call_args.args[0] == "test-sid" + session_data = call_args.args[1] + assert session_data["user_id"] == str(user_id) + assert "authenticated_at" in session_data - session_data = call_args.args[1] - assert session_data["user_id"] == str(user_id) - assert "authenticated_at" in session_data - - @pytest.mark.asyncio - async def test_setup_registers_with_connection_manager(self) -> None: + async def test_setup_registers_with_connection_manager( + self, + auth_handler: AuthHandler, + mock_connection_manager: AsyncMock, + ) -> None: """Test that session setup registers with ConnectionManager. The connection should be tracked in ConnectionManager for @@ -227,53 +272,63 @@ class TestSetupAuthenticatedSession: user_id = uuid4() mock_sio = AsyncMock() - with patch("app.socketio.auth.connection_manager") as mock_cm: - mock_cm.register_connection = AsyncMock() + await auth_handler.setup_authenticated_session(mock_sio, "test-sid", user_id) - await setup_authenticated_session(mock_sio, "test-sid", user_id) - - mock_cm.register_connection.assert_called_once_with("test-sid", user_id) + mock_connection_manager.register_connection.assert_called_once_with("test-sid", user_id) class TestCleanupAuthenticatedSession: """Tests for session cleanup on disconnect.""" - @pytest.mark.asyncio - async def test_cleanup_unregisters_connection(self) -> None: + @pytest.fixture + def mock_connection_manager(self) -> AsyncMock: + """Create a mock ConnectionManager.""" + return AsyncMock() + + @pytest.fixture + def auth_handler(self, mock_connection_manager: AsyncMock) -> AuthHandler: + """Create AuthHandler with injected mock.""" + return AuthHandler(conn_manager=mock_connection_manager) + + async def test_cleanup_unregisters_connection( + self, + auth_handler: AuthHandler, + mock_connection_manager: AsyncMock, + ) -> None: """Test that cleanup unregisters from ConnectionManager. On disconnect, the connection should be removed from ConnectionManager to update presence tracking. """ - with patch("app.socketio.auth.connection_manager") as mock_cm: - mock_conn_info = MagicMock() - mock_conn_info.user_id = "user-123" - mock_cm.unregister_connection = AsyncMock(return_value=mock_conn_info) + mock_conn_info = MagicMock() + mock_conn_info.user_id = "user-123" + mock_connection_manager.unregister_connection = AsyncMock(return_value=mock_conn_info) - result = await cleanup_authenticated_session("test-sid") + result = await auth_handler.cleanup_authenticated_session("test-sid") - assert result == "user-123" - mock_cm.unregister_connection.assert_called_once_with("test-sid") + assert result == "user-123" + mock_connection_manager.unregister_connection.assert_called_once_with("test-sid") - @pytest.mark.asyncio - async def test_cleanup_returns_none_for_unknown_session(self) -> None: + async def test_cleanup_returns_none_for_unknown_session( + self, + auth_handler: AuthHandler, + mock_connection_manager: AsyncMock, + ) -> None: """Test cleanup returns None for non-existent sessions. If the connection wasn't registered (e.g., auth failed), cleanup should return None gracefully. """ - with patch("app.socketio.auth.connection_manager") as mock_cm: - mock_cm.unregister_connection = AsyncMock(return_value=None) + mock_connection_manager.unregister_connection = AsyncMock(return_value=None) - result = await cleanup_authenticated_session("unknown-sid") + result = await auth_handler.cleanup_authenticated_session("unknown-sid") - assert result is None + assert result is None class TestGetSessionUserId: """Tests for session user_id retrieval.""" - @pytest.mark.asyncio async def test_get_session_user_id_returns_id(self) -> None: """Test retrieving user_id from authenticated session. @@ -290,7 +345,6 @@ class TestGetSessionUserId: assert result == "user-123" mock_sio.get_session.assert_called_once_with("test-sid", namespace="/game") - @pytest.mark.asyncio async def test_get_session_user_id_returns_none_for_missing(self) -> None: """Test that missing session returns None. @@ -304,7 +358,6 @@ class TestGetSessionUserId: assert result is None - @pytest.mark.asyncio async def test_get_session_user_id_handles_exception(self) -> None: """Test that exceptions are caught and return None. @@ -322,7 +375,6 @@ class TestGetSessionUserId: class TestRequireAuth: """Tests for the require_auth helper.""" - @pytest.mark.asyncio async def test_require_auth_returns_user_id_for_authenticated(self) -> None: """Test that require_auth returns user_id for valid sessions. @@ -336,7 +388,6 @@ class TestRequireAuth: assert result == "user-123" - @pytest.mark.asyncio async def test_require_auth_returns_none_for_unauthenticated(self) -> None: """Test that require_auth returns None for unauthenticated sessions. diff --git a/backend/tests/unit/services/test_connection_manager.py b/backend/tests/unit/services/test_connection_manager.py index 7fbbf84..f7fc08e 100644 --- a/backend/tests/unit/services/test_connection_manager.py +++ b/backend/tests/unit/services/test_connection_manager.py @@ -1,12 +1,16 @@ """Tests for ConnectionManager service. This module tests WebSocket connection tracking with Redis. Since these are -unit tests, we mock the Redis operations to test the ConnectionManager logic +unit tests, we inject a mock Redis factory to test the ConnectionManager logic without requiring a real Redis instance. + +Uses dependency injection pattern - no monkey patching required. """ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from datetime import UTC, datetime, timedelta -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock from uuid import uuid4 import pytest @@ -21,15 +25,13 @@ from app.services.connection_manager import ( ) -@pytest.fixture -def manager() -> ConnectionManager: - """Create a ConnectionManager instance for testing.""" - return ConnectionManager(conn_ttl_seconds=3600) - - @pytest.fixture def mock_redis() -> AsyncMock: - """Create a mock Redis client.""" + """Create a mock Redis client. + + This mock provides all the Redis operations used by ConnectionManager + with sensible defaults that can be overridden per-test. + """ redis = AsyncMock() redis.hset = AsyncMock() redis.hget = AsyncMock() @@ -46,6 +48,21 @@ def mock_redis() -> AsyncMock: return redis +@pytest.fixture +def manager(mock_redis: AsyncMock) -> ConnectionManager: + """Create a ConnectionManager with injected mock Redis. + + The mock Redis is injected via the redis_factory parameter, + eliminating the need for monkey patching in tests. + """ + + @asynccontextmanager + async def mock_redis_factory() -> AsyncIterator[AsyncMock]: + yield mock_redis + + return ConnectionManager(conn_ttl_seconds=3600, redis_factory=mock_redis_factory) + + class TestConnectionInfoDataclass: """Tests for the ConnectionInfo dataclass.""" @@ -87,8 +104,8 @@ class TestConnectionInfoDataclass: def test_is_stale_with_custom_threshold(self) -> None: """Test is_stale with a custom threshold. - The threshold can be adjusted for different use cases like - more aggressive cleanup or more lenient timeout. + The threshold parameter allows callers to customize what counts + as stale for different use cases. """ now = datetime.now(UTC) last_seen = now - timedelta(seconds=60) @@ -96,7 +113,7 @@ class TestConnectionInfoDataclass: sid="test-sid", user_id="user-123", game_id=None, - connected_at=now, + connected_at=last_seen, last_seen=last_seen, ) @@ -125,23 +142,20 @@ class TestRegisterConnection: sid = "test-sid-123" user_id = str(uuid4()) - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + await manager.register_connection(sid, user_id) - await manager.register_connection(sid, user_id) + # Verify connection hash was created + conn_key = f"{CONN_PREFIX}{sid}" + mock_redis.hset.assert_called() + hset_call = mock_redis.hset.call_args + assert hset_call.args[0] == conn_key + mapping = hset_call.kwargs["mapping"] + assert mapping["user_id"] == user_id + assert mapping["game_id"] == "" - # Verify connection hash was created - conn_key = f"{CONN_PREFIX}{sid}" - mock_redis.hset.assert_called() - hset_call = mock_redis.hset.call_args - assert hset_call.args[0] == conn_key - mapping = hset_call.kwargs["mapping"] - assert mapping["user_id"] == user_id - assert mapping["game_id"] == "" - - # Verify user-to-connection mapping was created - user_conn_key = f"{USER_CONN_PREFIX}{user_id}" - mock_redis.set.assert_called_with(user_conn_key, sid) + # Verify user-to-connection mapping was created + user_conn_key = f"{USER_CONN_PREFIX}{user_id}" + mock_redis.set.assert_called_with(user_conn_key, sid) @pytest.mark.asyncio async def test_register_connection_replaces_old_connection( @@ -167,15 +181,12 @@ class TestRegisterConnection: "last_seen": datetime.now(UTC).isoformat(), } - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + await manager.register_connection(new_sid, user_id) - await manager.register_connection(new_sid, user_id) - - # Verify old connection was cleaned up (delete called for old conn key) - old_conn_key = f"{CONN_PREFIX}{old_sid}" - delete_calls = [call.args[0] for call in mock_redis.delete.call_args_list] - assert old_conn_key in delete_calls + # Verify old connection was cleaned up (delete called for old conn key) + old_conn_key = f"{CONN_PREFIX}{old_sid}" + delete_calls = [call.args[0] for call in mock_redis.delete.call_args_list] + assert old_conn_key in delete_calls @pytest.mark.asyncio async def test_register_connection_accepts_uuid( @@ -191,15 +202,12 @@ class TestRegisterConnection: sid = "test-sid" user_uuid = uuid4() - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + await manager.register_connection(sid, user_uuid) - await manager.register_connection(sid, user_uuid) - - # Verify user_id was converted to string - hset_call = mock_redis.hset.call_args - mapping = hset_call.kwargs["mapping"] - assert mapping["user_id"] == str(user_uuid) + # Verify user_id was converted to string + hset_call = mock_redis.hset.call_args + mapping = hset_call.kwargs["mapping"] + assert mapping["user_id"] == str(user_uuid) class TestUnregisterConnection: @@ -213,17 +221,14 @@ class TestUnregisterConnection: ) -> None: """Test that unregistering unknown connection returns None. - If the connection doesn't exist, we should return None rather than - raising an error. + Unregistering a non-existent connection should be a no-op + and return None to indicate nothing was cleaned up. """ mock_redis.hgetall.return_value = {} - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + result = await manager.unregister_connection("unknown-sid") - result = await manager.unregister_connection("unknown-sid") - - assert result is None + assert result is None @pytest.mark.asyncio async def test_unregister_cleans_up_all_data( @@ -231,44 +236,39 @@ class TestUnregisterConnection: manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: - """Test that unregistering cleans up all related Redis data. + """Test that unregistering cleans up all Redis data. - Cleanup should remove: + Unregistration should remove: 1. Connection hash - 2. User-to-connection mapping (if it still points to this sid) - 3. Game connection set membership + 2. User-to-connection mapping + 3. Remove from game connection set (if in a game) """ sid = "test-sid" user_id = "user-123" game_id = "game-456" + now = datetime.now(UTC) - # Mock: connection exists with game mock_redis.hgetall.return_value = { "user_id": user_id, "game_id": game_id, - "connected_at": datetime.now(UTC).isoformat(), - "last_seen": datetime.now(UTC).isoformat(), + "connected_at": now.isoformat(), + "last_seen": now.isoformat(), } - mock_redis.get.return_value = sid # user mapping points to this sid + mock_redis.get.return_value = sid # User's current connection - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + result = await manager.unregister_connection(sid) - result = await manager.unregister_connection(sid) + assert result is not None + assert result.user_id == user_id + assert result.game_id == game_id - # Verify result - assert result is not None - assert result.sid == sid - assert result.user_id == user_id - assert result.game_id == game_id - - # Verify cleanup - mock_redis.srem.assert_called() # Removed from game set - mock_redis.delete.assert_called() # Connection deleted + # Verify cleanup calls + mock_redis.delete.assert_called() + mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid) class TestGameAssociation: - """Tests for game join/leave operations.""" + """Tests for game association methods.""" @pytest.mark.asyncio async def test_join_game_adds_to_game_set( @@ -276,25 +276,22 @@ class TestGameAssociation: manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: - """Test that joining a game adds the connection to the game's set. + """Test that joining a game adds connection to game's set. - When a connection joins a game, it should: - 1. Update the connection's game_id - 2. Add the sid to the game's connection set + The connection should be tracked in the game's connection set + for broadcasting updates to all participants. """ sid = "test-sid" game_id = "game-123" mock_redis.exists.return_value = True + mock_redis.hget.return_value = "" # Not in a game yet - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + result = await manager.join_game(sid, game_id) - result = await manager.join_game(sid, game_id) - - assert result is True - mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", game_id) - mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid) + assert result is True + mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid) + mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", game_id) @pytest.mark.asyncio async def test_join_game_returns_false_for_unknown_connection( @@ -304,18 +301,13 @@ class TestGameAssociation: ) -> None: """Test that joining a game fails for unknown connections. - If the connection doesn't exist, we should return False rather - than creating orphan game association data. + Non-existent connections should not be able to join games. """ mock_redis.exists.return_value = False - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + result = await manager.join_game("unknown-sid", "game-123") - result = await manager.join_game("unknown-sid", "game-123") - - assert result is False - mock_redis.sadd.assert_not_called() + assert result is False @pytest.mark.asyncio async def test_join_game_leaves_previous_game( @@ -325,25 +317,22 @@ class TestGameAssociation: ) -> None: """Test that joining a new game leaves the previous game. - When switching games, the connection should be removed from the - old game's connection set before joining the new one. + A connection can only be in one game at a time, so joining + a new game should automatically leave any previous game. """ sid = "test-sid" - old_game = "game-old" - new_game = "game-new" + old_game = "old-game" + new_game = "new-game" mock_redis.exists.return_value = True - mock_redis.hget.return_value = old_game # Currently in old game + mock_redis.hget.return_value = old_game - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + await manager.join_game(sid, new_game) - await manager.join_game(sid, new_game) - - # Verify left old game - mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{old_game}", sid) - # Verify joined new game - mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{new_game}", sid) + # Should remove from old game + mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{old_game}", sid) + # Should add to new game + mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{new_game}", sid) @pytest.mark.asyncio async def test_leave_game_removes_from_set( @@ -351,25 +340,21 @@ class TestGameAssociation: manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: - """Test that leaving a game removes the connection from the set. + """Test that leaving a game removes connection from set. - Leave should: - 1. Remove sid from game's connection set - 2. Clear game_id on connection record + The connection should be removed from the game's set and + its game_id should be cleared. """ sid = "test-sid" game_id = "game-123" mock_redis.hget.return_value = game_id - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + result = await manager.leave_game(sid) - result = await manager.leave_game(sid) - - assert result == game_id - mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid) - mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "") + assert result == game_id + mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid) + mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "") @pytest.mark.asyncio async def test_leave_game_returns_none_when_not_in_game( @@ -377,20 +362,16 @@ class TestGameAssociation: manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: - """Test that leave_game returns None when not in a game. + """Test that leaving when not in a game returns None. - If the connection isn't associated with any game, we should - return None without making unnecessary Redis calls. + If the connection isn't in a game, leave_game should return + None to indicate no game was left. """ - mock_redis.hget.return_value = "" # No game_id + mock_redis.hget.return_value = "" - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + result = await manager.leave_game("test-sid") - result = await manager.leave_game("test-sid") - - assert result is None - mock_redis.srem.assert_not_called() + assert result is None class TestHeartbeat: @@ -404,25 +385,18 @@ class TestHeartbeat: ) -> None: """Test that heartbeat updates the last_seen timestamp. - Heartbeats keep the connection alive by updating the timestamp - and refreshing TTLs on Redis records. + Heartbeats keep connections alive by updating timestamps + and refreshing TTLs on Redis keys. """ sid = "test-sid" mock_redis.exists.return_value = True mock_redis.hget.return_value = "user-123" - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + result = await manager.update_heartbeat(sid) - result = await manager.update_heartbeat(sid) - - assert result is True - # Verify last_seen was updated - hset_call = mock_redis.hset.call_args - assert hset_call.args[0] == f"{CONN_PREFIX}{sid}" - assert hset_call.args[1] == "last_seen" - # Verify TTL was refreshed - mock_redis.expire.assert_called() + assert result is True + mock_redis.hset.assert_called() + mock_redis.expire.assert_called() @pytest.mark.asyncio async def test_update_heartbeat_returns_false_for_unknown( @@ -430,18 +404,16 @@ class TestHeartbeat: manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: - """Test that heartbeat returns False for unknown connections. + """Test that heartbeat fails for unknown connections. - If the connection doesn't exist, we shouldn't try to update it. + Unknown connections should not be able to send heartbeats, + returning False to indicate failure. """ mock_redis.exists.return_value = False - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + result = await manager.update_heartbeat("unknown-sid") - result = await manager.update_heartbeat("unknown-sid") - - assert result is False + assert result is False class TestQueryMethods: @@ -453,13 +425,13 @@ class TestQueryMethods: manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: - """Test that get_connection returns ConnectionInfo for valid sid. + """Test that get_connection returns ConnectionInfo. - The returned ConnectionInfo should have all fields populated - from the Redis hash. + The returned info should contain all the stored connection data. """ sid = "test-sid" now = datetime.now(UTC) + mock_redis.hgetall.return_value = { "user_id": "user-123", "game_id": "game-456", @@ -467,15 +439,12 @@ class TestQueryMethods: "last_seen": now.isoformat(), } - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + result = await manager.get_connection(sid) - result = await manager.get_connection(sid) - - assert result is not None - assert result.sid == sid - assert result.user_id == "user-123" - assert result.game_id == "game-456" + assert result is not None + assert result.sid == sid + assert result.user_id == "user-123" + assert result.game_id == "game-456" @pytest.mark.asyncio async def test_get_connection_returns_none_for_unknown( @@ -483,15 +452,15 @@ class TestQueryMethods: manager: ConnectionManager, mock_redis: AsyncMock, ) -> None: - """Test that get_connection returns None for unknown sid.""" + """Test that get_connection returns None for unknown sid. + + Non-existent connections should return None rather than raising. + """ mock_redis.hgetall.return_value = {} - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + result = await manager.get_connection("unknown-sid") - result = await manager.get_connection("unknown-sid") - - assert result is None + assert result is None @pytest.mark.asyncio async def test_is_user_online_returns_true_for_connected_user( @@ -501,7 +470,7 @@ class TestQueryMethods: ) -> None: """Test that is_user_online returns True for connected users. - A user is online if they have an active connection that isn't stale. + Users with active, non-stale connections should be considered online. """ user_id = "user-123" now = datetime.now(UTC) @@ -514,12 +483,9 @@ class TestQueryMethods: "last_seen": now.isoformat(), } - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + result = await manager.is_user_online(user_id) - result = await manager.is_user_online(user_id) - - assert result is True + assert result is True @pytest.mark.asyncio async def test_is_user_online_returns_false_for_stale_connection( @@ -529,11 +495,11 @@ class TestQueryMethods: ) -> None: """Test that is_user_online returns False for stale connections. - Even if a connection record exists, if it's stale (no recent heartbeat), - the user should be considered offline. + Users with stale connections (no recent heartbeat) should be + considered offline even if their connection record exists. """ user_id = "user-123" - old_time = datetime.now(UTC) - timedelta(minutes=5) + old_time = datetime.now(UTC) - timedelta(seconds=HEARTBEAT_INTERVAL_SECONDS * 4) mock_redis.get.return_value = "test-sid" mock_redis.hgetall.return_value = { @@ -543,12 +509,9 @@ class TestQueryMethods: "last_seen": old_time.isoformat(), } - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + result = await manager.is_user_online(user_id) - result = await manager.is_user_online(user_id) - - assert result is False + assert result is False @pytest.mark.asyncio async def test_get_game_connections_returns_all_participants( @@ -564,29 +527,32 @@ class TestQueryMethods: now = datetime.now(UTC) mock_redis.smembers.return_value = {"sid-1", "sid-2"} - mock_redis.hgetall.side_effect = [ - { - "user_id": "user-1", - "game_id": game_id, - "connected_at": now.isoformat(), - "last_seen": now.isoformat(), - }, - { - "user_id": "user-2", - "game_id": game_id, - "connected_at": now.isoformat(), - "last_seen": now.isoformat(), - }, - ] - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + # Use a function to return the right data based on the key queried + def hgetall_by_key(key: str) -> dict[str, str]: + if key == f"{CONN_PREFIX}sid-1": + return { + "user_id": "user-1", + "game_id": game_id, + "connected_at": now.isoformat(), + "last_seen": now.isoformat(), + } + elif key == f"{CONN_PREFIX}sid-2": + return { + "user_id": "user-2", + "game_id": game_id, + "connected_at": now.isoformat(), + "last_seen": now.isoformat(), + } + return {} - result = await manager.get_game_connections(game_id) + mock_redis.hgetall.side_effect = hgetall_by_key - assert len(result) == 2 - user_ids = {conn.user_id for conn in result} - assert user_ids == {"user-1", "user-2"} + result = await manager.get_game_connections(game_id) + + assert len(result) == 2 + user_ids = {conn.user_id for conn in result} + assert user_ids == {"user-1", "user-2"} @pytest.mark.asyncio async def test_get_opponent_sid_returns_other_player( @@ -628,38 +594,38 @@ class TestQueryMethods: mock_redis.hgetall.side_effect = hgetall_by_key - with patch("app.services.connection_manager.get_redis") as mock_get_redis: - mock_get_redis.return_value.__aenter__.return_value = mock_redis + result = await manager.get_opponent_sid(game_id, current_user) - result = await manager.get_opponent_sid(game_id, current_user) - - assert result == opponent_sid + assert result == opponent_sid class TestKeyGeneration: """Tests for Redis key generation methods.""" - def test_conn_key_format(self, manager: ConnectionManager) -> None: + def test_conn_key_format(self) -> None: """Test that connection keys have correct format. Keys should follow the pattern conn:{sid} for easy identification and pattern matching. """ + manager = ConnectionManager() key = manager._conn_key("test-sid-123") assert key == "conn:test-sid-123" - def test_user_conn_key_format(self, manager: ConnectionManager) -> None: + def test_user_conn_key_format(self) -> None: """Test that user connection keys have correct format. Keys should follow the pattern user_conn:{user_id}. """ + manager = ConnectionManager() key = manager._user_conn_key("user-456") assert key == "user_conn:user-456" - def test_game_conns_key_format(self, manager: ConnectionManager) -> None: + def test_game_conns_key_format(self) -> None: """Test that game connections keys have correct format. Keys should follow the pattern game_conns:{game_id}. """ + manager = ConnectionManager() key = manager._game_conns_key("game-789") assert key == "game_conns:game-789" diff --git a/backend/tests/unit/services/test_game_service.py b/backend/tests/unit/services/test_game_service.py index 53916e0..f7aa19d 100644 --- a/backend/tests/unit/services/test_game_service.py +++ b/backend/tests/unit/services/test_game_service.py @@ -7,9 +7,11 @@ The GameService is STATELESS regarding game rules: - No GameEngine is stored in the service - Engine is created per-operation using rules from GameState - Rules come from frontend at game creation, stored in GameState + +Uses dependency injection pattern - no monkey patching required. """ -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 import pytest @@ -53,19 +55,35 @@ def mock_card_service() -> MagicMock: return MagicMock() +@pytest.fixture +def mock_engine() -> MagicMock: + """Create a mock GameEngine. + + The engine handles game logic - executing actions and checking + win conditions. + """ + engine = MagicMock() + engine.execute_action = AsyncMock( + return_value=ActionResult(success=True, message="Action executed") + ) + return engine + + @pytest.fixture def game_service( mock_state_manager: AsyncMock, mock_card_service: MagicMock, + mock_engine: MagicMock, ) -> GameService: - """Create a GameService with mocked dependencies. + """Create a GameService with injected mock dependencies. - Note: No engine is passed - GameService creates engines per-operation - using rules stored in each game's GameState. + The mock engine factory is injected via the engine_factory parameter, + eliminating the need for monkey patching in tests. """ return GameService( state_manager=mock_state_manager, card_service=mock_card_service, + engine_factory=lambda game: mock_engine, ) @@ -324,9 +342,8 @@ class TestJoinGame: class TestExecuteAction: """Tests for the execute_action method. - These tests verify action execution through GameService. Since GameService - creates engines per-operation, we patch _create_engine_for_game to return - a mock engine with controlled behavior. + These tests verify action execution through GameService. The mock engine + is injected via the engine_factory parameter in the game_service fixture. """ @pytest.mark.asyncio @@ -334,6 +351,7 @@ class TestExecuteAction: self, game_service: GameService, mock_state_manager: AsyncMock, + mock_engine: MagicMock, sample_game_state: GameState, ) -> None: """Test successful action execution. @@ -342,19 +360,14 @@ class TestExecuteAction: the state should be saved to cache. """ mock_state_manager.load_state.return_value = sample_game_state - - mock_engine = MagicMock() - mock_engine.execute_action = AsyncMock( - return_value=ActionResult( - success=True, - message="Attack executed", - state_changes=[{"type": "damage", "amount": 30}], - ) + mock_engine.execute_action.return_value = ActionResult( + success=True, + message="Attack executed", + state_changes=[{"type": "damage", "amount": 30}], ) - with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine): - action = AttackAction(attack_index=0) - result = await game_service.execute_action("game-123", "player-1", action) + action = AttackAction(attack_index=0) + result = await game_service.execute_action("game-123", "player-1", action) assert result.success is True assert result.action_type == "attack" @@ -434,6 +447,7 @@ class TestExecuteAction: self, game_service: GameService, mock_state_manager: AsyncMock, + mock_engine: MagicMock, sample_game_state: GameState, ) -> None: """Test that resignation is allowed even when not your turn. @@ -442,24 +456,19 @@ class TestExecuteAction: by either player. """ mock_state_manager.load_state.return_value = sample_game_state - - mock_engine = MagicMock() - mock_engine.execute_action = AsyncMock( - return_value=ActionResult( - success=True, - message="Player resigned", - win_result=WinResult( - winner_id="player-1", - loser_id="player-2", - end_reason=GameEndReason.RESIGNATION, - reason="Player resigned", - ), - ) + mock_engine.execute_action.return_value = ActionResult( + success=True, + message="Player resigned", + win_result=WinResult( + winner_id="player-1", + loser_id="player-2", + end_reason=GameEndReason.RESIGNATION, + reason="Player resigned", + ), ) - with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine): - # player-2 resigns even though it's player-1's turn - result = await game_service.execute_action("game-123", "player-2", ResignAction()) + # player-2 resigns even though it's player-1's turn + result = await game_service.execute_action("game-123", "player-2", ResignAction()) assert result.success is True assert result.game_over is True @@ -470,6 +479,7 @@ class TestExecuteAction: self, game_service: GameService, mock_state_manager: AsyncMock, + mock_engine: MagicMock, sample_game_state: GameState, ) -> None: """Test execute_action raises error for invalid actions. @@ -478,19 +488,12 @@ class TestExecuteAction: InvalidActionError with the reason. """ mock_state_manager.load_state.return_value = sample_game_state - - mock_engine = MagicMock() - mock_engine.execute_action = AsyncMock( - return_value=ActionResult( - success=False, - message="Not enough energy to attack", - ) + mock_engine.execute_action.return_value = ActionResult( + success=False, + message="Not enough energy to attack", ) - with ( - patch.object(game_service, "_create_engine_for_game", return_value=mock_engine), - pytest.raises(InvalidActionError) as exc_info, - ): + with pytest.raises(InvalidActionError) as exc_info: await game_service.execute_action("game-123", "player-1", AttackAction(attack_index=0)) assert "Not enough energy" in exc_info.value.reason @@ -500,6 +503,7 @@ class TestExecuteAction: self, game_service: GameService, mock_state_manager: AsyncMock, + mock_engine: MagicMock, sample_game_state: GameState, ) -> None: """Test execute_action detects game over and persists to DB. @@ -508,58 +512,26 @@ class TestExecuteAction: to the database for durability. """ mock_state_manager.load_state.return_value = sample_game_state - - mock_engine = MagicMock() - mock_engine.execute_action = AsyncMock( - return_value=ActionResult( - success=True, - message="Final prize taken!", - win_result=WinResult( - winner_id="player-1", - loser_id="player-2", - end_reason=GameEndReason.PRIZES_TAKEN, - reason="All prizes taken", - ), - ) + mock_engine.execute_action.return_value = ActionResult( + success=True, + message="Final prize taken!", + win_result=WinResult( + winner_id="player-1", + loser_id="player-2", + end_reason=GameEndReason.PRIZES_TAKEN, + reason="All prizes taken", + ), ) - with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine): - result = await game_service.execute_action( - "game-123", "player-1", AttackAction(attack_index=0) - ) + result = await game_service.execute_action( + "game-123", "player-1", AttackAction(attack_index=0) + ) assert result.game_over is True assert result.winner_id == "player-1" assert result.end_reason == GameEndReason.PRIZES_TAKEN mock_state_manager.persist_to_db.assert_called_once() - @pytest.mark.asyncio - async def test_execute_action_uses_game_rules( - self, - game_service: GameService, - mock_state_manager: AsyncMock, - sample_game_state: GameState, - ) -> None: - """Test that execute_action creates engine with game's rules. - - The engine should be created using the rules stored in the game - state, not any service-level defaults. - """ - mock_state_manager.load_state.return_value = sample_game_state - - mock_engine = MagicMock() - mock_engine.execute_action = AsyncMock( - return_value=ActionResult(success=True, message="OK") - ) - - with patch.object( - game_service, "_create_engine_for_game", return_value=mock_engine - ) as mock_create: - await game_service.execute_action("game-123", "player-1", PassAction()) - - # Verify engine was created with the game state - mock_create.assert_called_once_with(sample_game_state) - class TestResignGame: """Tests for the resign_game convenience method.""" @@ -569,6 +541,7 @@ class TestResignGame: self, game_service: GameService, mock_state_manager: AsyncMock, + mock_engine: MagicMock, sample_game_state: GameState, ) -> None: """Test that resign_game is a convenience wrapper for execute_action. @@ -577,23 +550,18 @@ class TestResignGame: and call execute_action. """ mock_state_manager.load_state.return_value = sample_game_state - - mock_engine = MagicMock() - mock_engine.execute_action = AsyncMock( - return_value=ActionResult( - success=True, - message="Player resigned", - win_result=WinResult( - winner_id="player-2", - loser_id="player-1", - end_reason=GameEndReason.RESIGNATION, - reason="Player resigned", - ), - ) + mock_engine.execute_action.return_value = ActionResult( + success=True, + message="Player resigned", + win_result=WinResult( + winner_id="player-2", + loser_id="player-1", + end_reason=GameEndReason.RESIGNATION, + reason="Player resigned", + ), ) - with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine): - result = await game_service.resign_game("game-123", "player-1") + result = await game_service.resign_game("game-123", "player-1") assert result.success is True assert result.action_type == "resign" @@ -692,31 +660,39 @@ class TestCreateGame: assert "GS-002" in str(exc_info.value) -class TestCreateEngineForGame: - """Tests for the _create_engine_for_game method. +class TestDefaultEngineFactory: + """Tests for the _default_engine_factory method. This method is responsible for creating a GameEngine configured with the rules from a specific game's state. """ - def test_create_engine_uses_game_rules( + def test_default_engine_uses_game_rules( self, - game_service: GameService, + mock_state_manager: AsyncMock, + mock_card_service: MagicMock, sample_game_state: GameState, ) -> None: - """Test that engine is created with the game's rules. + """Test that default engine is created with the game's rules. The engine should use the RulesConfig stored in the game state, not any default configuration. """ - engine = game_service._create_engine_for_game(sample_game_state) + # Create service without mock engine to test default factory + service = GameService( + state_manager=mock_state_manager, + card_service=mock_card_service, + ) + + engine = service._default_engine_factory(sample_game_state) # Engine should have the game's rules assert engine.rules == sample_game_state.rules - def test_create_engine_with_rng_seed( + def test_default_engine_with_rng_seed( self, - game_service: GameService, + mock_state_manager: AsyncMock, + mock_card_service: MagicMock, sample_game_state: GameState, ) -> None: """Test that engine uses seeded RNG when game has rng_seed. @@ -724,17 +700,21 @@ class TestCreateEngineForGame: When a game has an rng_seed set, the engine should use a deterministic RNG for replay support. """ + service = GameService( + state_manager=mock_state_manager, + card_service=mock_card_service, + ) sample_game_state.rng_seed = 12345 - engine = game_service._create_engine_for_game(sample_game_state) + engine = service._default_engine_factory(sample_game_state) - # Engine should have been created (we can't easily verify seed, - # but we can verify it doesn't error) + # Engine should have been created successfully assert engine is not None - def test_create_engine_without_rng_seed( + def test_default_engine_without_rng_seed( self, - game_service: GameService, + mock_state_manager: AsyncMock, + mock_card_service: MagicMock, sample_game_state: GameState, ) -> None: """Test that engine uses secure RNG when no seed is set. @@ -742,15 +722,20 @@ class TestCreateEngineForGame: Without an rng_seed, the engine should use cryptographically secure random number generation. """ + service = GameService( + state_manager=mock_state_manager, + card_service=mock_card_service, + ) sample_game_state.rng_seed = None - engine = game_service._create_engine_for_game(sample_game_state) + engine = service._default_engine_factory(sample_game_state) assert engine is not None - def test_create_engine_derives_unique_seed_per_action( + def test_default_engine_derives_unique_seed_per_action( self, - game_service: GameService, + mock_state_manager: AsyncMock, + mock_card_service: MagicMock, sample_game_state: GameState, ) -> None: """Test that different action counts produce different RNG sequences. @@ -758,15 +743,19 @@ class TestCreateEngineForGame: For deterministic replay, each action needs a unique but reproducible RNG seed based on game seed + action count. """ + service = GameService( + state_manager=mock_state_manager, + card_service=mock_card_service, + ) sample_game_state.rng_seed = 12345 # Simulate first action (action_log is empty) sample_game_state.action_log = [] - engine1 = game_service._create_engine_for_game(sample_game_state) + engine1 = service._default_engine_factory(sample_game_state) # Simulate second action (one action in log) sample_game_state.action_log = [{"type": "pass"}] - engine2 = game_service._create_engine_for_game(sample_game_state) + engine2 = service._default_engine_factory(sample_game_state) # Both engines should be created successfully # (They will have different seeds due to action count)