Refactor to dependency injection pattern - no monkey patching

- ConnectionManager: Add redis_factory constructor parameter
- GameService: Add engine_factory constructor parameter
- AuthHandler: New class replacing standalone functions with
  token_verifier and conn_manager injection
- Update all tests to use constructor DI instead of patch()
- Update CLAUDE.md with factory injection patterns
- Update services README with new patterns
- Add socketio README documenting AuthHandler and events

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2026-01-28 22:54:57 -06:00
parent 0c810e5b30
commit f512c7b2b3
11 changed files with 912 additions and 567 deletions

View File

@ -78,11 +78,11 @@ def validate_deck(cards, config: DeckConfig | None = None):
## Dependency Injection for Services
**Services use constructor-based dependency injection for repositories and other services.**
**Services use constructor-based dependency injection for repositories, services, and external resources.**
Services still use DI for data access, but config comes from method parameters (request).
Services use DI for data access and external dependencies, but config comes from method parameters (request).
### Required Pattern
### Required Pattern: Services and Repositories
```python
class DeckService:
@ -109,12 +109,78 @@ class DeckService:
...
```
### Required Pattern: Factory Injection for External Resources
For external resources (Redis, databases, engines), inject factory functions rather than instances:
```python
from collections.abc import AsyncIterator, Callable
# Type aliases for factories
RedisFactory = Callable[[], AsyncIterator["Redis"]]
EngineFactory = Callable[[GameState], GameEngine]
TokenVerifier = Callable[[str], UUID | None]
class ConnectionManager:
"""External resources injected as factories."""
def __init__(
self,
conn_ttl_seconds: int = 3600,
redis_factory: RedisFactory | None = None, # Factory, not instance
) -> None:
self.conn_ttl_seconds = conn_ttl_seconds
# Fall back to global only if not provided
self._get_redis = redis_factory if redis_factory is not None else get_redis
async def register_connection(self, sid: str, user_id: UUID) -> None:
async with self._get_redis() as redis: # Use injected factory
await redis.hset(...)
```
```python
class GameService:
"""Engine created via injected factory for testability."""
def __init__(
self,
state_manager: GameStateManager | None = None,
engine_factory: EngineFactory | None = None, # Factory for engine creation
) -> None:
self._state_manager = state_manager or game_state_manager
self._engine_factory = engine_factory or self._default_engine_factory
async def execute_action(self, game_id: str, player_id: str, action: Action):
state = await self.get_game_state(game_id)
engine = self._engine_factory(state) # Create via factory
result = await engine.execute_action(state, player_id, action)
```
### Global Singletons with DI Support
Create global instances for production, but always support injection for testing:
```python
# Global singleton for production use
connection_manager = ConnectionManager()
# Tests inject mocks via constructor - no patching needed
def test_register_connection(mock_redis):
@asynccontextmanager
async def mock_redis_factory():
yield mock_redis
manager = ConnectionManager(redis_factory=mock_redis_factory)
# Test with injected mock...
```
### Why This Matters
1. **Testability**: Dependencies can be mocked without patching globals
2. **Offline Fork**: Services can be swapped for local implementations
3. **Explicit Dependencies**: Constructor shows all requirements
4. **Stateless Operations**: Config comes from request, not server state
5. **No Monkey Patching**: Tests use constructor injection, not `patch()`
---
@ -253,6 +319,55 @@ def test_paralyzed_pokemon_cannot_attack():
| `tests/core/` | Core engine tests | No |
| `tests/api/` | API endpoint tests | Yes |
### No Monkey Patching - Use Dependency Injection
**Never use `patch()` or `monkeypatch` for unit tests.** Instead, inject mocks via constructor.
```python
# WRONG - Monkey patching
async def test_execute_action():
with patch.object(game_service, "_create_engine") as mock:
mock.return_value = mock_engine
result = await game_service.execute_action(...)
# CORRECT - Dependency injection
@pytest.fixture
def mock_engine():
engine = MagicMock()
engine.execute_action = AsyncMock(return_value=ActionResult(success=True))
return engine
@pytest.fixture
def game_service(mock_state_manager, mock_engine):
return GameService(
state_manager=mock_state_manager,
engine_factory=lambda state: mock_engine, # Inject via constructor
)
async def test_execute_action(game_service, mock_engine):
result = await game_service.execute_action(...)
mock_engine.execute_action.assert_called_once()
```
### Factory Fixtures for External Resources
```python
@pytest.fixture
def mock_redis():
redis = AsyncMock()
redis.hset = AsyncMock()
redis.hget = AsyncMock(return_value=None)
return redis
@pytest.fixture
def connection_manager(mock_redis):
@asynccontextmanager
async def mock_redis_factory():
yield mock_redis
return ConnectionManager(redis_factory=mock_redis_factory)
```
### Use Seeded RNG for Determinism
```python
@ -276,6 +391,12 @@ app/
│ ├── config.py # RulesConfig and sub-configs
│ └── engine.py # GameEngine orchestrator
├── services/ # Business logic (uses repositories)
│ ├── game_service.py # Game orchestration (uses engine_factory DI)
│ ├── connection_manager.py # WebSocket tracking (uses redis_factory DI)
│ └── ...
├── socketio/ # WebSocket real-time communication
│ ├── server.py # Socket.IO server setup and handlers
│ └── auth.py # AuthHandler class (uses token_verifier DI)
├── repositories/ # Data access layer
│ ├── protocols.py # Repository protocols (interfaces)
│ └── postgres/ # PostgreSQL implementations
@ -322,6 +443,9 @@ app/
| Hardcoded magic numbers | Use values from config parameter |
| Tests without docstrings | Always explain what and why |
| Unit tests in `tests/services/` | Use `tests/unit/` for no-DB tests |
| Using `patch()` in unit tests | Inject mocks via constructor DI |
| `async with get_redis()` in method body | Inject `redis_factory` via constructor |
| Importing dependencies inside functions | Import at module level, inject via constructor |
---
@ -330,3 +454,5 @@ app/
- `app/core/AGENTS.md` - Detailed core engine guidelines
- `app/core/README.md` - Core module documentation
- `app/core/effects/README.md` - Effect system documentation
- `app/services/README.md` - Services layer patterns
- `app/socketio/README.md` - WebSocket module documentation

View File

@ -19,6 +19,8 @@ Business logic layer between API endpoints and data access.
| `DeckValidator` | Pure deck validation functions | CardService (for lookups) |
| `UserService` | User profile management | Direct DB access |
| `GameStateManager` | Game state (Redis + Postgres) | Redis, AsyncSession |
| `GameService` | Game orchestration + actions | GameStateManager, CardService, engine_factory |
| `ConnectionManager` | WebSocket connection tracking | redis_factory |
| `JWTService` | Token creation/verification | Settings |
| `TokenStore` | Refresh token storage | Redis |
@ -93,6 +95,50 @@ Benefits:
- **Offline Fork**: Swap PostgresRepository for LocalRepository
- **Decoupling**: Services don't know about SQLAlchemy
### Factory Injection for External Resources
For external resources (Redis, GameEngine), inject factory functions rather than instances:
```python
from collections.abc import AsyncIterator, Callable
# Type aliases for factories
RedisFactory = Callable[[], AsyncIterator["Redis"]]
EngineFactory = Callable[[GameState], GameEngine]
class ConnectionManager:
"""External resources injected as factories."""
def __init__(
self,
conn_ttl_seconds: int = 3600,
redis_factory: RedisFactory | None = None,
) -> None:
self.conn_ttl_seconds = conn_ttl_seconds
self._get_redis = redis_factory if redis_factory is not None else get_redis
class GameService:
"""Engine created via injected factory for testability."""
def __init__(
self,
state_manager: GameStateManager | None = None,
engine_factory: EngineFactory | None = None,
) -> None:
self._state_manager = state_manager or game_state_manager
self._engine_factory = engine_factory or self._default_engine_factory
```
Global singletons are created for production, but constructors always support injection:
```python
# Production singleton
connection_manager = ConnectionManager()
# Tests inject mocks - no patching needed
manager = ConnectionManager(redis_factory=mock_redis_factory)
```
### Pure Validation Functions
Validation logic is extracted into pure functions for reuse:
@ -183,6 +229,27 @@ await manager.persist_to_db(game_id, game_state)
| `tests/unit/services/` | Pure unit tests, mocked deps | No |
| `tests/services/` | Integration tests | Yes (testcontainers) |
**No monkey patching** - Always use constructor DI for mocks:
```python
# WRONG - monkey patching
async def test_something():
with patch.object(service, "_method") as mock:
...
# CORRECT - constructor injection
@pytest.fixture
def mock_redis():
return AsyncMock()
@pytest.fixture
def connection_manager(mock_redis):
@asynccontextmanager
async def mock_redis_factory():
yield mock_redis
return ConnectionManager(redis_factory=mock_redis_factory)
```
### Unit Test Example
```python

View File

@ -35,14 +35,22 @@ Example:
"""
import logging
from collections.abc import AsyncIterator, Callable
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from uuid import UUID
from app.db.redis import get_redis
if TYPE_CHECKING:
from redis.asyncio import Redis
logger = logging.getLogger(__name__)
# Type alias for redis factory - a callable that returns an async context manager
RedisFactory = Callable[[], AsyncIterator["Redis"]]
# Redis key patterns
CONN_PREFIX = "conn:"
USER_CONN_PREFIX = "user_conn:"
@ -98,13 +106,20 @@ class ConnectionManager:
conn_ttl_seconds: TTL for connection records in Redis.
"""
def __init__(self, conn_ttl_seconds: int = DEFAULT_CONN_TTL_SECONDS) -> None:
def __init__(
self,
conn_ttl_seconds: int = DEFAULT_CONN_TTL_SECONDS,
redis_factory: RedisFactory | None = None,
) -> None:
"""Initialize the ConnectionManager.
Args:
conn_ttl_seconds: How long to keep connection records in Redis.
redis_factory: Optional factory for Redis connections. If not provided,
uses the default get_redis from app.db.redis. Useful for testing.
"""
self.conn_ttl_seconds = conn_ttl_seconds
self._get_redis = redis_factory if redis_factory is not None else get_redis
def _conn_key(self, sid: str) -> str:
"""Generate Redis key for a connection."""
@ -142,7 +157,7 @@ class ConnectionManager:
user_id_str = str(user_id)
now = datetime.now(UTC).isoformat()
async with get_redis() as redis:
async with self._get_redis() as redis:
# Check for existing connection and clean it up
old_sid = await redis.get(self._user_conn_key(user_id_str))
if old_sid and old_sid != sid:
@ -204,7 +219,7 @@ class ConnectionManager:
Args:
sid: Socket.IO session ID.
"""
async with get_redis() as redis:
async with self._get_redis() as redis:
conn_key = self._conn_key(sid)
# Get connection data for cleanup
@ -250,7 +265,7 @@ class ConnectionManager:
Example:
await manager.join_game("abc123", "game-456")
"""
async with get_redis() as redis:
async with self._get_redis() as redis:
conn_key = self._conn_key(sid)
# Check connection exists
@ -291,7 +306,7 @@ class ConnectionManager:
Example:
game_id = await manager.leave_game("abc123")
"""
async with get_redis() as redis:
async with self._get_redis() as redis:
conn_key = self._conn_key(sid)
# Get current game
@ -330,7 +345,7 @@ class ConnectionManager:
"""
now = datetime.now(UTC).isoformat()
async with get_redis() as redis:
async with self._get_redis() as redis:
conn_key = self._conn_key(sid)
# Check exists and update atomically
@ -370,7 +385,7 @@ class ConnectionManager:
if info:
print(f"User {info.user_id} connected at {info.connected_at}")
"""
async with get_redis() as redis:
async with self._get_redis() as redis:
conn_key = self._conn_key(sid)
data = await redis.hgetall(conn_key)
@ -401,7 +416,7 @@ class ConnectionManager:
"""
user_id_str = str(user_id)
async with get_redis() as redis:
async with self._get_redis() as redis:
user_conn_key = self._user_conn_key(user_id_str)
sid = await redis.get(user_conn_key)
@ -448,7 +463,7 @@ class ConnectionManager:
for conn in connections:
# Send state update to each participant
"""
async with get_redis() as redis:
async with self._get_redis() as redis:
game_conns_key = self._game_conns_key(game_id)
sids = await redis.smembers(game_conns_key)
@ -530,7 +545,7 @@ class ConnectionManager:
logger.info(f"Cleaned up {count} stale connections")
"""
count = 0
async with get_redis() as redis:
async with self._get_redis() as redis:
# Scan for all connection keys
async for key in redis.scan_iter(match=f"{CONN_PREFIX}*"):
sid = key[len(CONN_PREFIX) :]
@ -555,7 +570,7 @@ class ConnectionManager:
Number of active connections.
"""
count = 0
async with get_redis() as redis:
async with self._get_redis() as redis:
async for _ in redis.scan_iter(match=f"{CONN_PREFIX}*"):
count += 1
return count
@ -569,7 +584,7 @@ class ConnectionManager:
Returns:
Number of connected participants.
"""
async with get_redis() as redis:
async with self._get_redis() as redis:
game_conns_key = self._game_conns_key(game_id)
return await redis.scard(game_conns_key)

View File

@ -36,6 +36,7 @@ Example:
"""
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
@ -51,6 +52,9 @@ from app.services.game_state_manager import GameStateManager, game_state_manager
logger = logging.getLogger(__name__)
# Type alias for engine factory - takes GameState, returns GameEngine
EngineFactory = Callable[[GameState], GameEngine]
# =============================================================================
# Exceptions
@ -180,12 +184,14 @@ class GameService:
Attributes:
_state_manager: GameStateManager for persistence.
_card_service: CardService for card definitions.
_engine_factory: Factory function to create GameEngine instances.
"""
def __init__(
self,
state_manager: GameStateManager | None = None,
card_service: CardService | None = None,
engine_factory: EngineFactory | None = None,
) -> None:
"""Initialize the GameService.
@ -196,12 +202,16 @@ class GameService:
Args:
state_manager: GameStateManager instance. Uses global if not provided.
card_service: CardService instance. Uses global if not provided.
engine_factory: Optional factory for creating GameEngine instances.
If not provided, uses the default _default_engine_factory method.
Useful for testing with mock engines.
"""
self._state_manager = state_manager or game_state_manager
self._card_service = card_service or get_card_service()
self._engine_factory = engine_factory or self._default_engine_factory
def _create_engine_for_game(self, game: GameState) -> GameEngine:
"""Create a GameEngine configured for a specific game's rules.
def _default_engine_factory(self, game: GameState) -> GameEngine:
"""Default factory for creating a GameEngine from game state.
The engine is created on-demand using the rules stored in the
game state. This ensures each game uses its own configuration.
@ -412,8 +422,8 @@ class GameService:
if not isinstance(action, ResignAction) and state.current_player_id != player_id:
raise NotPlayerTurnError(game_id, player_id, state.current_player_id)
# Create engine with this game's rules
engine = self._create_engine_for_game(state)
# Create engine with this game's rules via factory
engine = self._engine_factory(state)
# Execute the action
result: ActionResult = await engine.execute_action(state, player_id, action)

View File

@ -0,0 +1,120 @@
# Socket.IO Module
Real-time WebSocket communication for active game sessions.
## Architecture
```
Client <--WebSocket--> Socket.IO Server <--> GameService <--> GameEngine
|
v
ConnectionManager <--> Redis (presence tracking)
```
## Components
| Component | Responsibility | DI Pattern |
|-----------|---------------|------------|
| `server.py` | Socket.IO server setup, event handlers | Uses `auth_handler` |
| `auth.py` | `AuthHandler` class for JWT authentication | `token_verifier`, `conn_manager` |
## AuthHandler
Handles JWT-based authentication with injectable dependencies:
```python
class AuthHandler:
def __init__(
self,
token_verifier: TokenVerifier | None = None, # Callable[[str], UUID | None]
conn_manager: ConnectionManager | None = None,
) -> None:
self._token_verifier = token_verifier or verify_access_token
self._connection_manager = conn_manager or connection_manager
# Global singleton for production
auth_handler = AuthHandler()
# Tests inject mocks via constructor
handler = AuthHandler(
token_verifier=lambda token: user_id,
conn_manager=mock_connection_manager,
)
```
### Methods
| Method | Purpose |
|--------|---------|
| `authenticate_connection(sid, auth)` | Validate JWT, return `AuthResult` |
| `setup_authenticated_session(sio, sid, user_id)` | Save session, register connection |
| `cleanup_authenticated_session(sid)` | Unregister connection, return user_id |
## Event Handlers
All handlers are in the `/game` namespace:
| Event | Direction | Purpose |
|-------|-----------|---------|
| `connect` | Client→Server | Authenticate and establish session |
| `disconnect` | Client→Server | Clean up session and notify opponent |
| `game:join` | Client→Server | Join/rejoin a game session |
| `game:action` | Client→Server | Execute a game action |
| `game:resign` | Client→Server | Resign from game |
| `game:heartbeat` | Client→Server | Keep connection alive |
## Client Connection
```javascript
const socket = io("wss://api.example.com", {
auth: { token: "JWT_ACCESS_TOKEN" }
});
socket.on("connect", () => {
socket.emit("game:join", { game_id: "uuid" });
});
socket.on("game:state", (state) => {
// Full game state update
});
socket.on("auth_error", (error) => {
// { code: "invalid_token", message: "..." }
});
```
## Testing
Use dependency injection - no monkey patching:
```python
@pytest.fixture
def mock_token_verifier():
return MagicMock(return_value=uuid4())
@pytest.fixture
def mock_connection_manager():
cm = AsyncMock()
cm.register_connection = AsyncMock()
return cm
@pytest.fixture
def auth_handler(mock_token_verifier, mock_connection_manager):
return AuthHandler(
token_verifier=mock_token_verifier,
conn_manager=mock_connection_manager,
)
async def test_authenticate_success(auth_handler, mock_token_verifier):
"""Test successful authentication with valid JWT."""
result = await auth_handler.authenticate_connection("sid", {"token": "valid"})
assert result.success
mock_token_verifier.assert_called_once_with("valid")
```
## See Also
- `app/services/connection_manager.py` - Connection presence tracking
- `app/services/game_service.py` - Game orchestration
- `app/schemas/ws_messages.py` - WebSocket message schemas
- `CLAUDE.md` - Architecture guidelines

View File

@ -23,12 +23,12 @@ Usage:
"""
from app.socketio.auth import (
AuthHandler,
AuthResult,
authenticate_connection,
cleanup_authenticated_session,
auth_handler,
extract_token,
get_session_user_id,
require_auth,
setup_authenticated_session,
)
from app.socketio.server import create_socketio_app, sio
@ -37,10 +37,10 @@ __all__ = [
"create_socketio_app",
"sio",
# Auth
"AuthHandler",
"AuthResult",
"authenticate_connection",
"cleanup_authenticated_session",
"auth_handler",
"extract_token",
"get_session_user_id",
"require_auth",
"setup_authenticated_session",
]

View File

@ -16,7 +16,7 @@ Session Data:
Example:
# In connect handler:
auth_result = await authenticate_connection(sid, auth)
auth_result = await auth_handler.authenticate_connection(sid, auth)
if not auth_result.success:
return False # Reject connection
@ -26,15 +26,19 @@ Example:
"""
import logging
from collections.abc import Callable
from dataclasses import dataclass
from datetime import UTC, datetime
from uuid import UUID
from app.services.connection_manager import connection_manager
from app.services.connection_manager import ConnectionManager, connection_manager
from app.services.jwt_service import verify_access_token
logger = logging.getLogger(__name__)
# Type alias for token verifier - takes token string, returns user_id or None
TokenVerifier = Callable[[str], UUID | None]
@dataclass
class AuthResult:
@ -53,6 +57,133 @@ class AuthResult:
error_message: str | None = None
class AuthHandler:
"""Handler for Socket.IO authentication with injectable dependencies.
This class encapsulates authentication logic with dependencies injected
via constructor, making it testable without monkey patching.
Attributes:
_token_verifier: Function to verify JWT tokens.
_connection_manager: ConnectionManager for tracking connections.
"""
def __init__(
self,
token_verifier: TokenVerifier | None = None,
conn_manager: ConnectionManager | None = None,
) -> None:
"""Initialize AuthHandler with dependencies.
Args:
token_verifier: Function to verify tokens. Uses jwt_service if not provided.
conn_manager: ConnectionManager instance. Uses global if not provided.
"""
self._token_verifier = token_verifier or verify_access_token
self._connection_manager = conn_manager or connection_manager
async def authenticate_connection(
self,
sid: str,
auth: dict[str, object] | None,
) -> AuthResult:
"""Authenticate a Socket.IO connection using JWT.
Extracts the JWT from auth data, validates it, and returns the result.
Does NOT modify socket session - caller should handle that.
Args:
sid: Socket session ID (for logging).
auth: Authentication data from connect event.
Returns:
AuthResult with success status and user_id or error details.
"""
token = extract_token(auth)
if token is None:
logger.debug(f"Connection {sid}: No token provided")
return AuthResult(
success=False,
error_code="missing_token",
error_message="Authentication token required",
)
user_id = self._token_verifier(token)
if user_id is None:
logger.debug(f"Connection {sid}: Invalid or expired token")
return AuthResult(
success=False,
error_code="invalid_token",
error_message="Invalid or expired token",
)
logger.debug(f"Connection {sid}: Authenticated as user {user_id}")
return AuthResult(
success=True,
user_id=user_id,
)
async def setup_authenticated_session(
self,
sio: object,
sid: str,
user_id: UUID,
namespace: str = "/game",
) -> None:
"""Set up socket session with authenticated user data.
Saves user_id and authentication timestamp to the socket session,
and registers the connection with ConnectionManager.
Args:
sio: Socket.IO AsyncServer instance (required).
sid: Socket session ID.
user_id: Authenticated user's UUID.
namespace: Socket.IO namespace.
"""
session_data = {
"user_id": str(user_id),
"authenticated_at": datetime.now(UTC).isoformat(),
}
await sio.save_session(sid, session_data, namespace=namespace)
await self._connection_manager.register_connection(sid, user_id)
logger.info(f"Session established: sid={sid}, user_id={user_id}")
async def cleanup_authenticated_session(
self,
sid: str,
namespace: str = "/game",
) -> str | None:
"""Clean up session data on disconnect.
Unregisters the connection from ConnectionManager and returns
the user_id for any additional cleanup needed.
Args:
sid: Socket session ID.
namespace: Socket.IO namespace.
Returns:
user_id if session was authenticated, None otherwise.
"""
conn_info = await self._connection_manager.unregister_connection(sid)
if conn_info:
logger.info(f"Session cleaned up: sid={sid}, user_id={conn_info.user_id}")
return conn_info.user_id
logger.debug(f"No session to clean up for {sid}")
return None
# Global singleton instance
auth_handler = AuthHandler()
def extract_token(auth: dict[str, object] | None) -> str | None:
"""Extract JWT token from Socket.IO auth data.
@ -96,131 +227,6 @@ def extract_token(auth: dict[str, object] | None) -> str | None:
return None
async def authenticate_connection(
sid: str,
auth: dict[str, object] | None,
) -> AuthResult:
"""Authenticate a Socket.IO connection using JWT.
Extracts the JWT from auth data, validates it, and returns the result.
Does NOT modify socket session - caller should handle that.
Args:
sid: Socket session ID (for logging).
auth: Authentication data from connect event.
Returns:
AuthResult with success status and user_id or error details.
Example:
result = await authenticate_connection(sid, auth)
if result.success:
await sio.save_session(sid, {"user_id": str(result.user_id)})
else:
logger.warning(f"Auth failed: {result.error_message}")
return False # Reject connection
"""
# Extract token from auth data
token = extract_token(auth)
if token is None:
logger.debug(f"Connection {sid}: No token provided")
return AuthResult(
success=False,
error_code="missing_token",
error_message="Authentication token required",
)
# Validate the token
user_id = verify_access_token(token)
if user_id is None:
logger.debug(f"Connection {sid}: Invalid or expired token")
return AuthResult(
success=False,
error_code="invalid_token",
error_message="Invalid or expired token",
)
logger.debug(f"Connection {sid}: Authenticated as user {user_id}")
return AuthResult(
success=True,
user_id=user_id,
)
async def setup_authenticated_session(
sio: object,
sid: str,
user_id: UUID,
namespace: str = "/game",
) -> None:
"""Set up socket session with authenticated user data.
Saves user_id and authentication timestamp to the socket session,
and registers the connection with ConnectionManager.
Args:
sio: Socket.IO AsyncServer instance.
sid: Socket session ID.
user_id: Authenticated user's UUID.
namespace: Socket.IO namespace.
Example:
if auth_result.success:
await setup_authenticated_session(sio, sid, auth_result.user_id)
"""
# Import here to avoid circular dependency
from app.socketio.server import sio as server_sio
# Use provided sio or fall back to server sio
socket_server = sio if sio is not None else server_sio
# Save to socket session
session_data = {
"user_id": str(user_id),
"authenticated_at": datetime.now(UTC).isoformat(),
}
await socket_server.save_session(sid, session_data, namespace=namespace)
# Register with ConnectionManager
await connection_manager.register_connection(sid, user_id)
logger.info(f"Session established: sid={sid}, user_id={user_id}")
async def cleanup_authenticated_session(
sid: str,
namespace: str = "/game",
) -> str | None:
"""Clean up session data on disconnect.
Unregisters the connection from ConnectionManager and returns
the user_id for any additional cleanup needed.
Args:
sid: Socket session ID.
namespace: Socket.IO namespace.
Returns:
user_id if session was authenticated, None otherwise.
Example:
user_id = await cleanup_authenticated_session(sid)
if user_id:
# Notify opponent, etc.
"""
# Unregister from ConnectionManager
conn_info = await connection_manager.unregister_connection(sid)
if conn_info:
logger.info(f"Session cleaned up: sid={sid}, user_id={conn_info.user_id}")
return conn_info.user_id
logger.debug(f"No session to clean up for {sid}")
return None
async def get_session_user_id(
sio: object,
sid: str,

View File

@ -33,12 +33,7 @@ from typing import TYPE_CHECKING
import socketio
from app.config import settings
from app.socketio.auth import (
authenticate_connection,
cleanup_authenticated_session,
require_auth,
setup_authenticated_session,
)
from app.socketio.auth import auth_handler, require_auth
if TYPE_CHECKING:
from fastapi import FastAPI
@ -84,7 +79,7 @@ async def connect(
None is treated as True (accept).
"""
# Authenticate the connection
auth_result = await authenticate_connection(sid, auth)
auth_result = await auth_handler.authenticate_connection(sid, auth)
if not auth_result.success:
logger.warning(
@ -103,7 +98,7 @@ async def connect(
return False
# Set up authenticated session and register connection
await setup_authenticated_session(sio, sid, auth_result.user_id, namespace="/game")
await auth_handler.setup_authenticated_session(sio, sid, auth_result.user_id, namespace="/game")
logger.info(f"Client authenticated to /game: sid={sid}, user_id={auth_result.user_id}")
return True
@ -119,7 +114,7 @@ async def disconnect(sid: str) -> None:
sid: Socket session ID of disconnecting client.
"""
# Clean up session and get user info
user_id = await cleanup_authenticated_session(sid, namespace="/game")
user_id = await auth_handler.cleanup_authenticated_session(sid, namespace="/game")
if user_id:
logger.info(f"Client disconnected from /game: sid={sid}, user_id={user_id}")

View File

@ -4,19 +4,17 @@ This module tests JWT-based authentication for WebSocket connections,
including token extraction, validation, and session management.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from app.socketio.auth import (
AuthHandler,
AuthResult,
authenticate_connection,
cleanup_authenticated_session,
extract_token,
get_session_user_id,
require_auth,
setup_authenticated_session,
)
@ -106,95 +104,143 @@ class TestExtractToken:
class TestAuthenticateConnection:
"""Tests for connection authentication."""
"""Tests for connection authentication via AuthHandler."""
@pytest.mark.asyncio
async def test_authenticate_success_with_valid_token(self) -> None:
@pytest.fixture
def mock_token_verifier(self) -> MagicMock:
"""Create a mock token verifier function."""
return MagicMock()
@pytest.fixture
def mock_connection_manager(self) -> AsyncMock:
"""Create a mock ConnectionManager."""
cm = AsyncMock()
cm.register_connection = AsyncMock()
cm.unregister_connection = AsyncMock()
return cm
@pytest.fixture
def auth_handler(
self,
mock_token_verifier: MagicMock,
mock_connection_manager: AsyncMock,
) -> AuthHandler:
"""Create AuthHandler with injected mocks."""
return AuthHandler(
token_verifier=mock_token_verifier,
conn_manager=mock_connection_manager,
)
async def test_authenticate_success_with_valid_token(
self,
auth_handler: AuthHandler,
mock_token_verifier: MagicMock,
) -> None:
"""Test successful authentication with a valid JWT.
A valid access token should result in AuthResult with
success=True and the user_id from the token.
"""
user_id = uuid4()
mock_token_verifier.return_value = user_id
with patch("app.socketio.auth.verify_access_token") as mock_verify:
mock_verify.return_value = user_id
result = await auth_handler.authenticate_connection("test-sid", {"token": "valid-token"})
result = await authenticate_connection("test-sid", {"token": "valid-token"})
assert result.success is True
assert result.user_id == user_id
assert result.error_code is None
mock_token_verifier.assert_called_once_with("valid-token")
assert result.success is True
assert result.user_id == user_id
assert result.error_code is None
mock_verify.assert_called_once_with("valid-token")
@pytest.mark.asyncio
async def test_authenticate_fails_with_missing_token(self) -> None:
async def test_authenticate_fails_with_missing_token(
self,
auth_handler: AuthHandler,
) -> None:
"""Test authentication failure when no token is provided.
Connections without any auth data should fail with
a 'missing_token' error code.
"""
result = await authenticate_connection("test-sid", None)
result = await auth_handler.authenticate_connection("test-sid", None)
assert result.success is False
assert result.user_id is None
assert result.error_code == "missing_token"
assert "required" in result.error_message.lower()
@pytest.mark.asyncio
async def test_authenticate_fails_with_empty_auth(self) -> None:
async def test_authenticate_fails_with_empty_auth(
self,
auth_handler: AuthHandler,
) -> None:
"""Test authentication failure with empty auth object.
An auth object without any token fields should fail.
"""
result = await authenticate_connection("test-sid", {})
result = await auth_handler.authenticate_connection("test-sid", {})
assert result.success is False
assert result.error_code == "missing_token"
@pytest.mark.asyncio
async def test_authenticate_fails_with_invalid_token(self) -> None:
async def test_authenticate_fails_with_invalid_token(
self,
auth_handler: AuthHandler,
mock_token_verifier: MagicMock,
) -> None:
"""Test authentication failure with invalid/expired JWT.
When verify_access_token returns None (invalid token),
we should fail with 'invalid_token' error.
"""
with patch("app.socketio.auth.verify_access_token") as mock_verify:
mock_verify.return_value = None # Token validation failed
mock_token_verifier.return_value = None
result = await authenticate_connection("test-sid", {"token": "invalid-token"})
result = await auth_handler.authenticate_connection("test-sid", {"token": "invalid-token"})
assert result.success is False
assert result.user_id is None
assert result.error_code == "invalid_token"
assert (
"invalid" in result.error_message.lower()
or "expired" in result.error_message.lower()
)
assert result.success is False
assert result.user_id is None
assert result.error_code == "invalid_token"
assert (
"invalid" in result.error_message.lower() or "expired" in result.error_message.lower()
)
@pytest.mark.asyncio
async def test_authenticate_extracts_token_from_bearer(self) -> None:
async def test_authenticate_extracts_token_from_bearer(
self,
auth_handler: AuthHandler,
mock_token_verifier: MagicMock,
) -> None:
"""Test that authentication works with Bearer format.
The authenticate function should handle Bearer token format
through the extract_token function.
"""
user_id = uuid4()
mock_token_verifier.return_value = user_id
with patch("app.socketio.auth.verify_access_token") as mock_verify:
mock_verify.return_value = user_id
result = await auth_handler.authenticate_connection(
"test-sid", {"authorization": "Bearer my-token"}
)
result = await authenticate_connection("test-sid", {"authorization": "Bearer my-token"})
assert result.success is True
mock_verify.assert_called_once_with("my-token")
assert result.success is True
mock_token_verifier.assert_called_once_with("my-token")
class TestSetupAuthenticatedSession:
"""Tests for session setup after authentication."""
@pytest.mark.asyncio
async def test_setup_saves_session_data(self) -> None:
@pytest.fixture
def mock_connection_manager(self) -> AsyncMock:
"""Create a mock ConnectionManager."""
cm = AsyncMock()
cm.register_connection = AsyncMock()
return cm
@pytest.fixture
def auth_handler(self, mock_connection_manager: AsyncMock) -> AuthHandler:
"""Create AuthHandler with injected mock."""
return AuthHandler(conn_manager=mock_connection_manager)
async def test_setup_saves_session_data(
self,
auth_handler: AuthHandler,
) -> None:
"""Test that session setup saves user_id and timestamp.
After authentication, the socket session should contain
@ -203,22 +249,21 @@ class TestSetupAuthenticatedSession:
user_id = uuid4()
mock_sio = AsyncMock()
with patch("app.socketio.auth.connection_manager") as mock_cm:
mock_cm.register_connection = AsyncMock()
await auth_handler.setup_authenticated_session(mock_sio, "test-sid", user_id)
await setup_authenticated_session(mock_sio, "test-sid", user_id)
mock_sio.save_session.assert_called_once()
call_args = mock_sio.save_session.call_args
assert call_args.args[0] == "test-sid"
# Verify session was saved
mock_sio.save_session.assert_called_once()
call_args = mock_sio.save_session.call_args
assert call_args.args[0] == "test-sid"
session_data = call_args.args[1]
assert session_data["user_id"] == str(user_id)
assert "authenticated_at" in session_data
session_data = call_args.args[1]
assert session_data["user_id"] == str(user_id)
assert "authenticated_at" in session_data
@pytest.mark.asyncio
async def test_setup_registers_with_connection_manager(self) -> None:
async def test_setup_registers_with_connection_manager(
self,
auth_handler: AuthHandler,
mock_connection_manager: AsyncMock,
) -> None:
"""Test that session setup registers with ConnectionManager.
The connection should be tracked in ConnectionManager for
@ -227,53 +272,63 @@ class TestSetupAuthenticatedSession:
user_id = uuid4()
mock_sio = AsyncMock()
with patch("app.socketio.auth.connection_manager") as mock_cm:
mock_cm.register_connection = AsyncMock()
await auth_handler.setup_authenticated_session(mock_sio, "test-sid", user_id)
await setup_authenticated_session(mock_sio, "test-sid", user_id)
mock_cm.register_connection.assert_called_once_with("test-sid", user_id)
mock_connection_manager.register_connection.assert_called_once_with("test-sid", user_id)
class TestCleanupAuthenticatedSession:
"""Tests for session cleanup on disconnect."""
@pytest.mark.asyncio
async def test_cleanup_unregisters_connection(self) -> None:
@pytest.fixture
def mock_connection_manager(self) -> AsyncMock:
"""Create a mock ConnectionManager."""
return AsyncMock()
@pytest.fixture
def auth_handler(self, mock_connection_manager: AsyncMock) -> AuthHandler:
"""Create AuthHandler with injected mock."""
return AuthHandler(conn_manager=mock_connection_manager)
async def test_cleanup_unregisters_connection(
self,
auth_handler: AuthHandler,
mock_connection_manager: AsyncMock,
) -> None:
"""Test that cleanup unregisters from ConnectionManager.
On disconnect, the connection should be removed from
ConnectionManager to update presence tracking.
"""
with patch("app.socketio.auth.connection_manager") as mock_cm:
mock_conn_info = MagicMock()
mock_conn_info.user_id = "user-123"
mock_cm.unregister_connection = AsyncMock(return_value=mock_conn_info)
mock_conn_info = MagicMock()
mock_conn_info.user_id = "user-123"
mock_connection_manager.unregister_connection = AsyncMock(return_value=mock_conn_info)
result = await cleanup_authenticated_session("test-sid")
result = await auth_handler.cleanup_authenticated_session("test-sid")
assert result == "user-123"
mock_cm.unregister_connection.assert_called_once_with("test-sid")
assert result == "user-123"
mock_connection_manager.unregister_connection.assert_called_once_with("test-sid")
@pytest.mark.asyncio
async def test_cleanup_returns_none_for_unknown_session(self) -> None:
async def test_cleanup_returns_none_for_unknown_session(
self,
auth_handler: AuthHandler,
mock_connection_manager: AsyncMock,
) -> None:
"""Test cleanup returns None for non-existent sessions.
If the connection wasn't registered (e.g., auth failed),
cleanup should return None gracefully.
"""
with patch("app.socketio.auth.connection_manager") as mock_cm:
mock_cm.unregister_connection = AsyncMock(return_value=None)
mock_connection_manager.unregister_connection = AsyncMock(return_value=None)
result = await cleanup_authenticated_session("unknown-sid")
result = await auth_handler.cleanup_authenticated_session("unknown-sid")
assert result is None
assert result is None
class TestGetSessionUserId:
"""Tests for session user_id retrieval."""
@pytest.mark.asyncio
async def test_get_session_user_id_returns_id(self) -> None:
"""Test retrieving user_id from authenticated session.
@ -290,7 +345,6 @@ class TestGetSessionUserId:
assert result == "user-123"
mock_sio.get_session.assert_called_once_with("test-sid", namespace="/game")
@pytest.mark.asyncio
async def test_get_session_user_id_returns_none_for_missing(self) -> None:
"""Test that missing session returns None.
@ -304,7 +358,6 @@ class TestGetSessionUserId:
assert result is None
@pytest.mark.asyncio
async def test_get_session_user_id_handles_exception(self) -> None:
"""Test that exceptions are caught and return None.
@ -322,7 +375,6 @@ class TestGetSessionUserId:
class TestRequireAuth:
"""Tests for the require_auth helper."""
@pytest.mark.asyncio
async def test_require_auth_returns_user_id_for_authenticated(self) -> None:
"""Test that require_auth returns user_id for valid sessions.
@ -336,7 +388,6 @@ class TestRequireAuth:
assert result == "user-123"
@pytest.mark.asyncio
async def test_require_auth_returns_none_for_unauthenticated(self) -> None:
"""Test that require_auth returns None for unauthenticated sessions.

View File

@ -1,12 +1,16 @@
"""Tests for ConnectionManager service.
This module tests WebSocket connection tracking with Redis. Since these are
unit tests, we mock the Redis operations to test the ConnectionManager logic
unit tests, we inject a mock Redis factory to test the ConnectionManager logic
without requiring a real Redis instance.
Uses dependency injection pattern - no monkey patching required.
"""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
@ -21,15 +25,13 @@ from app.services.connection_manager import (
)
@pytest.fixture
def manager() -> ConnectionManager:
"""Create a ConnectionManager instance for testing."""
return ConnectionManager(conn_ttl_seconds=3600)
@pytest.fixture
def mock_redis() -> AsyncMock:
"""Create a mock Redis client."""
"""Create a mock Redis client.
This mock provides all the Redis operations used by ConnectionManager
with sensible defaults that can be overridden per-test.
"""
redis = AsyncMock()
redis.hset = AsyncMock()
redis.hget = AsyncMock()
@ -46,6 +48,21 @@ def mock_redis() -> AsyncMock:
return redis
@pytest.fixture
def manager(mock_redis: AsyncMock) -> ConnectionManager:
"""Create a ConnectionManager with injected mock Redis.
The mock Redis is injected via the redis_factory parameter,
eliminating the need for monkey patching in tests.
"""
@asynccontextmanager
async def mock_redis_factory() -> AsyncIterator[AsyncMock]:
yield mock_redis
return ConnectionManager(conn_ttl_seconds=3600, redis_factory=mock_redis_factory)
class TestConnectionInfoDataclass:
"""Tests for the ConnectionInfo dataclass."""
@ -87,8 +104,8 @@ class TestConnectionInfoDataclass:
def test_is_stale_with_custom_threshold(self) -> None:
"""Test is_stale with a custom threshold.
The threshold can be adjusted for different use cases like
more aggressive cleanup or more lenient timeout.
The threshold parameter allows callers to customize what counts
as stale for different use cases.
"""
now = datetime.now(UTC)
last_seen = now - timedelta(seconds=60)
@ -96,7 +113,7 @@ class TestConnectionInfoDataclass:
sid="test-sid",
user_id="user-123",
game_id=None,
connected_at=now,
connected_at=last_seen,
last_seen=last_seen,
)
@ -125,23 +142,20 @@ class TestRegisterConnection:
sid = "test-sid-123"
user_id = str(uuid4())
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
await manager.register_connection(sid, user_id)
await manager.register_connection(sid, user_id)
# Verify connection hash was created
conn_key = f"{CONN_PREFIX}{sid}"
mock_redis.hset.assert_called()
hset_call = mock_redis.hset.call_args
assert hset_call.args[0] == conn_key
mapping = hset_call.kwargs["mapping"]
assert mapping["user_id"] == user_id
assert mapping["game_id"] == ""
# Verify connection hash was created
conn_key = f"{CONN_PREFIX}{sid}"
mock_redis.hset.assert_called()
hset_call = mock_redis.hset.call_args
assert hset_call.args[0] == conn_key
mapping = hset_call.kwargs["mapping"]
assert mapping["user_id"] == user_id
assert mapping["game_id"] == ""
# Verify user-to-connection mapping was created
user_conn_key = f"{USER_CONN_PREFIX}{user_id}"
mock_redis.set.assert_called_with(user_conn_key, sid)
# Verify user-to-connection mapping was created
user_conn_key = f"{USER_CONN_PREFIX}{user_id}"
mock_redis.set.assert_called_with(user_conn_key, sid)
@pytest.mark.asyncio
async def test_register_connection_replaces_old_connection(
@ -167,15 +181,12 @@ class TestRegisterConnection:
"last_seen": datetime.now(UTC).isoformat(),
}
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
await manager.register_connection(new_sid, user_id)
await manager.register_connection(new_sid, user_id)
# Verify old connection was cleaned up (delete called for old conn key)
old_conn_key = f"{CONN_PREFIX}{old_sid}"
delete_calls = [call.args[0] for call in mock_redis.delete.call_args_list]
assert old_conn_key in delete_calls
# Verify old connection was cleaned up (delete called for old conn key)
old_conn_key = f"{CONN_PREFIX}{old_sid}"
delete_calls = [call.args[0] for call in mock_redis.delete.call_args_list]
assert old_conn_key in delete_calls
@pytest.mark.asyncio
async def test_register_connection_accepts_uuid(
@ -191,15 +202,12 @@ class TestRegisterConnection:
sid = "test-sid"
user_uuid = uuid4()
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
await manager.register_connection(sid, user_uuid)
await manager.register_connection(sid, user_uuid)
# Verify user_id was converted to string
hset_call = mock_redis.hset.call_args
mapping = hset_call.kwargs["mapping"]
assert mapping["user_id"] == str(user_uuid)
# Verify user_id was converted to string
hset_call = mock_redis.hset.call_args
mapping = hset_call.kwargs["mapping"]
assert mapping["user_id"] == str(user_uuid)
class TestUnregisterConnection:
@ -213,17 +221,14 @@ class TestUnregisterConnection:
) -> None:
"""Test that unregistering unknown connection returns None.
If the connection doesn't exist, we should return None rather than
raising an error.
Unregistering a non-existent connection should be a no-op
and return None to indicate nothing was cleaned up.
"""
mock_redis.hgetall.return_value = {}
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.unregister_connection("unknown-sid")
result = await manager.unregister_connection("unknown-sid")
assert result is None
assert result is None
@pytest.mark.asyncio
async def test_unregister_cleans_up_all_data(
@ -231,44 +236,39 @@ class TestUnregisterConnection:
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that unregistering cleans up all related Redis data.
"""Test that unregistering cleans up all Redis data.
Cleanup should remove:
Unregistration should remove:
1. Connection hash
2. User-to-connection mapping (if it still points to this sid)
3. Game connection set membership
2. User-to-connection mapping
3. Remove from game connection set (if in a game)
"""
sid = "test-sid"
user_id = "user-123"
game_id = "game-456"
now = datetime.now(UTC)
# Mock: connection exists with game
mock_redis.hgetall.return_value = {
"user_id": user_id,
"game_id": game_id,
"connected_at": datetime.now(UTC).isoformat(),
"last_seen": datetime.now(UTC).isoformat(),
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
}
mock_redis.get.return_value = sid # user mapping points to this sid
mock_redis.get.return_value = sid # User's current connection
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.unregister_connection(sid)
result = await manager.unregister_connection(sid)
assert result is not None
assert result.user_id == user_id
assert result.game_id == game_id
# Verify result
assert result is not None
assert result.sid == sid
assert result.user_id == user_id
assert result.game_id == game_id
# Verify cleanup
mock_redis.srem.assert_called() # Removed from game set
mock_redis.delete.assert_called() # Connection deleted
# Verify cleanup calls
mock_redis.delete.assert_called()
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
class TestGameAssociation:
"""Tests for game join/leave operations."""
"""Tests for game association methods."""
@pytest.mark.asyncio
async def test_join_game_adds_to_game_set(
@ -276,25 +276,22 @@ class TestGameAssociation:
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that joining a game adds the connection to the game's set.
"""Test that joining a game adds connection to game's set.
When a connection joins a game, it should:
1. Update the connection's game_id
2. Add the sid to the game's connection set
The connection should be tracked in the game's connection set
for broadcasting updates to all participants.
"""
sid = "test-sid"
game_id = "game-123"
mock_redis.exists.return_value = True
mock_redis.hget.return_value = "" # Not in a game yet
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.join_game(sid, game_id)
result = await manager.join_game(sid, game_id)
assert result is True
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", game_id)
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
assert result is True
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", game_id)
@pytest.mark.asyncio
async def test_join_game_returns_false_for_unknown_connection(
@ -304,18 +301,13 @@ class TestGameAssociation:
) -> None:
"""Test that joining a game fails for unknown connections.
If the connection doesn't exist, we should return False rather
than creating orphan game association data.
Non-existent connections should not be able to join games.
"""
mock_redis.exists.return_value = False
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.join_game("unknown-sid", "game-123")
result = await manager.join_game("unknown-sid", "game-123")
assert result is False
mock_redis.sadd.assert_not_called()
assert result is False
@pytest.mark.asyncio
async def test_join_game_leaves_previous_game(
@ -325,25 +317,22 @@ class TestGameAssociation:
) -> None:
"""Test that joining a new game leaves the previous game.
When switching games, the connection should be removed from the
old game's connection set before joining the new one.
A connection can only be in one game at a time, so joining
a new game should automatically leave any previous game.
"""
sid = "test-sid"
old_game = "game-old"
new_game = "game-new"
old_game = "old-game"
new_game = "new-game"
mock_redis.exists.return_value = True
mock_redis.hget.return_value = old_game # Currently in old game
mock_redis.hget.return_value = old_game
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
await manager.join_game(sid, new_game)
await manager.join_game(sid, new_game)
# Verify left old game
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{old_game}", sid)
# Verify joined new game
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{new_game}", sid)
# Should remove from old game
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{old_game}", sid)
# Should add to new game
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{new_game}", sid)
@pytest.mark.asyncio
async def test_leave_game_removes_from_set(
@ -351,25 +340,21 @@ class TestGameAssociation:
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that leaving a game removes the connection from the set.
"""Test that leaving a game removes connection from set.
Leave should:
1. Remove sid from game's connection set
2. Clear game_id on connection record
The connection should be removed from the game's set and
its game_id should be cleared.
"""
sid = "test-sid"
game_id = "game-123"
mock_redis.hget.return_value = game_id
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.leave_game(sid)
result = await manager.leave_game(sid)
assert result == game_id
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "")
assert result == game_id
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "")
@pytest.mark.asyncio
async def test_leave_game_returns_none_when_not_in_game(
@ -377,20 +362,16 @@ class TestGameAssociation:
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that leave_game returns None when not in a game.
"""Test that leaving when not in a game returns None.
If the connection isn't associated with any game, we should
return None without making unnecessary Redis calls.
If the connection isn't in a game, leave_game should return
None to indicate no game was left.
"""
mock_redis.hget.return_value = "" # No game_id
mock_redis.hget.return_value = ""
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.leave_game("test-sid")
result = await manager.leave_game("test-sid")
assert result is None
mock_redis.srem.assert_not_called()
assert result is None
class TestHeartbeat:
@ -404,25 +385,18 @@ class TestHeartbeat:
) -> None:
"""Test that heartbeat updates the last_seen timestamp.
Heartbeats keep the connection alive by updating the timestamp
and refreshing TTLs on Redis records.
Heartbeats keep connections alive by updating timestamps
and refreshing TTLs on Redis keys.
"""
sid = "test-sid"
mock_redis.exists.return_value = True
mock_redis.hget.return_value = "user-123"
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.update_heartbeat(sid)
result = await manager.update_heartbeat(sid)
assert result is True
# Verify last_seen was updated
hset_call = mock_redis.hset.call_args
assert hset_call.args[0] == f"{CONN_PREFIX}{sid}"
assert hset_call.args[1] == "last_seen"
# Verify TTL was refreshed
mock_redis.expire.assert_called()
assert result is True
mock_redis.hset.assert_called()
mock_redis.expire.assert_called()
@pytest.mark.asyncio
async def test_update_heartbeat_returns_false_for_unknown(
@ -430,18 +404,16 @@ class TestHeartbeat:
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that heartbeat returns False for unknown connections.
"""Test that heartbeat fails for unknown connections.
If the connection doesn't exist, we shouldn't try to update it.
Unknown connections should not be able to send heartbeats,
returning False to indicate failure.
"""
mock_redis.exists.return_value = False
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.update_heartbeat("unknown-sid")
result = await manager.update_heartbeat("unknown-sid")
assert result is False
assert result is False
class TestQueryMethods:
@ -453,13 +425,13 @@ class TestQueryMethods:
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that get_connection returns ConnectionInfo for valid sid.
"""Test that get_connection returns ConnectionInfo.
The returned ConnectionInfo should have all fields populated
from the Redis hash.
The returned info should contain all the stored connection data.
"""
sid = "test-sid"
now = datetime.now(UTC)
mock_redis.hgetall.return_value = {
"user_id": "user-123",
"game_id": "game-456",
@ -467,15 +439,12 @@ class TestQueryMethods:
"last_seen": now.isoformat(),
}
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.get_connection(sid)
result = await manager.get_connection(sid)
assert result is not None
assert result.sid == sid
assert result.user_id == "user-123"
assert result.game_id == "game-456"
assert result is not None
assert result.sid == sid
assert result.user_id == "user-123"
assert result.game_id == "game-456"
@pytest.mark.asyncio
async def test_get_connection_returns_none_for_unknown(
@ -483,15 +452,15 @@ class TestQueryMethods:
manager: ConnectionManager,
mock_redis: AsyncMock,
) -> None:
"""Test that get_connection returns None for unknown sid."""
"""Test that get_connection returns None for unknown sid.
Non-existent connections should return None rather than raising.
"""
mock_redis.hgetall.return_value = {}
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.get_connection("unknown-sid")
result = await manager.get_connection("unknown-sid")
assert result is None
assert result is None
@pytest.mark.asyncio
async def test_is_user_online_returns_true_for_connected_user(
@ -501,7 +470,7 @@ class TestQueryMethods:
) -> None:
"""Test that is_user_online returns True for connected users.
A user is online if they have an active connection that isn't stale.
Users with active, non-stale connections should be considered online.
"""
user_id = "user-123"
now = datetime.now(UTC)
@ -514,12 +483,9 @@ class TestQueryMethods:
"last_seen": now.isoformat(),
}
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.is_user_online(user_id)
result = await manager.is_user_online(user_id)
assert result is True
assert result is True
@pytest.mark.asyncio
async def test_is_user_online_returns_false_for_stale_connection(
@ -529,11 +495,11 @@ class TestQueryMethods:
) -> None:
"""Test that is_user_online returns False for stale connections.
Even if a connection record exists, if it's stale (no recent heartbeat),
the user should be considered offline.
Users with stale connections (no recent heartbeat) should be
considered offline even if their connection record exists.
"""
user_id = "user-123"
old_time = datetime.now(UTC) - timedelta(minutes=5)
old_time = datetime.now(UTC) - timedelta(seconds=HEARTBEAT_INTERVAL_SECONDS * 4)
mock_redis.get.return_value = "test-sid"
mock_redis.hgetall.return_value = {
@ -543,12 +509,9 @@ class TestQueryMethods:
"last_seen": old_time.isoformat(),
}
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.is_user_online(user_id)
result = await manager.is_user_online(user_id)
assert result is False
assert result is False
@pytest.mark.asyncio
async def test_get_game_connections_returns_all_participants(
@ -564,29 +527,32 @@ class TestQueryMethods:
now = datetime.now(UTC)
mock_redis.smembers.return_value = {"sid-1", "sid-2"}
mock_redis.hgetall.side_effect = [
{
"user_id": "user-1",
"game_id": game_id,
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
},
{
"user_id": "user-2",
"game_id": game_id,
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
},
]
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
# Use a function to return the right data based on the key queried
def hgetall_by_key(key: str) -> dict[str, str]:
if key == f"{CONN_PREFIX}sid-1":
return {
"user_id": "user-1",
"game_id": game_id,
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
}
elif key == f"{CONN_PREFIX}sid-2":
return {
"user_id": "user-2",
"game_id": game_id,
"connected_at": now.isoformat(),
"last_seen": now.isoformat(),
}
return {}
result = await manager.get_game_connections(game_id)
mock_redis.hgetall.side_effect = hgetall_by_key
assert len(result) == 2
user_ids = {conn.user_id for conn in result}
assert user_ids == {"user-1", "user-2"}
result = await manager.get_game_connections(game_id)
assert len(result) == 2
user_ids = {conn.user_id for conn in result}
assert user_ids == {"user-1", "user-2"}
@pytest.mark.asyncio
async def test_get_opponent_sid_returns_other_player(
@ -628,38 +594,38 @@ class TestQueryMethods:
mock_redis.hgetall.side_effect = hgetall_by_key
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
mock_get_redis.return_value.__aenter__.return_value = mock_redis
result = await manager.get_opponent_sid(game_id, current_user)
result = await manager.get_opponent_sid(game_id, current_user)
assert result == opponent_sid
assert result == opponent_sid
class TestKeyGeneration:
"""Tests for Redis key generation methods."""
def test_conn_key_format(self, manager: ConnectionManager) -> None:
def test_conn_key_format(self) -> None:
"""Test that connection keys have correct format.
Keys should follow the pattern conn:{sid} for easy identification
and pattern matching.
"""
manager = ConnectionManager()
key = manager._conn_key("test-sid-123")
assert key == "conn:test-sid-123"
def test_user_conn_key_format(self, manager: ConnectionManager) -> None:
def test_user_conn_key_format(self) -> None:
"""Test that user connection keys have correct format.
Keys should follow the pattern user_conn:{user_id}.
"""
manager = ConnectionManager()
key = manager._user_conn_key("user-456")
assert key == "user_conn:user-456"
def test_game_conns_key_format(self, manager: ConnectionManager) -> None:
def test_game_conns_key_format(self) -> None:
"""Test that game connections keys have correct format.
Keys should follow the pattern game_conns:{game_id}.
"""
manager = ConnectionManager()
key = manager._game_conns_key("game-789")
assert key == "game_conns:game-789"

View File

@ -7,9 +7,11 @@ The GameService is STATELESS regarding game rules:
- No GameEngine is stored in the service
- Engine is created per-operation using rules from GameState
- Rules come from frontend at game creation, stored in GameState
Uses dependency injection pattern - no monkey patching required.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
@ -53,19 +55,35 @@ def mock_card_service() -> MagicMock:
return MagicMock()
@pytest.fixture
def mock_engine() -> MagicMock:
"""Create a mock GameEngine.
The engine handles game logic - executing actions and checking
win conditions.
"""
engine = MagicMock()
engine.execute_action = AsyncMock(
return_value=ActionResult(success=True, message="Action executed")
)
return engine
@pytest.fixture
def game_service(
mock_state_manager: AsyncMock,
mock_card_service: MagicMock,
mock_engine: MagicMock,
) -> GameService:
"""Create a GameService with mocked dependencies.
"""Create a GameService with injected mock dependencies.
Note: No engine is passed - GameService creates engines per-operation
using rules stored in each game's GameState.
The mock engine factory is injected via the engine_factory parameter,
eliminating the need for monkey patching in tests.
"""
return GameService(
state_manager=mock_state_manager,
card_service=mock_card_service,
engine_factory=lambda game: mock_engine,
)
@ -324,9 +342,8 @@ class TestJoinGame:
class TestExecuteAction:
"""Tests for the execute_action method.
These tests verify action execution through GameService. Since GameService
creates engines per-operation, we patch _create_engine_for_game to return
a mock engine with controlled behavior.
These tests verify action execution through GameService. The mock engine
is injected via the engine_factory parameter in the game_service fixture.
"""
@pytest.mark.asyncio
@ -334,6 +351,7 @@ class TestExecuteAction:
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test successful action execution.
@ -342,19 +360,14 @@ class TestExecuteAction:
the state should be saved to cache.
"""
mock_state_manager.load_state.return_value = sample_game_state
mock_engine = MagicMock()
mock_engine.execute_action = AsyncMock(
return_value=ActionResult(
success=True,
message="Attack executed",
state_changes=[{"type": "damage", "amount": 30}],
)
mock_engine.execute_action.return_value = ActionResult(
success=True,
message="Attack executed",
state_changes=[{"type": "damage", "amount": 30}],
)
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
action = AttackAction(attack_index=0)
result = await game_service.execute_action("game-123", "player-1", action)
action = AttackAction(attack_index=0)
result = await game_service.execute_action("game-123", "player-1", action)
assert result.success is True
assert result.action_type == "attack"
@ -434,6 +447,7 @@ class TestExecuteAction:
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test that resignation is allowed even when not your turn.
@ -442,24 +456,19 @@ class TestExecuteAction:
by either player.
"""
mock_state_manager.load_state.return_value = sample_game_state
mock_engine = MagicMock()
mock_engine.execute_action = AsyncMock(
return_value=ActionResult(
success=True,
message="Player resigned",
win_result=WinResult(
winner_id="player-1",
loser_id="player-2",
end_reason=GameEndReason.RESIGNATION,
reason="Player resigned",
),
)
mock_engine.execute_action.return_value = ActionResult(
success=True,
message="Player resigned",
win_result=WinResult(
winner_id="player-1",
loser_id="player-2",
end_reason=GameEndReason.RESIGNATION,
reason="Player resigned",
),
)
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
# player-2 resigns even though it's player-1's turn
result = await game_service.execute_action("game-123", "player-2", ResignAction())
# player-2 resigns even though it's player-1's turn
result = await game_service.execute_action("game-123", "player-2", ResignAction())
assert result.success is True
assert result.game_over is True
@ -470,6 +479,7 @@ class TestExecuteAction:
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test execute_action raises error for invalid actions.
@ -478,19 +488,12 @@ class TestExecuteAction:
InvalidActionError with the reason.
"""
mock_state_manager.load_state.return_value = sample_game_state
mock_engine = MagicMock()
mock_engine.execute_action = AsyncMock(
return_value=ActionResult(
success=False,
message="Not enough energy to attack",
)
mock_engine.execute_action.return_value = ActionResult(
success=False,
message="Not enough energy to attack",
)
with (
patch.object(game_service, "_create_engine_for_game", return_value=mock_engine),
pytest.raises(InvalidActionError) as exc_info,
):
with pytest.raises(InvalidActionError) as exc_info:
await game_service.execute_action("game-123", "player-1", AttackAction(attack_index=0))
assert "Not enough energy" in exc_info.value.reason
@ -500,6 +503,7 @@ class TestExecuteAction:
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test execute_action detects game over and persists to DB.
@ -508,58 +512,26 @@ class TestExecuteAction:
to the database for durability.
"""
mock_state_manager.load_state.return_value = sample_game_state
mock_engine = MagicMock()
mock_engine.execute_action = AsyncMock(
return_value=ActionResult(
success=True,
message="Final prize taken!",
win_result=WinResult(
winner_id="player-1",
loser_id="player-2",
end_reason=GameEndReason.PRIZES_TAKEN,
reason="All prizes taken",
),
)
mock_engine.execute_action.return_value = ActionResult(
success=True,
message="Final prize taken!",
win_result=WinResult(
winner_id="player-1",
loser_id="player-2",
end_reason=GameEndReason.PRIZES_TAKEN,
reason="All prizes taken",
),
)
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
result = await game_service.execute_action(
"game-123", "player-1", AttackAction(attack_index=0)
)
result = await game_service.execute_action(
"game-123", "player-1", AttackAction(attack_index=0)
)
assert result.game_over is True
assert result.winner_id == "player-1"
assert result.end_reason == GameEndReason.PRIZES_TAKEN
mock_state_manager.persist_to_db.assert_called_once()
@pytest.mark.asyncio
async def test_execute_action_uses_game_rules(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
sample_game_state: GameState,
) -> None:
"""Test that execute_action creates engine with game's rules.
The engine should be created using the rules stored in the game
state, not any service-level defaults.
"""
mock_state_manager.load_state.return_value = sample_game_state
mock_engine = MagicMock()
mock_engine.execute_action = AsyncMock(
return_value=ActionResult(success=True, message="OK")
)
with patch.object(
game_service, "_create_engine_for_game", return_value=mock_engine
) as mock_create:
await game_service.execute_action("game-123", "player-1", PassAction())
# Verify engine was created with the game state
mock_create.assert_called_once_with(sample_game_state)
class TestResignGame:
"""Tests for the resign_game convenience method."""
@ -569,6 +541,7 @@ class TestResignGame:
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_engine: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test that resign_game is a convenience wrapper for execute_action.
@ -577,23 +550,18 @@ class TestResignGame:
and call execute_action.
"""
mock_state_manager.load_state.return_value = sample_game_state
mock_engine = MagicMock()
mock_engine.execute_action = AsyncMock(
return_value=ActionResult(
success=True,
message="Player resigned",
win_result=WinResult(
winner_id="player-2",
loser_id="player-1",
end_reason=GameEndReason.RESIGNATION,
reason="Player resigned",
),
)
mock_engine.execute_action.return_value = ActionResult(
success=True,
message="Player resigned",
win_result=WinResult(
winner_id="player-2",
loser_id="player-1",
end_reason=GameEndReason.RESIGNATION,
reason="Player resigned",
),
)
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
result = await game_service.resign_game("game-123", "player-1")
result = await game_service.resign_game("game-123", "player-1")
assert result.success is True
assert result.action_type == "resign"
@ -692,31 +660,39 @@ class TestCreateGame:
assert "GS-002" in str(exc_info.value)
class TestCreateEngineForGame:
"""Tests for the _create_engine_for_game method.
class TestDefaultEngineFactory:
"""Tests for the _default_engine_factory method.
This method is responsible for creating a GameEngine configured
with the rules from a specific game's state.
"""
def test_create_engine_uses_game_rules(
def test_default_engine_uses_game_rules(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_card_service: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test that engine is created with the game's rules.
"""Test that default engine is created with the game's rules.
The engine should use the RulesConfig stored in the game state,
not any default configuration.
"""
engine = game_service._create_engine_for_game(sample_game_state)
# Create service without mock engine to test default factory
service = GameService(
state_manager=mock_state_manager,
card_service=mock_card_service,
)
engine = service._default_engine_factory(sample_game_state)
# Engine should have the game's rules
assert engine.rules == sample_game_state.rules
def test_create_engine_with_rng_seed(
def test_default_engine_with_rng_seed(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_card_service: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test that engine uses seeded RNG when game has rng_seed.
@ -724,17 +700,21 @@ class TestCreateEngineForGame:
When a game has an rng_seed set, the engine should use a
deterministic RNG for replay support.
"""
service = GameService(
state_manager=mock_state_manager,
card_service=mock_card_service,
)
sample_game_state.rng_seed = 12345
engine = game_service._create_engine_for_game(sample_game_state)
engine = service._default_engine_factory(sample_game_state)
# Engine should have been created (we can't easily verify seed,
# but we can verify it doesn't error)
# Engine should have been created successfully
assert engine is not None
def test_create_engine_without_rng_seed(
def test_default_engine_without_rng_seed(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_card_service: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test that engine uses secure RNG when no seed is set.
@ -742,15 +722,20 @@ class TestCreateEngineForGame:
Without an rng_seed, the engine should use cryptographically
secure random number generation.
"""
service = GameService(
state_manager=mock_state_manager,
card_service=mock_card_service,
)
sample_game_state.rng_seed = None
engine = game_service._create_engine_for_game(sample_game_state)
engine = service._default_engine_factory(sample_game_state)
assert engine is not None
def test_create_engine_derives_unique_seed_per_action(
def test_default_engine_derives_unique_seed_per_action(
self,
game_service: GameService,
mock_state_manager: AsyncMock,
mock_card_service: MagicMock,
sample_game_state: GameState,
) -> None:
"""Test that different action counts produce different RNG sequences.
@ -758,15 +743,19 @@ class TestCreateEngineForGame:
For deterministic replay, each action needs a unique but
reproducible RNG seed based on game seed + action count.
"""
service = GameService(
state_manager=mock_state_manager,
card_service=mock_card_service,
)
sample_game_state.rng_seed = 12345
# Simulate first action (action_log is empty)
sample_game_state.action_log = []
engine1 = game_service._create_engine_for_game(sample_game_state)
engine1 = service._default_engine_factory(sample_game_state)
# Simulate second action (one action in log)
sample_game_state.action_log = [{"type": "pass"}]
engine2 = game_service._create_engine_for_game(sample_game_state)
engine2 = service._default_engine_factory(sample_game_state)
# Both engines should be created successfully
# (They will have different seeds due to action count)