Refactor to dependency injection pattern - no monkey patching
- ConnectionManager: Add redis_factory constructor parameter - GameService: Add engine_factory constructor parameter - AuthHandler: New class replacing standalone functions with token_verifier and conn_manager injection - Update all tests to use constructor DI instead of patch() - Update CLAUDE.md with factory injection patterns - Update services README with new patterns - Add socketio README documenting AuthHandler and events Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
0c810e5b30
commit
f512c7b2b3
@ -78,11 +78,11 @@ def validate_deck(cards, config: DeckConfig | None = None):
|
||||
|
||||
## Dependency Injection for Services
|
||||
|
||||
**Services use constructor-based dependency injection for repositories and other services.**
|
||||
**Services use constructor-based dependency injection for repositories, services, and external resources.**
|
||||
|
||||
Services still use DI for data access, but config comes from method parameters (request).
|
||||
Services use DI for data access and external dependencies, but config comes from method parameters (request).
|
||||
|
||||
### Required Pattern
|
||||
### Required Pattern: Services and Repositories
|
||||
|
||||
```python
|
||||
class DeckService:
|
||||
@ -109,12 +109,78 @@ class DeckService:
|
||||
...
|
||||
```
|
||||
|
||||
### Required Pattern: Factory Injection for External Resources
|
||||
|
||||
For external resources (Redis, databases, engines), inject factory functions rather than instances:
|
||||
|
||||
```python
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
|
||||
# Type aliases for factories
|
||||
RedisFactory = Callable[[], AsyncIterator["Redis"]]
|
||||
EngineFactory = Callable[[GameState], GameEngine]
|
||||
TokenVerifier = Callable[[str], UUID | None]
|
||||
|
||||
class ConnectionManager:
|
||||
"""External resources injected as factories."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn_ttl_seconds: int = 3600,
|
||||
redis_factory: RedisFactory | None = None, # Factory, not instance
|
||||
) -> None:
|
||||
self.conn_ttl_seconds = conn_ttl_seconds
|
||||
# Fall back to global only if not provided
|
||||
self._get_redis = redis_factory if redis_factory is not None else get_redis
|
||||
|
||||
async def register_connection(self, sid: str, user_id: UUID) -> None:
|
||||
async with self._get_redis() as redis: # Use injected factory
|
||||
await redis.hset(...)
|
||||
```
|
||||
|
||||
```python
|
||||
class GameService:
|
||||
"""Engine created via injected factory for testability."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_manager: GameStateManager | None = None,
|
||||
engine_factory: EngineFactory | None = None, # Factory for engine creation
|
||||
) -> None:
|
||||
self._state_manager = state_manager or game_state_manager
|
||||
self._engine_factory = engine_factory or self._default_engine_factory
|
||||
|
||||
async def execute_action(self, game_id: str, player_id: str, action: Action):
|
||||
state = await self.get_game_state(game_id)
|
||||
engine = self._engine_factory(state) # Create via factory
|
||||
result = await engine.execute_action(state, player_id, action)
|
||||
```
|
||||
|
||||
### Global Singletons with DI Support
|
||||
|
||||
Create global instances for production, but always support injection for testing:
|
||||
|
||||
```python
|
||||
# Global singleton for production use
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
# Tests inject mocks via constructor - no patching needed
|
||||
def test_register_connection(mock_redis):
|
||||
@asynccontextmanager
|
||||
async def mock_redis_factory():
|
||||
yield mock_redis
|
||||
|
||||
manager = ConnectionManager(redis_factory=mock_redis_factory)
|
||||
# Test with injected mock...
|
||||
```
|
||||
|
||||
### Why This Matters
|
||||
|
||||
1. **Testability**: Dependencies can be mocked without patching globals
|
||||
2. **Offline Fork**: Services can be swapped for local implementations
|
||||
3. **Explicit Dependencies**: Constructor shows all requirements
|
||||
4. **Stateless Operations**: Config comes from request, not server state
|
||||
5. **No Monkey Patching**: Tests use constructor injection, not `patch()`
|
||||
|
||||
---
|
||||
|
||||
@ -253,6 +319,55 @@ def test_paralyzed_pokemon_cannot_attack():
|
||||
| `tests/core/` | Core engine tests | No |
|
||||
| `tests/api/` | API endpoint tests | Yes |
|
||||
|
||||
### No Monkey Patching - Use Dependency Injection
|
||||
|
||||
**Never use `patch()` or `monkeypatch` for unit tests.** Instead, inject mocks via constructor.
|
||||
|
||||
```python
|
||||
# WRONG - Monkey patching
|
||||
async def test_execute_action():
|
||||
with patch.object(game_service, "_create_engine") as mock:
|
||||
mock.return_value = mock_engine
|
||||
result = await game_service.execute_action(...)
|
||||
|
||||
# CORRECT - Dependency injection
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
engine = MagicMock()
|
||||
engine.execute_action = AsyncMock(return_value=ActionResult(success=True))
|
||||
return engine
|
||||
|
||||
@pytest.fixture
|
||||
def game_service(mock_state_manager, mock_engine):
|
||||
return GameService(
|
||||
state_manager=mock_state_manager,
|
||||
engine_factory=lambda state: mock_engine, # Inject via constructor
|
||||
)
|
||||
|
||||
async def test_execute_action(game_service, mock_engine):
|
||||
result = await game_service.execute_action(...)
|
||||
mock_engine.execute_action.assert_called_once()
|
||||
```
|
||||
|
||||
### Factory Fixtures for External Resources
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
redis = AsyncMock()
|
||||
redis.hset = AsyncMock()
|
||||
redis.hget = AsyncMock(return_value=None)
|
||||
return redis
|
||||
|
||||
@pytest.fixture
|
||||
def connection_manager(mock_redis):
|
||||
@asynccontextmanager
|
||||
async def mock_redis_factory():
|
||||
yield mock_redis
|
||||
|
||||
return ConnectionManager(redis_factory=mock_redis_factory)
|
||||
```
|
||||
|
||||
### Use Seeded RNG for Determinism
|
||||
|
||||
```python
|
||||
@ -276,6 +391,12 @@ app/
|
||||
│ ├── config.py # RulesConfig and sub-configs
|
||||
│ └── engine.py # GameEngine orchestrator
|
||||
├── services/ # Business logic (uses repositories)
|
||||
│ ├── game_service.py # Game orchestration (uses engine_factory DI)
|
||||
│ ├── connection_manager.py # WebSocket tracking (uses redis_factory DI)
|
||||
│ └── ...
|
||||
├── socketio/ # WebSocket real-time communication
|
||||
│ ├── server.py # Socket.IO server setup and handlers
|
||||
│ └── auth.py # AuthHandler class (uses token_verifier DI)
|
||||
├── repositories/ # Data access layer
|
||||
│ ├── protocols.py # Repository protocols (interfaces)
|
||||
│ └── postgres/ # PostgreSQL implementations
|
||||
@ -322,6 +443,9 @@ app/
|
||||
| Hardcoded magic numbers | Use values from config parameter |
|
||||
| Tests without docstrings | Always explain what and why |
|
||||
| Unit tests in `tests/services/` | Use `tests/unit/` for no-DB tests |
|
||||
| Using `patch()` in unit tests | Inject mocks via constructor DI |
|
||||
| `async with get_redis()` in method body | Inject `redis_factory` via constructor |
|
||||
| Importing dependencies inside functions | Import at module level, inject via constructor |
|
||||
|
||||
---
|
||||
|
||||
@ -330,3 +454,5 @@ app/
|
||||
- `app/core/AGENTS.md` - Detailed core engine guidelines
|
||||
- `app/core/README.md` - Core module documentation
|
||||
- `app/core/effects/README.md` - Effect system documentation
|
||||
- `app/services/README.md` - Services layer patterns
|
||||
- `app/socketio/README.md` - WebSocket module documentation
|
||||
|
||||
@ -19,6 +19,8 @@ Business logic layer between API endpoints and data access.
|
||||
| `DeckValidator` | Pure deck validation functions | CardService (for lookups) |
|
||||
| `UserService` | User profile management | Direct DB access |
|
||||
| `GameStateManager` | Game state (Redis + Postgres) | Redis, AsyncSession |
|
||||
| `GameService` | Game orchestration + actions | GameStateManager, CardService, engine_factory |
|
||||
| `ConnectionManager` | WebSocket connection tracking | redis_factory |
|
||||
| `JWTService` | Token creation/verification | Settings |
|
||||
| `TokenStore` | Refresh token storage | Redis |
|
||||
|
||||
@ -93,6 +95,50 @@ Benefits:
|
||||
- **Offline Fork**: Swap PostgresRepository for LocalRepository
|
||||
- **Decoupling**: Services don't know about SQLAlchemy
|
||||
|
||||
### Factory Injection for External Resources
|
||||
|
||||
For external resources (Redis, GameEngine), inject factory functions rather than instances:
|
||||
|
||||
```python
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
|
||||
# Type aliases for factories
|
||||
RedisFactory = Callable[[], AsyncIterator["Redis"]]
|
||||
EngineFactory = Callable[[GameState], GameEngine]
|
||||
|
||||
class ConnectionManager:
|
||||
"""External resources injected as factories."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn_ttl_seconds: int = 3600,
|
||||
redis_factory: RedisFactory | None = None,
|
||||
) -> None:
|
||||
self.conn_ttl_seconds = conn_ttl_seconds
|
||||
self._get_redis = redis_factory if redis_factory is not None else get_redis
|
||||
|
||||
class GameService:
|
||||
"""Engine created via injected factory for testability."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_manager: GameStateManager | None = None,
|
||||
engine_factory: EngineFactory | None = None,
|
||||
) -> None:
|
||||
self._state_manager = state_manager or game_state_manager
|
||||
self._engine_factory = engine_factory or self._default_engine_factory
|
||||
```
|
||||
|
||||
Global singletons are created for production, but constructors always support injection:
|
||||
|
||||
```python
|
||||
# Production singleton
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
# Tests inject mocks - no patching needed
|
||||
manager = ConnectionManager(redis_factory=mock_redis_factory)
|
||||
```
|
||||
|
||||
### Pure Validation Functions
|
||||
|
||||
Validation logic is extracted into pure functions for reuse:
|
||||
@ -183,6 +229,27 @@ await manager.persist_to_db(game_id, game_state)
|
||||
| `tests/unit/services/` | Pure unit tests, mocked deps | No |
|
||||
| `tests/services/` | Integration tests | Yes (testcontainers) |
|
||||
|
||||
**No monkey patching** - Always use constructor DI for mocks:
|
||||
|
||||
```python
|
||||
# WRONG - monkey patching
|
||||
async def test_something():
|
||||
with patch.object(service, "_method") as mock:
|
||||
...
|
||||
|
||||
# CORRECT - constructor injection
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def connection_manager(mock_redis):
|
||||
@asynccontextmanager
|
||||
async def mock_redis_factory():
|
||||
yield mock_redis
|
||||
return ConnectionManager(redis_factory=mock_redis_factory)
|
||||
```
|
||||
|
||||
### Unit Test Example
|
||||
|
||||
```python
|
||||
|
||||
@ -35,14 +35,22 @@ Example:
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from app.db.redis import get_redis
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for redis factory - a callable that returns an async context manager
|
||||
RedisFactory = Callable[[], AsyncIterator["Redis"]]
|
||||
|
||||
# Redis key patterns
|
||||
CONN_PREFIX = "conn:"
|
||||
USER_CONN_PREFIX = "user_conn:"
|
||||
@ -98,13 +106,20 @@ class ConnectionManager:
|
||||
conn_ttl_seconds: TTL for connection records in Redis.
|
||||
"""
|
||||
|
||||
def __init__(self, conn_ttl_seconds: int = DEFAULT_CONN_TTL_SECONDS) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
conn_ttl_seconds: int = DEFAULT_CONN_TTL_SECONDS,
|
||||
redis_factory: RedisFactory | None = None,
|
||||
) -> None:
|
||||
"""Initialize the ConnectionManager.
|
||||
|
||||
Args:
|
||||
conn_ttl_seconds: How long to keep connection records in Redis.
|
||||
redis_factory: Optional factory for Redis connections. If not provided,
|
||||
uses the default get_redis from app.db.redis. Useful for testing.
|
||||
"""
|
||||
self.conn_ttl_seconds = conn_ttl_seconds
|
||||
self._get_redis = redis_factory if redis_factory is not None else get_redis
|
||||
|
||||
def _conn_key(self, sid: str) -> str:
|
||||
"""Generate Redis key for a connection."""
|
||||
@ -142,7 +157,7 @@ class ConnectionManager:
|
||||
user_id_str = str(user_id)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
|
||||
async with get_redis() as redis:
|
||||
async with self._get_redis() as redis:
|
||||
# Check for existing connection and clean it up
|
||||
old_sid = await redis.get(self._user_conn_key(user_id_str))
|
||||
if old_sid and old_sid != sid:
|
||||
@ -204,7 +219,7 @@ class ConnectionManager:
|
||||
Args:
|
||||
sid: Socket.IO session ID.
|
||||
"""
|
||||
async with get_redis() as redis:
|
||||
async with self._get_redis() as redis:
|
||||
conn_key = self._conn_key(sid)
|
||||
|
||||
# Get connection data for cleanup
|
||||
@ -250,7 +265,7 @@ class ConnectionManager:
|
||||
Example:
|
||||
await manager.join_game("abc123", "game-456")
|
||||
"""
|
||||
async with get_redis() as redis:
|
||||
async with self._get_redis() as redis:
|
||||
conn_key = self._conn_key(sid)
|
||||
|
||||
# Check connection exists
|
||||
@ -291,7 +306,7 @@ class ConnectionManager:
|
||||
Example:
|
||||
game_id = await manager.leave_game("abc123")
|
||||
"""
|
||||
async with get_redis() as redis:
|
||||
async with self._get_redis() as redis:
|
||||
conn_key = self._conn_key(sid)
|
||||
|
||||
# Get current game
|
||||
@ -330,7 +345,7 @@ class ConnectionManager:
|
||||
"""
|
||||
now = datetime.now(UTC).isoformat()
|
||||
|
||||
async with get_redis() as redis:
|
||||
async with self._get_redis() as redis:
|
||||
conn_key = self._conn_key(sid)
|
||||
|
||||
# Check exists and update atomically
|
||||
@ -370,7 +385,7 @@ class ConnectionManager:
|
||||
if info:
|
||||
print(f"User {info.user_id} connected at {info.connected_at}")
|
||||
"""
|
||||
async with get_redis() as redis:
|
||||
async with self._get_redis() as redis:
|
||||
conn_key = self._conn_key(sid)
|
||||
data = await redis.hgetall(conn_key)
|
||||
|
||||
@ -401,7 +416,7 @@ class ConnectionManager:
|
||||
"""
|
||||
user_id_str = str(user_id)
|
||||
|
||||
async with get_redis() as redis:
|
||||
async with self._get_redis() as redis:
|
||||
user_conn_key = self._user_conn_key(user_id_str)
|
||||
sid = await redis.get(user_conn_key)
|
||||
|
||||
@ -448,7 +463,7 @@ class ConnectionManager:
|
||||
for conn in connections:
|
||||
# Send state update to each participant
|
||||
"""
|
||||
async with get_redis() as redis:
|
||||
async with self._get_redis() as redis:
|
||||
game_conns_key = self._game_conns_key(game_id)
|
||||
sids = await redis.smembers(game_conns_key)
|
||||
|
||||
@ -530,7 +545,7 @@ class ConnectionManager:
|
||||
logger.info(f"Cleaned up {count} stale connections")
|
||||
"""
|
||||
count = 0
|
||||
async with get_redis() as redis:
|
||||
async with self._get_redis() as redis:
|
||||
# Scan for all connection keys
|
||||
async for key in redis.scan_iter(match=f"{CONN_PREFIX}*"):
|
||||
sid = key[len(CONN_PREFIX) :]
|
||||
@ -555,7 +570,7 @@ class ConnectionManager:
|
||||
Number of active connections.
|
||||
"""
|
||||
count = 0
|
||||
async with get_redis() as redis:
|
||||
async with self._get_redis() as redis:
|
||||
async for _ in redis.scan_iter(match=f"{CONN_PREFIX}*"):
|
||||
count += 1
|
||||
return count
|
||||
@ -569,7 +584,7 @@ class ConnectionManager:
|
||||
Returns:
|
||||
Number of connected participants.
|
||||
"""
|
||||
async with get_redis() as redis:
|
||||
async with self._get_redis() as redis:
|
||||
game_conns_key = self._game_conns_key(game_id)
|
||||
return await redis.scard(game_conns_key)
|
||||
|
||||
|
||||
@ -36,6 +36,7 @@ Example:
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
@ -51,6 +52,9 @@ from app.services.game_state_manager import GameStateManager, game_state_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for engine factory - takes GameState, returns GameEngine
|
||||
EngineFactory = Callable[[GameState], GameEngine]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Exceptions
|
||||
@ -180,12 +184,14 @@ class GameService:
|
||||
Attributes:
|
||||
_state_manager: GameStateManager for persistence.
|
||||
_card_service: CardService for card definitions.
|
||||
_engine_factory: Factory function to create GameEngine instances.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_manager: GameStateManager | None = None,
|
||||
card_service: CardService | None = None,
|
||||
engine_factory: EngineFactory | None = None,
|
||||
) -> None:
|
||||
"""Initialize the GameService.
|
||||
|
||||
@ -196,12 +202,16 @@ class GameService:
|
||||
Args:
|
||||
state_manager: GameStateManager instance. Uses global if not provided.
|
||||
card_service: CardService instance. Uses global if not provided.
|
||||
engine_factory: Optional factory for creating GameEngine instances.
|
||||
If not provided, uses the default _default_engine_factory method.
|
||||
Useful for testing with mock engines.
|
||||
"""
|
||||
self._state_manager = state_manager or game_state_manager
|
||||
self._card_service = card_service or get_card_service()
|
||||
self._engine_factory = engine_factory or self._default_engine_factory
|
||||
|
||||
def _create_engine_for_game(self, game: GameState) -> GameEngine:
|
||||
"""Create a GameEngine configured for a specific game's rules.
|
||||
def _default_engine_factory(self, game: GameState) -> GameEngine:
|
||||
"""Default factory for creating a GameEngine from game state.
|
||||
|
||||
The engine is created on-demand using the rules stored in the
|
||||
game state. This ensures each game uses its own configuration.
|
||||
@ -412,8 +422,8 @@ class GameService:
|
||||
if not isinstance(action, ResignAction) and state.current_player_id != player_id:
|
||||
raise NotPlayerTurnError(game_id, player_id, state.current_player_id)
|
||||
|
||||
# Create engine with this game's rules
|
||||
engine = self._create_engine_for_game(state)
|
||||
# Create engine with this game's rules via factory
|
||||
engine = self._engine_factory(state)
|
||||
|
||||
# Execute the action
|
||||
result: ActionResult = await engine.execute_action(state, player_id, action)
|
||||
|
||||
120
backend/app/socketio/README.md
Normal file
120
backend/app/socketio/README.md
Normal file
@ -0,0 +1,120 @@
|
||||
# Socket.IO Module
|
||||
|
||||
Real-time WebSocket communication for active game sessions.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Client <--WebSocket--> Socket.IO Server <--> GameService <--> GameEngine
|
||||
|
|
||||
v
|
||||
ConnectionManager <--> Redis (presence tracking)
|
||||
```
|
||||
|
||||
## Components
|
||||
|
||||
| Component | Responsibility | DI Pattern |
|
||||
|-----------|---------------|------------|
|
||||
| `server.py` | Socket.IO server setup, event handlers | Uses `auth_handler` |
|
||||
| `auth.py` | `AuthHandler` class for JWT authentication | `token_verifier`, `conn_manager` |
|
||||
|
||||
## AuthHandler
|
||||
|
||||
Handles JWT-based authentication with injectable dependencies:
|
||||
|
||||
```python
|
||||
class AuthHandler:
|
||||
def __init__(
|
||||
self,
|
||||
token_verifier: TokenVerifier | None = None, # Callable[[str], UUID | None]
|
||||
conn_manager: ConnectionManager | None = None,
|
||||
) -> None:
|
||||
self._token_verifier = token_verifier or verify_access_token
|
||||
self._connection_manager = conn_manager or connection_manager
|
||||
|
||||
# Global singleton for production
|
||||
auth_handler = AuthHandler()
|
||||
|
||||
# Tests inject mocks via constructor
|
||||
handler = AuthHandler(
|
||||
token_verifier=lambda token: user_id,
|
||||
conn_manager=mock_connection_manager,
|
||||
)
|
||||
```
|
||||
|
||||
### Methods
|
||||
|
||||
| Method | Purpose |
|
||||
|--------|---------|
|
||||
| `authenticate_connection(sid, auth)` | Validate JWT, return `AuthResult` |
|
||||
| `setup_authenticated_session(sio, sid, user_id)` | Save session, register connection |
|
||||
| `cleanup_authenticated_session(sid)` | Unregister connection, return user_id |
|
||||
|
||||
## Event Handlers
|
||||
|
||||
All handlers are in the `/game` namespace:
|
||||
|
||||
| Event | Direction | Purpose |
|
||||
|-------|-----------|---------|
|
||||
| `connect` | Client→Server | Authenticate and establish session |
|
||||
| `disconnect` | Client→Server | Clean up session and notify opponent |
|
||||
| `game:join` | Client→Server | Join/rejoin a game session |
|
||||
| `game:action` | Client→Server | Execute a game action |
|
||||
| `game:resign` | Client→Server | Resign from game |
|
||||
| `game:heartbeat` | Client→Server | Keep connection alive |
|
||||
|
||||
## Client Connection
|
||||
|
||||
```javascript
|
||||
const socket = io("wss://api.example.com", {
|
||||
auth: { token: "JWT_ACCESS_TOKEN" }
|
||||
});
|
||||
|
||||
socket.on("connect", () => {
|
||||
socket.emit("game:join", { game_id: "uuid" });
|
||||
});
|
||||
|
||||
socket.on("game:state", (state) => {
|
||||
// Full game state update
|
||||
});
|
||||
|
||||
socket.on("auth_error", (error) => {
|
||||
// { code: "invalid_token", message: "..." }
|
||||
});
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Use dependency injection - no monkey patching:
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
def mock_token_verifier():
|
||||
return MagicMock(return_value=uuid4())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connection_manager():
|
||||
cm = AsyncMock()
|
||||
cm.register_connection = AsyncMock()
|
||||
return cm
|
||||
|
||||
@pytest.fixture
|
||||
def auth_handler(mock_token_verifier, mock_connection_manager):
|
||||
return AuthHandler(
|
||||
token_verifier=mock_token_verifier,
|
||||
conn_manager=mock_connection_manager,
|
||||
)
|
||||
|
||||
async def test_authenticate_success(auth_handler, mock_token_verifier):
|
||||
"""Test successful authentication with valid JWT."""
|
||||
result = await auth_handler.authenticate_connection("sid", {"token": "valid"})
|
||||
assert result.success
|
||||
mock_token_verifier.assert_called_once_with("valid")
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- `app/services/connection_manager.py` - Connection presence tracking
|
||||
- `app/services/game_service.py` - Game orchestration
|
||||
- `app/schemas/ws_messages.py` - WebSocket message schemas
|
||||
- `CLAUDE.md` - Architecture guidelines
|
||||
@ -23,12 +23,12 @@ Usage:
|
||||
"""
|
||||
|
||||
from app.socketio.auth import (
|
||||
AuthHandler,
|
||||
AuthResult,
|
||||
authenticate_connection,
|
||||
cleanup_authenticated_session,
|
||||
auth_handler,
|
||||
extract_token,
|
||||
get_session_user_id,
|
||||
require_auth,
|
||||
setup_authenticated_session,
|
||||
)
|
||||
from app.socketio.server import create_socketio_app, sio
|
||||
|
||||
@ -37,10 +37,10 @@ __all__ = [
|
||||
"create_socketio_app",
|
||||
"sio",
|
||||
# Auth
|
||||
"AuthHandler",
|
||||
"AuthResult",
|
||||
"authenticate_connection",
|
||||
"cleanup_authenticated_session",
|
||||
"auth_handler",
|
||||
"extract_token",
|
||||
"get_session_user_id",
|
||||
"require_auth",
|
||||
"setup_authenticated_session",
|
||||
]
|
||||
|
||||
@ -16,7 +16,7 @@ Session Data:
|
||||
|
||||
Example:
|
||||
# In connect handler:
|
||||
auth_result = await authenticate_connection(sid, auth)
|
||||
auth_result = await auth_handler.authenticate_connection(sid, auth)
|
||||
if not auth_result.success:
|
||||
return False # Reject connection
|
||||
|
||||
@ -26,15 +26,19 @@ Example:
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.connection_manager import connection_manager
|
||||
from app.services.connection_manager import ConnectionManager, connection_manager
|
||||
from app.services.jwt_service import verify_access_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for token verifier - takes token string, returns user_id or None
|
||||
TokenVerifier = Callable[[str], UUID | None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthResult:
|
||||
@ -53,6 +57,133 @@ class AuthResult:
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class AuthHandler:
|
||||
"""Handler for Socket.IO authentication with injectable dependencies.
|
||||
|
||||
This class encapsulates authentication logic with dependencies injected
|
||||
via constructor, making it testable without monkey patching.
|
||||
|
||||
Attributes:
|
||||
_token_verifier: Function to verify JWT tokens.
|
||||
_connection_manager: ConnectionManager for tracking connections.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_verifier: TokenVerifier | None = None,
|
||||
conn_manager: ConnectionManager | None = None,
|
||||
) -> None:
|
||||
"""Initialize AuthHandler with dependencies.
|
||||
|
||||
Args:
|
||||
token_verifier: Function to verify tokens. Uses jwt_service if not provided.
|
||||
conn_manager: ConnectionManager instance. Uses global if not provided.
|
||||
"""
|
||||
self._token_verifier = token_verifier or verify_access_token
|
||||
self._connection_manager = conn_manager or connection_manager
|
||||
|
||||
async def authenticate_connection(
|
||||
self,
|
||||
sid: str,
|
||||
auth: dict[str, object] | None,
|
||||
) -> AuthResult:
|
||||
"""Authenticate a Socket.IO connection using JWT.
|
||||
|
||||
Extracts the JWT from auth data, validates it, and returns the result.
|
||||
Does NOT modify socket session - caller should handle that.
|
||||
|
||||
Args:
|
||||
sid: Socket session ID (for logging).
|
||||
auth: Authentication data from connect event.
|
||||
|
||||
Returns:
|
||||
AuthResult with success status and user_id or error details.
|
||||
"""
|
||||
token = extract_token(auth)
|
||||
|
||||
if token is None:
|
||||
logger.debug(f"Connection {sid}: No token provided")
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_code="missing_token",
|
||||
error_message="Authentication token required",
|
||||
)
|
||||
|
||||
user_id = self._token_verifier(token)
|
||||
|
||||
if user_id is None:
|
||||
logger.debug(f"Connection {sid}: Invalid or expired token")
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_code="invalid_token",
|
||||
error_message="Invalid or expired token",
|
||||
)
|
||||
|
||||
logger.debug(f"Connection {sid}: Authenticated as user {user_id}")
|
||||
return AuthResult(
|
||||
success=True,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
async def setup_authenticated_session(
|
||||
self,
|
||||
sio: object,
|
||||
sid: str,
|
||||
user_id: UUID,
|
||||
namespace: str = "/game",
|
||||
) -> None:
|
||||
"""Set up socket session with authenticated user data.
|
||||
|
||||
Saves user_id and authentication timestamp to the socket session,
|
||||
and registers the connection with ConnectionManager.
|
||||
|
||||
Args:
|
||||
sio: Socket.IO AsyncServer instance (required).
|
||||
sid: Socket session ID.
|
||||
user_id: Authenticated user's UUID.
|
||||
namespace: Socket.IO namespace.
|
||||
"""
|
||||
session_data = {
|
||||
"user_id": str(user_id),
|
||||
"authenticated_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
await sio.save_session(sid, session_data, namespace=namespace)
|
||||
|
||||
await self._connection_manager.register_connection(sid, user_id)
|
||||
|
||||
logger.info(f"Session established: sid={sid}, user_id={user_id}")
|
||||
|
||||
async def cleanup_authenticated_session(
|
||||
self,
|
||||
sid: str,
|
||||
namespace: str = "/game",
|
||||
) -> str | None:
|
||||
"""Clean up session data on disconnect.
|
||||
|
||||
Unregisters the connection from ConnectionManager and returns
|
||||
the user_id for any additional cleanup needed.
|
||||
|
||||
Args:
|
||||
sid: Socket session ID.
|
||||
namespace: Socket.IO namespace.
|
||||
|
||||
Returns:
|
||||
user_id if session was authenticated, None otherwise.
|
||||
"""
|
||||
conn_info = await self._connection_manager.unregister_connection(sid)
|
||||
|
||||
if conn_info:
|
||||
logger.info(f"Session cleaned up: sid={sid}, user_id={conn_info.user_id}")
|
||||
return conn_info.user_id
|
||||
|
||||
logger.debug(f"No session to clean up for {sid}")
|
||||
return None
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
auth_handler = AuthHandler()
|
||||
|
||||
|
||||
def extract_token(auth: dict[str, object] | None) -> str | None:
|
||||
"""Extract JWT token from Socket.IO auth data.
|
||||
|
||||
@ -96,131 +227,6 @@ def extract_token(auth: dict[str, object] | None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
async def authenticate_connection(
|
||||
sid: str,
|
||||
auth: dict[str, object] | None,
|
||||
) -> AuthResult:
|
||||
"""Authenticate a Socket.IO connection using JWT.
|
||||
|
||||
Extracts the JWT from auth data, validates it, and returns the result.
|
||||
Does NOT modify socket session - caller should handle that.
|
||||
|
||||
Args:
|
||||
sid: Socket session ID (for logging).
|
||||
auth: Authentication data from connect event.
|
||||
|
||||
Returns:
|
||||
AuthResult with success status and user_id or error details.
|
||||
|
||||
Example:
|
||||
result = await authenticate_connection(sid, auth)
|
||||
if result.success:
|
||||
await sio.save_session(sid, {"user_id": str(result.user_id)})
|
||||
else:
|
||||
logger.warning(f"Auth failed: {result.error_message}")
|
||||
return False # Reject connection
|
||||
"""
|
||||
# Extract token from auth data
|
||||
token = extract_token(auth)
|
||||
|
||||
if token is None:
|
||||
logger.debug(f"Connection {sid}: No token provided")
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_code="missing_token",
|
||||
error_message="Authentication token required",
|
||||
)
|
||||
|
||||
# Validate the token
|
||||
user_id = verify_access_token(token)
|
||||
|
||||
if user_id is None:
|
||||
logger.debug(f"Connection {sid}: Invalid or expired token")
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_code="invalid_token",
|
||||
error_message="Invalid or expired token",
|
||||
)
|
||||
|
||||
logger.debug(f"Connection {sid}: Authenticated as user {user_id}")
|
||||
return AuthResult(
|
||||
success=True,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
async def setup_authenticated_session(
|
||||
sio: object,
|
||||
sid: str,
|
||||
user_id: UUID,
|
||||
namespace: str = "/game",
|
||||
) -> None:
|
||||
"""Set up socket session with authenticated user data.
|
||||
|
||||
Saves user_id and authentication timestamp to the socket session,
|
||||
and registers the connection with ConnectionManager.
|
||||
|
||||
Args:
|
||||
sio: Socket.IO AsyncServer instance.
|
||||
sid: Socket session ID.
|
||||
user_id: Authenticated user's UUID.
|
||||
namespace: Socket.IO namespace.
|
||||
|
||||
Example:
|
||||
if auth_result.success:
|
||||
await setup_authenticated_session(sio, sid, auth_result.user_id)
|
||||
"""
|
||||
# Import here to avoid circular dependency
|
||||
from app.socketio.server import sio as server_sio
|
||||
|
||||
# Use provided sio or fall back to server sio
|
||||
socket_server = sio if sio is not None else server_sio
|
||||
|
||||
# Save to socket session
|
||||
session_data = {
|
||||
"user_id": str(user_id),
|
||||
"authenticated_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
await socket_server.save_session(sid, session_data, namespace=namespace)
|
||||
|
||||
# Register with ConnectionManager
|
||||
await connection_manager.register_connection(sid, user_id)
|
||||
|
||||
logger.info(f"Session established: sid={sid}, user_id={user_id}")
|
||||
|
||||
|
||||
async def cleanup_authenticated_session(
|
||||
sid: str,
|
||||
namespace: str = "/game",
|
||||
) -> str | None:
|
||||
"""Clean up session data on disconnect.
|
||||
|
||||
Unregisters the connection from ConnectionManager and returns
|
||||
the user_id for any additional cleanup needed.
|
||||
|
||||
Args:
|
||||
sid: Socket session ID.
|
||||
namespace: Socket.IO namespace.
|
||||
|
||||
Returns:
|
||||
user_id if session was authenticated, None otherwise.
|
||||
|
||||
Example:
|
||||
user_id = await cleanup_authenticated_session(sid)
|
||||
if user_id:
|
||||
# Notify opponent, etc.
|
||||
"""
|
||||
# Unregister from ConnectionManager
|
||||
conn_info = await connection_manager.unregister_connection(sid)
|
||||
|
||||
if conn_info:
|
||||
logger.info(f"Session cleaned up: sid={sid}, user_id={conn_info.user_id}")
|
||||
return conn_info.user_id
|
||||
|
||||
logger.debug(f"No session to clean up for {sid}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_session_user_id(
|
||||
sio: object,
|
||||
sid: str,
|
||||
|
||||
@ -33,12 +33,7 @@ from typing import TYPE_CHECKING
|
||||
import socketio
|
||||
|
||||
from app.config import settings
|
||||
from app.socketio.auth import (
|
||||
authenticate_connection,
|
||||
cleanup_authenticated_session,
|
||||
require_auth,
|
||||
setup_authenticated_session,
|
||||
)
|
||||
from app.socketio.auth import auth_handler, require_auth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
@ -84,7 +79,7 @@ async def connect(
|
||||
None is treated as True (accept).
|
||||
"""
|
||||
# Authenticate the connection
|
||||
auth_result = await authenticate_connection(sid, auth)
|
||||
auth_result = await auth_handler.authenticate_connection(sid, auth)
|
||||
|
||||
if not auth_result.success:
|
||||
logger.warning(
|
||||
@ -103,7 +98,7 @@ async def connect(
|
||||
return False
|
||||
|
||||
# Set up authenticated session and register connection
|
||||
await setup_authenticated_session(sio, sid, auth_result.user_id, namespace="/game")
|
||||
await auth_handler.setup_authenticated_session(sio, sid, auth_result.user_id, namespace="/game")
|
||||
|
||||
logger.info(f"Client authenticated to /game: sid={sid}, user_id={auth_result.user_id}")
|
||||
return True
|
||||
@ -119,7 +114,7 @@ async def disconnect(sid: str) -> None:
|
||||
sid: Socket session ID of disconnecting client.
|
||||
"""
|
||||
# Clean up session and get user info
|
||||
user_id = await cleanup_authenticated_session(sid, namespace="/game")
|
||||
user_id = await auth_handler.cleanup_authenticated_session(sid, namespace="/game")
|
||||
|
||||
if user_id:
|
||||
logger.info(f"Client disconnected from /game: sid={sid}, user_id={user_id}")
|
||||
|
||||
@ -4,19 +4,17 @@ This module tests JWT-based authentication for WebSocket connections,
|
||||
including token extraction, validation, and session management.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.socketio.auth import (
|
||||
AuthHandler,
|
||||
AuthResult,
|
||||
authenticate_connection,
|
||||
cleanup_authenticated_session,
|
||||
extract_token,
|
||||
get_session_user_id,
|
||||
require_auth,
|
||||
setup_authenticated_session,
|
||||
)
|
||||
|
||||
|
||||
@ -106,95 +104,143 @@ class TestExtractToken:
|
||||
|
||||
|
||||
class TestAuthenticateConnection:
|
||||
"""Tests for connection authentication."""
|
||||
"""Tests for connection authentication via AuthHandler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_success_with_valid_token(self) -> None:
|
||||
@pytest.fixture
|
||||
def mock_token_verifier(self) -> MagicMock:
|
||||
"""Create a mock token verifier function."""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connection_manager(self) -> AsyncMock:
|
||||
"""Create a mock ConnectionManager."""
|
||||
cm = AsyncMock()
|
||||
cm.register_connection = AsyncMock()
|
||||
cm.unregister_connection = AsyncMock()
|
||||
return cm
|
||||
|
||||
@pytest.fixture
|
||||
def auth_handler(
|
||||
self,
|
||||
mock_token_verifier: MagicMock,
|
||||
mock_connection_manager: AsyncMock,
|
||||
) -> AuthHandler:
|
||||
"""Create AuthHandler with injected mocks."""
|
||||
return AuthHandler(
|
||||
token_verifier=mock_token_verifier,
|
||||
conn_manager=mock_connection_manager,
|
||||
)
|
||||
|
||||
async def test_authenticate_success_with_valid_token(
|
||||
self,
|
||||
auth_handler: AuthHandler,
|
||||
mock_token_verifier: MagicMock,
|
||||
) -> None:
|
||||
"""Test successful authentication with a valid JWT.
|
||||
|
||||
A valid access token should result in AuthResult with
|
||||
success=True and the user_id from the token.
|
||||
"""
|
||||
user_id = uuid4()
|
||||
mock_token_verifier.return_value = user_id
|
||||
|
||||
with patch("app.socketio.auth.verify_access_token") as mock_verify:
|
||||
mock_verify.return_value = user_id
|
||||
result = await auth_handler.authenticate_connection("test-sid", {"token": "valid-token"})
|
||||
|
||||
result = await authenticate_connection("test-sid", {"token": "valid-token"})
|
||||
assert result.success is True
|
||||
assert result.user_id == user_id
|
||||
assert result.error_code is None
|
||||
mock_token_verifier.assert_called_once_with("valid-token")
|
||||
|
||||
assert result.success is True
|
||||
assert result.user_id == user_id
|
||||
assert result.error_code is None
|
||||
mock_verify.assert_called_once_with("valid-token")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_fails_with_missing_token(self) -> None:
|
||||
async def test_authenticate_fails_with_missing_token(
|
||||
self,
|
||||
auth_handler: AuthHandler,
|
||||
) -> None:
|
||||
"""Test authentication failure when no token is provided.
|
||||
|
||||
Connections without any auth data should fail with
|
||||
a 'missing_token' error code.
|
||||
"""
|
||||
result = await authenticate_connection("test-sid", None)
|
||||
result = await auth_handler.authenticate_connection("test-sid", None)
|
||||
|
||||
assert result.success is False
|
||||
assert result.user_id is None
|
||||
assert result.error_code == "missing_token"
|
||||
assert "required" in result.error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_fails_with_empty_auth(self) -> None:
|
||||
async def test_authenticate_fails_with_empty_auth(
|
||||
self,
|
||||
auth_handler: AuthHandler,
|
||||
) -> None:
|
||||
"""Test authentication failure with empty auth object.
|
||||
|
||||
An auth object without any token fields should fail.
|
||||
"""
|
||||
result = await authenticate_connection("test-sid", {})
|
||||
result = await auth_handler.authenticate_connection("test-sid", {})
|
||||
|
||||
assert result.success is False
|
||||
assert result.error_code == "missing_token"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_fails_with_invalid_token(self) -> None:
|
||||
async def test_authenticate_fails_with_invalid_token(
|
||||
self,
|
||||
auth_handler: AuthHandler,
|
||||
mock_token_verifier: MagicMock,
|
||||
) -> None:
|
||||
"""Test authentication failure with invalid/expired JWT.
|
||||
|
||||
When verify_access_token returns None (invalid token),
|
||||
we should fail with 'invalid_token' error.
|
||||
"""
|
||||
with patch("app.socketio.auth.verify_access_token") as mock_verify:
|
||||
mock_verify.return_value = None # Token validation failed
|
||||
mock_token_verifier.return_value = None
|
||||
|
||||
result = await authenticate_connection("test-sid", {"token": "invalid-token"})
|
||||
result = await auth_handler.authenticate_connection("test-sid", {"token": "invalid-token"})
|
||||
|
||||
assert result.success is False
|
||||
assert result.user_id is None
|
||||
assert result.error_code == "invalid_token"
|
||||
assert (
|
||||
"invalid" in result.error_message.lower()
|
||||
or "expired" in result.error_message.lower()
|
||||
)
|
||||
assert result.success is False
|
||||
assert result.user_id is None
|
||||
assert result.error_code == "invalid_token"
|
||||
assert (
|
||||
"invalid" in result.error_message.lower() or "expired" in result.error_message.lower()
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_extracts_token_from_bearer(self) -> None:
|
||||
async def test_authenticate_extracts_token_from_bearer(
|
||||
self,
|
||||
auth_handler: AuthHandler,
|
||||
mock_token_verifier: MagicMock,
|
||||
) -> None:
|
||||
"""Test that authentication works with Bearer format.
|
||||
|
||||
The authenticate function should handle Bearer token format
|
||||
through the extract_token function.
|
||||
"""
|
||||
user_id = uuid4()
|
||||
mock_token_verifier.return_value = user_id
|
||||
|
||||
with patch("app.socketio.auth.verify_access_token") as mock_verify:
|
||||
mock_verify.return_value = user_id
|
||||
result = await auth_handler.authenticate_connection(
|
||||
"test-sid", {"authorization": "Bearer my-token"}
|
||||
)
|
||||
|
||||
result = await authenticate_connection("test-sid", {"authorization": "Bearer my-token"})
|
||||
|
||||
assert result.success is True
|
||||
mock_verify.assert_called_once_with("my-token")
|
||||
assert result.success is True
|
||||
mock_token_verifier.assert_called_once_with("my-token")
|
||||
|
||||
|
||||
class TestSetupAuthenticatedSession:
|
||||
"""Tests for session setup after authentication."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_saves_session_data(self) -> None:
|
||||
@pytest.fixture
|
||||
def mock_connection_manager(self) -> AsyncMock:
|
||||
"""Create a mock ConnectionManager."""
|
||||
cm = AsyncMock()
|
||||
cm.register_connection = AsyncMock()
|
||||
return cm
|
||||
|
||||
@pytest.fixture
|
||||
def auth_handler(self, mock_connection_manager: AsyncMock) -> AuthHandler:
|
||||
"""Create AuthHandler with injected mock."""
|
||||
return AuthHandler(conn_manager=mock_connection_manager)
|
||||
|
||||
async def test_setup_saves_session_data(
|
||||
self,
|
||||
auth_handler: AuthHandler,
|
||||
) -> None:
|
||||
"""Test that session setup saves user_id and timestamp.
|
||||
|
||||
After authentication, the socket session should contain
|
||||
@ -203,22 +249,21 @@ class TestSetupAuthenticatedSession:
|
||||
user_id = uuid4()
|
||||
mock_sio = AsyncMock()
|
||||
|
||||
with patch("app.socketio.auth.connection_manager") as mock_cm:
|
||||
mock_cm.register_connection = AsyncMock()
|
||||
await auth_handler.setup_authenticated_session(mock_sio, "test-sid", user_id)
|
||||
|
||||
await setup_authenticated_session(mock_sio, "test-sid", user_id)
|
||||
mock_sio.save_session.assert_called_once()
|
||||
call_args = mock_sio.save_session.call_args
|
||||
assert call_args.args[0] == "test-sid"
|
||||
|
||||
# Verify session was saved
|
||||
mock_sio.save_session.assert_called_once()
|
||||
call_args = mock_sio.save_session.call_args
|
||||
assert call_args.args[0] == "test-sid"
|
||||
session_data = call_args.args[1]
|
||||
assert session_data["user_id"] == str(user_id)
|
||||
assert "authenticated_at" in session_data
|
||||
|
||||
session_data = call_args.args[1]
|
||||
assert session_data["user_id"] == str(user_id)
|
||||
assert "authenticated_at" in session_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_registers_with_connection_manager(self) -> None:
|
||||
async def test_setup_registers_with_connection_manager(
|
||||
self,
|
||||
auth_handler: AuthHandler,
|
||||
mock_connection_manager: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that session setup registers with ConnectionManager.
|
||||
|
||||
The connection should be tracked in ConnectionManager for
|
||||
@ -227,53 +272,63 @@ class TestSetupAuthenticatedSession:
|
||||
user_id = uuid4()
|
||||
mock_sio = AsyncMock()
|
||||
|
||||
with patch("app.socketio.auth.connection_manager") as mock_cm:
|
||||
mock_cm.register_connection = AsyncMock()
|
||||
await auth_handler.setup_authenticated_session(mock_sio, "test-sid", user_id)
|
||||
|
||||
await setup_authenticated_session(mock_sio, "test-sid", user_id)
|
||||
|
||||
mock_cm.register_connection.assert_called_once_with("test-sid", user_id)
|
||||
mock_connection_manager.register_connection.assert_called_once_with("test-sid", user_id)
|
||||
|
||||
|
||||
class TestCleanupAuthenticatedSession:
|
||||
"""Tests for session cleanup on disconnect."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_unregisters_connection(self) -> None:
|
||||
@pytest.fixture
|
||||
def mock_connection_manager(self) -> AsyncMock:
|
||||
"""Create a mock ConnectionManager."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def auth_handler(self, mock_connection_manager: AsyncMock) -> AuthHandler:
|
||||
"""Create AuthHandler with injected mock."""
|
||||
return AuthHandler(conn_manager=mock_connection_manager)
|
||||
|
||||
async def test_cleanup_unregisters_connection(
|
||||
self,
|
||||
auth_handler: AuthHandler,
|
||||
mock_connection_manager: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that cleanup unregisters from ConnectionManager.
|
||||
|
||||
On disconnect, the connection should be removed from
|
||||
ConnectionManager to update presence tracking.
|
||||
"""
|
||||
with patch("app.socketio.auth.connection_manager") as mock_cm:
|
||||
mock_conn_info = MagicMock()
|
||||
mock_conn_info.user_id = "user-123"
|
||||
mock_cm.unregister_connection = AsyncMock(return_value=mock_conn_info)
|
||||
mock_conn_info = MagicMock()
|
||||
mock_conn_info.user_id = "user-123"
|
||||
mock_connection_manager.unregister_connection = AsyncMock(return_value=mock_conn_info)
|
||||
|
||||
result = await cleanup_authenticated_session("test-sid")
|
||||
result = await auth_handler.cleanup_authenticated_session("test-sid")
|
||||
|
||||
assert result == "user-123"
|
||||
mock_cm.unregister_connection.assert_called_once_with("test-sid")
|
||||
assert result == "user-123"
|
||||
mock_connection_manager.unregister_connection.assert_called_once_with("test-sid")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_returns_none_for_unknown_session(self) -> None:
|
||||
async def test_cleanup_returns_none_for_unknown_session(
|
||||
self,
|
||||
auth_handler: AuthHandler,
|
||||
mock_connection_manager: AsyncMock,
|
||||
) -> None:
|
||||
"""Test cleanup returns None for non-existent sessions.
|
||||
|
||||
If the connection wasn't registered (e.g., auth failed),
|
||||
cleanup should return None gracefully.
|
||||
"""
|
||||
with patch("app.socketio.auth.connection_manager") as mock_cm:
|
||||
mock_cm.unregister_connection = AsyncMock(return_value=None)
|
||||
mock_connection_manager.unregister_connection = AsyncMock(return_value=None)
|
||||
|
||||
result = await cleanup_authenticated_session("unknown-sid")
|
||||
result = await auth_handler.cleanup_authenticated_session("unknown-sid")
|
||||
|
||||
assert result is None
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetSessionUserId:
|
||||
"""Tests for session user_id retrieval."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_user_id_returns_id(self) -> None:
|
||||
"""Test retrieving user_id from authenticated session.
|
||||
|
||||
@ -290,7 +345,6 @@ class TestGetSessionUserId:
|
||||
assert result == "user-123"
|
||||
mock_sio.get_session.assert_called_once_with("test-sid", namespace="/game")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_user_id_returns_none_for_missing(self) -> None:
|
||||
"""Test that missing session returns None.
|
||||
|
||||
@ -304,7 +358,6 @@ class TestGetSessionUserId:
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_user_id_handles_exception(self) -> None:
|
||||
"""Test that exceptions are caught and return None.
|
||||
|
||||
@ -322,7 +375,6 @@ class TestGetSessionUserId:
|
||||
class TestRequireAuth:
|
||||
"""Tests for the require_auth helper."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_auth_returns_user_id_for_authenticated(self) -> None:
|
||||
"""Test that require_auth returns user_id for valid sessions.
|
||||
|
||||
@ -336,7 +388,6 @@ class TestRequireAuth:
|
||||
|
||||
assert result == "user-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_auth_returns_none_for_unauthenticated(self) -> None:
|
||||
"""Test that require_auth returns None for unauthenticated sessions.
|
||||
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
"""Tests for ConnectionManager service.
|
||||
|
||||
This module tests WebSocket connection tracking with Redis. Since these are
|
||||
unit tests, we mock the Redis operations to test the ConnectionManager logic
|
||||
unit tests, we inject a mock Redis factory to test the ConnectionManager logic
|
||||
without requiring a real Redis instance.
|
||||
|
||||
Uses dependency injection pattern - no monkey patching required.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -21,15 +25,13 @@ from app.services.connection_manager import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager() -> ConnectionManager:
|
||||
"""Create a ConnectionManager instance for testing."""
|
||||
return ConnectionManager(conn_ttl_seconds=3600)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis() -> AsyncMock:
|
||||
"""Create a mock Redis client."""
|
||||
"""Create a mock Redis client.
|
||||
|
||||
This mock provides all the Redis operations used by ConnectionManager
|
||||
with sensible defaults that can be overridden per-test.
|
||||
"""
|
||||
redis = AsyncMock()
|
||||
redis.hset = AsyncMock()
|
||||
redis.hget = AsyncMock()
|
||||
@ -46,6 +48,21 @@ def mock_redis() -> AsyncMock:
|
||||
return redis
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(mock_redis: AsyncMock) -> ConnectionManager:
|
||||
"""Create a ConnectionManager with injected mock Redis.
|
||||
|
||||
The mock Redis is injected via the redis_factory parameter,
|
||||
eliminating the need for monkey patching in tests.
|
||||
"""
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_redis_factory() -> AsyncIterator[AsyncMock]:
|
||||
yield mock_redis
|
||||
|
||||
return ConnectionManager(conn_ttl_seconds=3600, redis_factory=mock_redis_factory)
|
||||
|
||||
|
||||
class TestConnectionInfoDataclass:
|
||||
"""Tests for the ConnectionInfo dataclass."""
|
||||
|
||||
@ -87,8 +104,8 @@ class TestConnectionInfoDataclass:
|
||||
def test_is_stale_with_custom_threshold(self) -> None:
|
||||
"""Test is_stale with a custom threshold.
|
||||
|
||||
The threshold can be adjusted for different use cases like
|
||||
more aggressive cleanup or more lenient timeout.
|
||||
The threshold parameter allows callers to customize what counts
|
||||
as stale for different use cases.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
last_seen = now - timedelta(seconds=60)
|
||||
@ -96,7 +113,7 @@ class TestConnectionInfoDataclass:
|
||||
sid="test-sid",
|
||||
user_id="user-123",
|
||||
game_id=None,
|
||||
connected_at=now,
|
||||
connected_at=last_seen,
|
||||
last_seen=last_seen,
|
||||
)
|
||||
|
||||
@ -125,23 +142,20 @@ class TestRegisterConnection:
|
||||
sid = "test-sid-123"
|
||||
user_id = str(uuid4())
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
await manager.register_connection(sid, user_id)
|
||||
|
||||
await manager.register_connection(sid, user_id)
|
||||
# Verify connection hash was created
|
||||
conn_key = f"{CONN_PREFIX}{sid}"
|
||||
mock_redis.hset.assert_called()
|
||||
hset_call = mock_redis.hset.call_args
|
||||
assert hset_call.args[0] == conn_key
|
||||
mapping = hset_call.kwargs["mapping"]
|
||||
assert mapping["user_id"] == user_id
|
||||
assert mapping["game_id"] == ""
|
||||
|
||||
# Verify connection hash was created
|
||||
conn_key = f"{CONN_PREFIX}{sid}"
|
||||
mock_redis.hset.assert_called()
|
||||
hset_call = mock_redis.hset.call_args
|
||||
assert hset_call.args[0] == conn_key
|
||||
mapping = hset_call.kwargs["mapping"]
|
||||
assert mapping["user_id"] == user_id
|
||||
assert mapping["game_id"] == ""
|
||||
|
||||
# Verify user-to-connection mapping was created
|
||||
user_conn_key = f"{USER_CONN_PREFIX}{user_id}"
|
||||
mock_redis.set.assert_called_with(user_conn_key, sid)
|
||||
# Verify user-to-connection mapping was created
|
||||
user_conn_key = f"{USER_CONN_PREFIX}{user_id}"
|
||||
mock_redis.set.assert_called_with(user_conn_key, sid)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_connection_replaces_old_connection(
|
||||
@ -167,15 +181,12 @@ class TestRegisterConnection:
|
||||
"last_seen": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
await manager.register_connection(new_sid, user_id)
|
||||
|
||||
await manager.register_connection(new_sid, user_id)
|
||||
|
||||
# Verify old connection was cleaned up (delete called for old conn key)
|
||||
old_conn_key = f"{CONN_PREFIX}{old_sid}"
|
||||
delete_calls = [call.args[0] for call in mock_redis.delete.call_args_list]
|
||||
assert old_conn_key in delete_calls
|
||||
# Verify old connection was cleaned up (delete called for old conn key)
|
||||
old_conn_key = f"{CONN_PREFIX}{old_sid}"
|
||||
delete_calls = [call.args[0] for call in mock_redis.delete.call_args_list]
|
||||
assert old_conn_key in delete_calls
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_connection_accepts_uuid(
|
||||
@ -191,15 +202,12 @@ class TestRegisterConnection:
|
||||
sid = "test-sid"
|
||||
user_uuid = uuid4()
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
await manager.register_connection(sid, user_uuid)
|
||||
|
||||
await manager.register_connection(sid, user_uuid)
|
||||
|
||||
# Verify user_id was converted to string
|
||||
hset_call = mock_redis.hset.call_args
|
||||
mapping = hset_call.kwargs["mapping"]
|
||||
assert mapping["user_id"] == str(user_uuid)
|
||||
# Verify user_id was converted to string
|
||||
hset_call = mock_redis.hset.call_args
|
||||
mapping = hset_call.kwargs["mapping"]
|
||||
assert mapping["user_id"] == str(user_uuid)
|
||||
|
||||
|
||||
class TestUnregisterConnection:
|
||||
@ -213,17 +221,14 @@ class TestUnregisterConnection:
|
||||
) -> None:
|
||||
"""Test that unregistering unknown connection returns None.
|
||||
|
||||
If the connection doesn't exist, we should return None rather than
|
||||
raising an error.
|
||||
Unregistering a non-existent connection should be a no-op
|
||||
and return None to indicate nothing was cleaned up.
|
||||
"""
|
||||
mock_redis.hgetall.return_value = {}
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
result = await manager.unregister_connection("unknown-sid")
|
||||
|
||||
result = await manager.unregister_connection("unknown-sid")
|
||||
|
||||
assert result is None
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_cleans_up_all_data(
|
||||
@ -231,44 +236,39 @@ class TestUnregisterConnection:
|
||||
manager: ConnectionManager,
|
||||
mock_redis: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that unregistering cleans up all related Redis data.
|
||||
"""Test that unregistering cleans up all Redis data.
|
||||
|
||||
Cleanup should remove:
|
||||
Unregistration should remove:
|
||||
1. Connection hash
|
||||
2. User-to-connection mapping (if it still points to this sid)
|
||||
3. Game connection set membership
|
||||
2. User-to-connection mapping
|
||||
3. Remove from game connection set (if in a game)
|
||||
"""
|
||||
sid = "test-sid"
|
||||
user_id = "user-123"
|
||||
game_id = "game-456"
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Mock: connection exists with game
|
||||
mock_redis.hgetall.return_value = {
|
||||
"user_id": user_id,
|
||||
"game_id": game_id,
|
||||
"connected_at": datetime.now(UTC).isoformat(),
|
||||
"last_seen": datetime.now(UTC).isoformat(),
|
||||
"connected_at": now.isoformat(),
|
||||
"last_seen": now.isoformat(),
|
||||
}
|
||||
mock_redis.get.return_value = sid # user mapping points to this sid
|
||||
mock_redis.get.return_value = sid # User's current connection
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
result = await manager.unregister_connection(sid)
|
||||
|
||||
result = await manager.unregister_connection(sid)
|
||||
assert result is not None
|
||||
assert result.user_id == user_id
|
||||
assert result.game_id == game_id
|
||||
|
||||
# Verify result
|
||||
assert result is not None
|
||||
assert result.sid == sid
|
||||
assert result.user_id == user_id
|
||||
assert result.game_id == game_id
|
||||
|
||||
# Verify cleanup
|
||||
mock_redis.srem.assert_called() # Removed from game set
|
||||
mock_redis.delete.assert_called() # Connection deleted
|
||||
# Verify cleanup calls
|
||||
mock_redis.delete.assert_called()
|
||||
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
|
||||
|
||||
|
||||
class TestGameAssociation:
|
||||
"""Tests for game join/leave operations."""
|
||||
"""Tests for game association methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_game_adds_to_game_set(
|
||||
@ -276,25 +276,22 @@ class TestGameAssociation:
|
||||
manager: ConnectionManager,
|
||||
mock_redis: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that joining a game adds the connection to the game's set.
|
||||
"""Test that joining a game adds connection to game's set.
|
||||
|
||||
When a connection joins a game, it should:
|
||||
1. Update the connection's game_id
|
||||
2. Add the sid to the game's connection set
|
||||
The connection should be tracked in the game's connection set
|
||||
for broadcasting updates to all participants.
|
||||
"""
|
||||
sid = "test-sid"
|
||||
game_id = "game-123"
|
||||
|
||||
mock_redis.exists.return_value = True
|
||||
mock_redis.hget.return_value = "" # Not in a game yet
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
result = await manager.join_game(sid, game_id)
|
||||
|
||||
result = await manager.join_game(sid, game_id)
|
||||
|
||||
assert result is True
|
||||
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", game_id)
|
||||
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
|
||||
assert result is True
|
||||
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
|
||||
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", game_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_game_returns_false_for_unknown_connection(
|
||||
@ -304,18 +301,13 @@ class TestGameAssociation:
|
||||
) -> None:
|
||||
"""Test that joining a game fails for unknown connections.
|
||||
|
||||
If the connection doesn't exist, we should return False rather
|
||||
than creating orphan game association data.
|
||||
Non-existent connections should not be able to join games.
|
||||
"""
|
||||
mock_redis.exists.return_value = False
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
result = await manager.join_game("unknown-sid", "game-123")
|
||||
|
||||
result = await manager.join_game("unknown-sid", "game-123")
|
||||
|
||||
assert result is False
|
||||
mock_redis.sadd.assert_not_called()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_game_leaves_previous_game(
|
||||
@ -325,25 +317,22 @@ class TestGameAssociation:
|
||||
) -> None:
|
||||
"""Test that joining a new game leaves the previous game.
|
||||
|
||||
When switching games, the connection should be removed from the
|
||||
old game's connection set before joining the new one.
|
||||
A connection can only be in one game at a time, so joining
|
||||
a new game should automatically leave any previous game.
|
||||
"""
|
||||
sid = "test-sid"
|
||||
old_game = "game-old"
|
||||
new_game = "game-new"
|
||||
old_game = "old-game"
|
||||
new_game = "new-game"
|
||||
|
||||
mock_redis.exists.return_value = True
|
||||
mock_redis.hget.return_value = old_game # Currently in old game
|
||||
mock_redis.hget.return_value = old_game
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
await manager.join_game(sid, new_game)
|
||||
|
||||
await manager.join_game(sid, new_game)
|
||||
|
||||
# Verify left old game
|
||||
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{old_game}", sid)
|
||||
# Verify joined new game
|
||||
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{new_game}", sid)
|
||||
# Should remove from old game
|
||||
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{old_game}", sid)
|
||||
# Should add to new game
|
||||
mock_redis.sadd.assert_called_with(f"{GAME_CONNS_PREFIX}{new_game}", sid)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_game_removes_from_set(
|
||||
@ -351,25 +340,21 @@ class TestGameAssociation:
|
||||
manager: ConnectionManager,
|
||||
mock_redis: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that leaving a game removes the connection from the set.
|
||||
"""Test that leaving a game removes connection from set.
|
||||
|
||||
Leave should:
|
||||
1. Remove sid from game's connection set
|
||||
2. Clear game_id on connection record
|
||||
The connection should be removed from the game's set and
|
||||
its game_id should be cleared.
|
||||
"""
|
||||
sid = "test-sid"
|
||||
game_id = "game-123"
|
||||
|
||||
mock_redis.hget.return_value = game_id
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
result = await manager.leave_game(sid)
|
||||
|
||||
result = await manager.leave_game(sid)
|
||||
|
||||
assert result == game_id
|
||||
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
|
||||
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "")
|
||||
assert result == game_id
|
||||
mock_redis.srem.assert_called_with(f"{GAME_CONNS_PREFIX}{game_id}", sid)
|
||||
mock_redis.hset.assert_called_with(f"{CONN_PREFIX}{sid}", "game_id", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_game_returns_none_when_not_in_game(
|
||||
@ -377,20 +362,16 @@ class TestGameAssociation:
|
||||
manager: ConnectionManager,
|
||||
mock_redis: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that leave_game returns None when not in a game.
|
||||
"""Test that leaving when not in a game returns None.
|
||||
|
||||
If the connection isn't associated with any game, we should
|
||||
return None without making unnecessary Redis calls.
|
||||
If the connection isn't in a game, leave_game should return
|
||||
None to indicate no game was left.
|
||||
"""
|
||||
mock_redis.hget.return_value = "" # No game_id
|
||||
mock_redis.hget.return_value = ""
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
result = await manager.leave_game("test-sid")
|
||||
|
||||
result = await manager.leave_game("test-sid")
|
||||
|
||||
assert result is None
|
||||
mock_redis.srem.assert_not_called()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestHeartbeat:
|
||||
@ -404,25 +385,18 @@ class TestHeartbeat:
|
||||
) -> None:
|
||||
"""Test that heartbeat updates the last_seen timestamp.
|
||||
|
||||
Heartbeats keep the connection alive by updating the timestamp
|
||||
and refreshing TTLs on Redis records.
|
||||
Heartbeats keep connections alive by updating timestamps
|
||||
and refreshing TTLs on Redis keys.
|
||||
"""
|
||||
sid = "test-sid"
|
||||
mock_redis.exists.return_value = True
|
||||
mock_redis.hget.return_value = "user-123"
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
result = await manager.update_heartbeat(sid)
|
||||
|
||||
result = await manager.update_heartbeat(sid)
|
||||
|
||||
assert result is True
|
||||
# Verify last_seen was updated
|
||||
hset_call = mock_redis.hset.call_args
|
||||
assert hset_call.args[0] == f"{CONN_PREFIX}{sid}"
|
||||
assert hset_call.args[1] == "last_seen"
|
||||
# Verify TTL was refreshed
|
||||
mock_redis.expire.assert_called()
|
||||
assert result is True
|
||||
mock_redis.hset.assert_called()
|
||||
mock_redis.expire.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_heartbeat_returns_false_for_unknown(
|
||||
@ -430,18 +404,16 @@ class TestHeartbeat:
|
||||
manager: ConnectionManager,
|
||||
mock_redis: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that heartbeat returns False for unknown connections.
|
||||
"""Test that heartbeat fails for unknown connections.
|
||||
|
||||
If the connection doesn't exist, we shouldn't try to update it.
|
||||
Unknown connections should not be able to send heartbeats,
|
||||
returning False to indicate failure.
|
||||
"""
|
||||
mock_redis.exists.return_value = False
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
result = await manager.update_heartbeat("unknown-sid")
|
||||
|
||||
result = await manager.update_heartbeat("unknown-sid")
|
||||
|
||||
assert result is False
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestQueryMethods:
|
||||
@ -453,13 +425,13 @@ class TestQueryMethods:
|
||||
manager: ConnectionManager,
|
||||
mock_redis: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that get_connection returns ConnectionInfo for valid sid.
|
||||
"""Test that get_connection returns ConnectionInfo.
|
||||
|
||||
The returned ConnectionInfo should have all fields populated
|
||||
from the Redis hash.
|
||||
The returned info should contain all the stored connection data.
|
||||
"""
|
||||
sid = "test-sid"
|
||||
now = datetime.now(UTC)
|
||||
|
||||
mock_redis.hgetall.return_value = {
|
||||
"user_id": "user-123",
|
||||
"game_id": "game-456",
|
||||
@ -467,15 +439,12 @@ class TestQueryMethods:
|
||||
"last_seen": now.isoformat(),
|
||||
}
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
result = await manager.get_connection(sid)
|
||||
|
||||
result = await manager.get_connection(sid)
|
||||
|
||||
assert result is not None
|
||||
assert result.sid == sid
|
||||
assert result.user_id == "user-123"
|
||||
assert result.game_id == "game-456"
|
||||
assert result is not None
|
||||
assert result.sid == sid
|
||||
assert result.user_id == "user-123"
|
||||
assert result.game_id == "game-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_connection_returns_none_for_unknown(
|
||||
@ -483,15 +452,15 @@ class TestQueryMethods:
|
||||
manager: ConnectionManager,
|
||||
mock_redis: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that get_connection returns None for unknown sid."""
|
||||
"""Test that get_connection returns None for unknown sid.
|
||||
|
||||
Non-existent connections should return None rather than raising.
|
||||
"""
|
||||
mock_redis.hgetall.return_value = {}
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
result = await manager.get_connection("unknown-sid")
|
||||
|
||||
result = await manager.get_connection("unknown-sid")
|
||||
|
||||
assert result is None
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_user_online_returns_true_for_connected_user(
|
||||
@ -501,7 +470,7 @@ class TestQueryMethods:
|
||||
) -> None:
|
||||
"""Test that is_user_online returns True for connected users.
|
||||
|
||||
A user is online if they have an active connection that isn't stale.
|
||||
Users with active, non-stale connections should be considered online.
|
||||
"""
|
||||
user_id = "user-123"
|
||||
now = datetime.now(UTC)
|
||||
@ -514,12 +483,9 @@ class TestQueryMethods:
|
||||
"last_seen": now.isoformat(),
|
||||
}
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
result = await manager.is_user_online(user_id)
|
||||
|
||||
result = await manager.is_user_online(user_id)
|
||||
|
||||
assert result is True
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_user_online_returns_false_for_stale_connection(
|
||||
@ -529,11 +495,11 @@ class TestQueryMethods:
|
||||
) -> None:
|
||||
"""Test that is_user_online returns False for stale connections.
|
||||
|
||||
Even if a connection record exists, if it's stale (no recent heartbeat),
|
||||
the user should be considered offline.
|
||||
Users with stale connections (no recent heartbeat) should be
|
||||
considered offline even if their connection record exists.
|
||||
"""
|
||||
user_id = "user-123"
|
||||
old_time = datetime.now(UTC) - timedelta(minutes=5)
|
||||
old_time = datetime.now(UTC) - timedelta(seconds=HEARTBEAT_INTERVAL_SECONDS * 4)
|
||||
|
||||
mock_redis.get.return_value = "test-sid"
|
||||
mock_redis.hgetall.return_value = {
|
||||
@ -543,12 +509,9 @@ class TestQueryMethods:
|
||||
"last_seen": old_time.isoformat(),
|
||||
}
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
result = await manager.is_user_online(user_id)
|
||||
|
||||
result = await manager.is_user_online(user_id)
|
||||
|
||||
assert result is False
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_game_connections_returns_all_participants(
|
||||
@ -564,29 +527,32 @@ class TestQueryMethods:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
mock_redis.smembers.return_value = {"sid-1", "sid-2"}
|
||||
mock_redis.hgetall.side_effect = [
|
||||
{
|
||||
"user_id": "user-1",
|
||||
"game_id": game_id,
|
||||
"connected_at": now.isoformat(),
|
||||
"last_seen": now.isoformat(),
|
||||
},
|
||||
{
|
||||
"user_id": "user-2",
|
||||
"game_id": game_id,
|
||||
"connected_at": now.isoformat(),
|
||||
"last_seen": now.isoformat(),
|
||||
},
|
||||
]
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
# Use a function to return the right data based on the key queried
|
||||
def hgetall_by_key(key: str) -> dict[str, str]:
|
||||
if key == f"{CONN_PREFIX}sid-1":
|
||||
return {
|
||||
"user_id": "user-1",
|
||||
"game_id": game_id,
|
||||
"connected_at": now.isoformat(),
|
||||
"last_seen": now.isoformat(),
|
||||
}
|
||||
elif key == f"{CONN_PREFIX}sid-2":
|
||||
return {
|
||||
"user_id": "user-2",
|
||||
"game_id": game_id,
|
||||
"connected_at": now.isoformat(),
|
||||
"last_seen": now.isoformat(),
|
||||
}
|
||||
return {}
|
||||
|
||||
result = await manager.get_game_connections(game_id)
|
||||
mock_redis.hgetall.side_effect = hgetall_by_key
|
||||
|
||||
assert len(result) == 2
|
||||
user_ids = {conn.user_id for conn in result}
|
||||
assert user_ids == {"user-1", "user-2"}
|
||||
result = await manager.get_game_connections(game_id)
|
||||
|
||||
assert len(result) == 2
|
||||
user_ids = {conn.user_id for conn in result}
|
||||
assert user_ids == {"user-1", "user-2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_opponent_sid_returns_other_player(
|
||||
@ -628,38 +594,38 @@ class TestQueryMethods:
|
||||
|
||||
mock_redis.hgetall.side_effect = hgetall_by_key
|
||||
|
||||
with patch("app.services.connection_manager.get_redis") as mock_get_redis:
|
||||
mock_get_redis.return_value.__aenter__.return_value = mock_redis
|
||||
result = await manager.get_opponent_sid(game_id, current_user)
|
||||
|
||||
result = await manager.get_opponent_sid(game_id, current_user)
|
||||
|
||||
assert result == opponent_sid
|
||||
assert result == opponent_sid
|
||||
|
||||
|
||||
class TestKeyGeneration:
|
||||
"""Tests for Redis key generation methods."""
|
||||
|
||||
def test_conn_key_format(self, manager: ConnectionManager) -> None:
|
||||
def test_conn_key_format(self) -> None:
|
||||
"""Test that connection keys have correct format.
|
||||
|
||||
Keys should follow the pattern conn:{sid} for easy identification
|
||||
and pattern matching.
|
||||
"""
|
||||
manager = ConnectionManager()
|
||||
key = manager._conn_key("test-sid-123")
|
||||
assert key == "conn:test-sid-123"
|
||||
|
||||
def test_user_conn_key_format(self, manager: ConnectionManager) -> None:
|
||||
def test_user_conn_key_format(self) -> None:
|
||||
"""Test that user connection keys have correct format.
|
||||
|
||||
Keys should follow the pattern user_conn:{user_id}.
|
||||
"""
|
||||
manager = ConnectionManager()
|
||||
key = manager._user_conn_key("user-456")
|
||||
assert key == "user_conn:user-456"
|
||||
|
||||
def test_game_conns_key_format(self, manager: ConnectionManager) -> None:
|
||||
def test_game_conns_key_format(self) -> None:
|
||||
"""Test that game connections keys have correct format.
|
||||
|
||||
Keys should follow the pattern game_conns:{game_id}.
|
||||
"""
|
||||
manager = ConnectionManager()
|
||||
key = manager._game_conns_key("game-789")
|
||||
assert key == "game_conns:game-789"
|
||||
|
||||
@ -7,9 +7,11 @@ The GameService is STATELESS regarding game rules:
|
||||
- No GameEngine is stored in the service
|
||||
- Engine is created per-operation using rules from GameState
|
||||
- Rules come from frontend at game creation, stored in GameState
|
||||
|
||||
Uses dependency injection pattern - no monkey patching required.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -53,19 +55,35 @@ def mock_card_service() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine() -> MagicMock:
|
||||
"""Create a mock GameEngine.
|
||||
|
||||
The engine handles game logic - executing actions and checking
|
||||
win conditions.
|
||||
"""
|
||||
engine = MagicMock()
|
||||
engine.execute_action = AsyncMock(
|
||||
return_value=ActionResult(success=True, message="Action executed")
|
||||
)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def game_service(
|
||||
mock_state_manager: AsyncMock,
|
||||
mock_card_service: MagicMock,
|
||||
mock_engine: MagicMock,
|
||||
) -> GameService:
|
||||
"""Create a GameService with mocked dependencies.
|
||||
"""Create a GameService with injected mock dependencies.
|
||||
|
||||
Note: No engine is passed - GameService creates engines per-operation
|
||||
using rules stored in each game's GameState.
|
||||
The mock engine factory is injected via the engine_factory parameter,
|
||||
eliminating the need for monkey patching in tests.
|
||||
"""
|
||||
return GameService(
|
||||
state_manager=mock_state_manager,
|
||||
card_service=mock_card_service,
|
||||
engine_factory=lambda game: mock_engine,
|
||||
)
|
||||
|
||||
|
||||
@ -324,9 +342,8 @@ class TestJoinGame:
|
||||
class TestExecuteAction:
|
||||
"""Tests for the execute_action method.
|
||||
|
||||
These tests verify action execution through GameService. Since GameService
|
||||
creates engines per-operation, we patch _create_engine_for_game to return
|
||||
a mock engine with controlled behavior.
|
||||
These tests verify action execution through GameService. The mock engine
|
||||
is injected via the engine_factory parameter in the game_service fixture.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -334,6 +351,7 @@ class TestExecuteAction:
|
||||
self,
|
||||
game_service: GameService,
|
||||
mock_state_manager: AsyncMock,
|
||||
mock_engine: MagicMock,
|
||||
sample_game_state: GameState,
|
||||
) -> None:
|
||||
"""Test successful action execution.
|
||||
@ -342,19 +360,14 @@ class TestExecuteAction:
|
||||
the state should be saved to cache.
|
||||
"""
|
||||
mock_state_manager.load_state.return_value = sample_game_state
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.execute_action = AsyncMock(
|
||||
return_value=ActionResult(
|
||||
success=True,
|
||||
message="Attack executed",
|
||||
state_changes=[{"type": "damage", "amount": 30}],
|
||||
)
|
||||
mock_engine.execute_action.return_value = ActionResult(
|
||||
success=True,
|
||||
message="Attack executed",
|
||||
state_changes=[{"type": "damage", "amount": 30}],
|
||||
)
|
||||
|
||||
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
|
||||
action = AttackAction(attack_index=0)
|
||||
result = await game_service.execute_action("game-123", "player-1", action)
|
||||
action = AttackAction(attack_index=0)
|
||||
result = await game_service.execute_action("game-123", "player-1", action)
|
||||
|
||||
assert result.success is True
|
||||
assert result.action_type == "attack"
|
||||
@ -434,6 +447,7 @@ class TestExecuteAction:
|
||||
self,
|
||||
game_service: GameService,
|
||||
mock_state_manager: AsyncMock,
|
||||
mock_engine: MagicMock,
|
||||
sample_game_state: GameState,
|
||||
) -> None:
|
||||
"""Test that resignation is allowed even when not your turn.
|
||||
@ -442,24 +456,19 @@ class TestExecuteAction:
|
||||
by either player.
|
||||
"""
|
||||
mock_state_manager.load_state.return_value = sample_game_state
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.execute_action = AsyncMock(
|
||||
return_value=ActionResult(
|
||||
success=True,
|
||||
message="Player resigned",
|
||||
win_result=WinResult(
|
||||
winner_id="player-1",
|
||||
loser_id="player-2",
|
||||
end_reason=GameEndReason.RESIGNATION,
|
||||
reason="Player resigned",
|
||||
),
|
||||
)
|
||||
mock_engine.execute_action.return_value = ActionResult(
|
||||
success=True,
|
||||
message="Player resigned",
|
||||
win_result=WinResult(
|
||||
winner_id="player-1",
|
||||
loser_id="player-2",
|
||||
end_reason=GameEndReason.RESIGNATION,
|
||||
reason="Player resigned",
|
||||
),
|
||||
)
|
||||
|
||||
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
|
||||
# player-2 resigns even though it's player-1's turn
|
||||
result = await game_service.execute_action("game-123", "player-2", ResignAction())
|
||||
# player-2 resigns even though it's player-1's turn
|
||||
result = await game_service.execute_action("game-123", "player-2", ResignAction())
|
||||
|
||||
assert result.success is True
|
||||
assert result.game_over is True
|
||||
@ -470,6 +479,7 @@ class TestExecuteAction:
|
||||
self,
|
||||
game_service: GameService,
|
||||
mock_state_manager: AsyncMock,
|
||||
mock_engine: MagicMock,
|
||||
sample_game_state: GameState,
|
||||
) -> None:
|
||||
"""Test execute_action raises error for invalid actions.
|
||||
@ -478,19 +488,12 @@ class TestExecuteAction:
|
||||
InvalidActionError with the reason.
|
||||
"""
|
||||
mock_state_manager.load_state.return_value = sample_game_state
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.execute_action = AsyncMock(
|
||||
return_value=ActionResult(
|
||||
success=False,
|
||||
message="Not enough energy to attack",
|
||||
)
|
||||
mock_engine.execute_action.return_value = ActionResult(
|
||||
success=False,
|
||||
message="Not enough energy to attack",
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(game_service, "_create_engine_for_game", return_value=mock_engine),
|
||||
pytest.raises(InvalidActionError) as exc_info,
|
||||
):
|
||||
with pytest.raises(InvalidActionError) as exc_info:
|
||||
await game_service.execute_action("game-123", "player-1", AttackAction(attack_index=0))
|
||||
|
||||
assert "Not enough energy" in exc_info.value.reason
|
||||
@ -500,6 +503,7 @@ class TestExecuteAction:
|
||||
self,
|
||||
game_service: GameService,
|
||||
mock_state_manager: AsyncMock,
|
||||
mock_engine: MagicMock,
|
||||
sample_game_state: GameState,
|
||||
) -> None:
|
||||
"""Test execute_action detects game over and persists to DB.
|
||||
@ -508,58 +512,26 @@ class TestExecuteAction:
|
||||
to the database for durability.
|
||||
"""
|
||||
mock_state_manager.load_state.return_value = sample_game_state
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.execute_action = AsyncMock(
|
||||
return_value=ActionResult(
|
||||
success=True,
|
||||
message="Final prize taken!",
|
||||
win_result=WinResult(
|
||||
winner_id="player-1",
|
||||
loser_id="player-2",
|
||||
end_reason=GameEndReason.PRIZES_TAKEN,
|
||||
reason="All prizes taken",
|
||||
),
|
||||
)
|
||||
mock_engine.execute_action.return_value = ActionResult(
|
||||
success=True,
|
||||
message="Final prize taken!",
|
||||
win_result=WinResult(
|
||||
winner_id="player-1",
|
||||
loser_id="player-2",
|
||||
end_reason=GameEndReason.PRIZES_TAKEN,
|
||||
reason="All prizes taken",
|
||||
),
|
||||
)
|
||||
|
||||
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
|
||||
result = await game_service.execute_action(
|
||||
"game-123", "player-1", AttackAction(attack_index=0)
|
||||
)
|
||||
result = await game_service.execute_action(
|
||||
"game-123", "player-1", AttackAction(attack_index=0)
|
||||
)
|
||||
|
||||
assert result.game_over is True
|
||||
assert result.winner_id == "player-1"
|
||||
assert result.end_reason == GameEndReason.PRIZES_TAKEN
|
||||
mock_state_manager.persist_to_db.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_action_uses_game_rules(
|
||||
self,
|
||||
game_service: GameService,
|
||||
mock_state_manager: AsyncMock,
|
||||
sample_game_state: GameState,
|
||||
) -> None:
|
||||
"""Test that execute_action creates engine with game's rules.
|
||||
|
||||
The engine should be created using the rules stored in the game
|
||||
state, not any service-level defaults.
|
||||
"""
|
||||
mock_state_manager.load_state.return_value = sample_game_state
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.execute_action = AsyncMock(
|
||||
return_value=ActionResult(success=True, message="OK")
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
game_service, "_create_engine_for_game", return_value=mock_engine
|
||||
) as mock_create:
|
||||
await game_service.execute_action("game-123", "player-1", PassAction())
|
||||
|
||||
# Verify engine was created with the game state
|
||||
mock_create.assert_called_once_with(sample_game_state)
|
||||
|
||||
|
||||
class TestResignGame:
|
||||
"""Tests for the resign_game convenience method."""
|
||||
@ -569,6 +541,7 @@ class TestResignGame:
|
||||
self,
|
||||
game_service: GameService,
|
||||
mock_state_manager: AsyncMock,
|
||||
mock_engine: MagicMock,
|
||||
sample_game_state: GameState,
|
||||
) -> None:
|
||||
"""Test that resign_game is a convenience wrapper for execute_action.
|
||||
@ -577,23 +550,18 @@ class TestResignGame:
|
||||
and call execute_action.
|
||||
"""
|
||||
mock_state_manager.load_state.return_value = sample_game_state
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.execute_action = AsyncMock(
|
||||
return_value=ActionResult(
|
||||
success=True,
|
||||
message="Player resigned",
|
||||
win_result=WinResult(
|
||||
winner_id="player-2",
|
||||
loser_id="player-1",
|
||||
end_reason=GameEndReason.RESIGNATION,
|
||||
reason="Player resigned",
|
||||
),
|
||||
)
|
||||
mock_engine.execute_action.return_value = ActionResult(
|
||||
success=True,
|
||||
message="Player resigned",
|
||||
win_result=WinResult(
|
||||
winner_id="player-2",
|
||||
loser_id="player-1",
|
||||
end_reason=GameEndReason.RESIGNATION,
|
||||
reason="Player resigned",
|
||||
),
|
||||
)
|
||||
|
||||
with patch.object(game_service, "_create_engine_for_game", return_value=mock_engine):
|
||||
result = await game_service.resign_game("game-123", "player-1")
|
||||
result = await game_service.resign_game("game-123", "player-1")
|
||||
|
||||
assert result.success is True
|
||||
assert result.action_type == "resign"
|
||||
@ -692,31 +660,39 @@ class TestCreateGame:
|
||||
assert "GS-002" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestCreateEngineForGame:
|
||||
"""Tests for the _create_engine_for_game method.
|
||||
class TestDefaultEngineFactory:
|
||||
"""Tests for the _default_engine_factory method.
|
||||
|
||||
This method is responsible for creating a GameEngine configured
|
||||
with the rules from a specific game's state.
|
||||
"""
|
||||
|
||||
def test_create_engine_uses_game_rules(
|
||||
def test_default_engine_uses_game_rules(
|
||||
self,
|
||||
game_service: GameService,
|
||||
mock_state_manager: AsyncMock,
|
||||
mock_card_service: MagicMock,
|
||||
sample_game_state: GameState,
|
||||
) -> None:
|
||||
"""Test that engine is created with the game's rules.
|
||||
"""Test that default engine is created with the game's rules.
|
||||
|
||||
The engine should use the RulesConfig stored in the game state,
|
||||
not any default configuration.
|
||||
"""
|
||||
engine = game_service._create_engine_for_game(sample_game_state)
|
||||
# Create service without mock engine to test default factory
|
||||
service = GameService(
|
||||
state_manager=mock_state_manager,
|
||||
card_service=mock_card_service,
|
||||
)
|
||||
|
||||
engine = service._default_engine_factory(sample_game_state)
|
||||
|
||||
# Engine should have the game's rules
|
||||
assert engine.rules == sample_game_state.rules
|
||||
|
||||
def test_create_engine_with_rng_seed(
|
||||
def test_default_engine_with_rng_seed(
|
||||
self,
|
||||
game_service: GameService,
|
||||
mock_state_manager: AsyncMock,
|
||||
mock_card_service: MagicMock,
|
||||
sample_game_state: GameState,
|
||||
) -> None:
|
||||
"""Test that engine uses seeded RNG when game has rng_seed.
|
||||
@ -724,17 +700,21 @@ class TestCreateEngineForGame:
|
||||
When a game has an rng_seed set, the engine should use a
|
||||
deterministic RNG for replay support.
|
||||
"""
|
||||
service = GameService(
|
||||
state_manager=mock_state_manager,
|
||||
card_service=mock_card_service,
|
||||
)
|
||||
sample_game_state.rng_seed = 12345
|
||||
|
||||
engine = game_service._create_engine_for_game(sample_game_state)
|
||||
engine = service._default_engine_factory(sample_game_state)
|
||||
|
||||
# Engine should have been created (we can't easily verify seed,
|
||||
# but we can verify it doesn't error)
|
||||
# Engine should have been created successfully
|
||||
assert engine is not None
|
||||
|
||||
def test_create_engine_without_rng_seed(
|
||||
def test_default_engine_without_rng_seed(
|
||||
self,
|
||||
game_service: GameService,
|
||||
mock_state_manager: AsyncMock,
|
||||
mock_card_service: MagicMock,
|
||||
sample_game_state: GameState,
|
||||
) -> None:
|
||||
"""Test that engine uses secure RNG when no seed is set.
|
||||
@ -742,15 +722,20 @@ class TestCreateEngineForGame:
|
||||
Without an rng_seed, the engine should use cryptographically
|
||||
secure random number generation.
|
||||
"""
|
||||
service = GameService(
|
||||
state_manager=mock_state_manager,
|
||||
card_service=mock_card_service,
|
||||
)
|
||||
sample_game_state.rng_seed = None
|
||||
|
||||
engine = game_service._create_engine_for_game(sample_game_state)
|
||||
engine = service._default_engine_factory(sample_game_state)
|
||||
|
||||
assert engine is not None
|
||||
|
||||
def test_create_engine_derives_unique_seed_per_action(
|
||||
def test_default_engine_derives_unique_seed_per_action(
|
||||
self,
|
||||
game_service: GameService,
|
||||
mock_state_manager: AsyncMock,
|
||||
mock_card_service: MagicMock,
|
||||
sample_game_state: GameState,
|
||||
) -> None:
|
||||
"""Test that different action counts produce different RNG sequences.
|
||||
@ -758,15 +743,19 @@ class TestCreateEngineForGame:
|
||||
For deterministic replay, each action needs a unique but
|
||||
reproducible RNG seed based on game seed + action count.
|
||||
"""
|
||||
service = GameService(
|
||||
state_manager=mock_state_manager,
|
||||
card_service=mock_card_service,
|
||||
)
|
||||
sample_game_state.rng_seed = 12345
|
||||
|
||||
# Simulate first action (action_log is empty)
|
||||
sample_game_state.action_log = []
|
||||
engine1 = game_service._create_engine_for_game(sample_game_state)
|
||||
engine1 = service._default_engine_factory(sample_game_state)
|
||||
|
||||
# Simulate second action (one action in log)
|
||||
sample_game_state.action_log = [{"type": "pass"}]
|
||||
engine2 = game_service._create_engine_for_game(sample_game_state)
|
||||
engine2 = service._default_engine_factory(sample_game_state)
|
||||
|
||||
# Both engines should be created successfully
|
||||
# (They will have different seeds due to action count)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user