mantimon-tcg/backend/tests/unit/schemas/test_ws_messages.py
Cal Corum 0c810e5b30 Add Phase 4 WebSocket infrastructure (WS-001 through GS-001)
WebSocket Message Schemas (WS-002):
- Add Pydantic models for all client/server WebSocket messages
- Implement discriminated unions for message type parsing
- Include JoinGame, Action, Resign, Heartbeat client messages
- Include GameState, ActionResult, Error, TurnStart server messages

Connection Manager (WS-003):
- Add Redis-backed WebSocket connection tracking
- Implement user-to-sid mapping with TTL management
- Support game room association and opponent lookup
- Add heartbeat tracking for connection health

Socket.IO Authentication (WS-004):
- Add JWT-based authentication middleware
- Support token extraction from multiple formats
- Implement session setup with ConnectionManager integration
- Add require_auth helper for event handlers

Socket.IO Server Setup (WS-001):
- Configure AsyncServer with ASGI mode
- Register /game namespace with event handlers
- Integrate with FastAPI via ASGIApp wrapper
- Configure CORS from application settings

Game Service (GS-001):
- Add stateless GameService for game lifecycle orchestration
- Create engine per-operation using rules from GameState
- Implement action-based RNG seeding for deterministic replay
- Add rng_seed field to GameState for replay support

Architecture verified:
- Core module independence (no forbidden imports)
- Config from request pattern (rules in GameState)
- Dependency injection (constructor deps, method config)
- All 1090 tests passing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 22:21:20 -06:00

702 lines
22 KiB
Python

"""Tests for WebSocket message schemas.
This module tests the Pydantic models for WebSocket communication, verifying
discriminated union parsing, field validation, and serialization behavior.
"""
from datetime import UTC, datetime
import pytest
from pydantic import TypeAdapter, ValidationError
from app.core.enums import GameEndReason, TurnPhase
from app.core.models.actions import AttackAction, PassAction
from app.core.visibility import VisibleGameState, VisiblePlayerState
from app.schemas.ws_messages import (
ActionMessage,
ActionResultMessage,
ClientMessage,
ConnectionStatus,
ErrorMessage,
GameOverMessage,
GameStateMessage,
HeartbeatAckMessage,
HeartbeatMessage,
JoinGameMessage,
OpponentStatusMessage,
ResignMessage,
ServerMessage,
TurnStartMessage,
TurnTimeoutMessage,
WSErrorCode,
parse_client_message,
parse_server_message,
)
class TestClientMessageDiscriminator:
"""Tests for client message discriminated union parsing.
The discriminated union pattern allows automatic type resolution based on
the 'type' field, enabling type-safe message handling without explicit
type checking logic.
"""
def test_parse_join_game_message(self) -> None:
"""Test parsing JoinGameMessage from dictionary.
JoinGameMessage is the first message clients send to establish their
presence in a game session. The discriminator should correctly resolve
the type from the raw dictionary.
"""
data = {
"type": "join_game",
"message_id": "test-123",
"game_id": "game-456",
}
message = parse_client_message(data)
assert isinstance(message, JoinGameMessage)
assert message.game_id == "game-456"
assert message.message_id == "test-123"
assert message.last_event_id is None
def test_parse_join_game_with_last_event_id(self) -> None:
"""Test JoinGameMessage with reconnection event ID.
When reconnecting after a disconnect, clients provide last_event_id
to receive any missed events since that point, enabling seamless
session recovery.
"""
data = {
"type": "join_game",
"message_id": "test-123",
"game_id": "game-456",
"last_event_id": "event-789",
}
message = parse_client_message(data)
assert isinstance(message, JoinGameMessage)
assert message.last_event_id == "event-789"
def test_parse_action_message(self) -> None:
"""Test parsing ActionMessage with embedded game action.
ActionMessage wraps the Action union type, so the discriminator must
handle both the message type and the nested action type correctly.
"""
data = {
"type": "action",
"message_id": "test-123",
"game_id": "game-456",
"action": {"type": "attack", "attack_index": 0},
}
message = parse_client_message(data)
assert isinstance(message, ActionMessage)
assert message.game_id == "game-456"
assert isinstance(message.action, AttackAction)
assert message.action.attack_index == 0
def test_parse_action_message_with_pass(self) -> None:
"""Test parsing ActionMessage with PassAction.
PassAction is the simplest action type - verifies that minimal
action payloads are handled correctly.
"""
data = {
"type": "action",
"message_id": "test-123",
"game_id": "game-456",
"action": {"type": "pass"},
}
message = parse_client_message(data)
assert isinstance(message, ActionMessage)
assert isinstance(message.action, PassAction)
def test_parse_resign_message(self) -> None:
"""Test parsing ResignMessage.
ResignMessage allows resignation at the WebSocket layer, separate from
the game engine's ResignAction, for cases where engine interaction
may not be possible.
"""
data = {
"type": "resign",
"message_id": "test-123",
"game_id": "game-456",
}
message = parse_client_message(data)
assert isinstance(message, ResignMessage)
assert message.game_id == "game-456"
def test_parse_heartbeat_message(self) -> None:
"""Test parsing HeartbeatMessage.
Heartbeat messages keep the WebSocket connection alive and should
be the most minimal message type with just the type discriminator.
"""
data = {
"type": "heartbeat",
"message_id": "test-123",
}
message = parse_client_message(data)
assert isinstance(message, HeartbeatMessage)
def test_unknown_message_type_raises_error(self) -> None:
"""Test that unknown message types raise ValidationError.
The discriminated union should fail fast on unknown types rather
than silently accepting invalid messages.
"""
data = {
"type": "unknown_type",
"message_id": "test-123",
}
with pytest.raises(ValidationError):
parse_client_message(data)
def test_message_id_auto_generated(self) -> None:
"""Test that message_id is auto-generated if not provided.
Clients may omit message_id in which case a UUID should be generated,
ensuring all messages have a unique identifier for tracking.
"""
data = {
"type": "heartbeat",
}
message = parse_client_message(data)
assert message.message_id is not None
assert len(message.message_id) > 0
def test_empty_message_id_rejected(self) -> None:
"""Test that empty message_id is rejected.
An explicitly empty message_id would break idempotency tracking,
so it should be rejected by validation.
"""
data = {
"type": "heartbeat",
"message_id": "",
}
with pytest.raises(ValidationError):
parse_client_message(data)
def test_type_adapter_client_message(self) -> None:
"""Test ClientMessage union with TypeAdapter.
TypeAdapter is the recommended way to work with discriminated unions
in Pydantic v2 - verifies it works for both parsing and validation.
"""
adapter = TypeAdapter(ClientMessage)
data = {"type": "heartbeat", "message_id": "test"}
message = adapter.validate_python(data)
assert isinstance(message, HeartbeatMessage)
class TestServerMessageDiscriminator:
"""Tests for server message discriminated union parsing.
Server messages are more complex than client messages, often containing
nested game state. The discriminator must handle all message types and
their embedded data correctly.
"""
def test_parse_game_state_message(self) -> None:
"""Test parsing GameStateMessage with full visible state.
GameStateMessage carries the complete game state from a player's
perspective, used for initial sync and full state updates.
"""
visible_state = VisibleGameState(
game_id="game-456",
viewer_id="player-1",
players={
"player-1": VisiblePlayerState(player_id="player-1"),
"player-2": VisiblePlayerState(player_id="player-2"),
},
current_player_id="player-1",
turn_number=1,
phase=TurnPhase.MAIN,
is_my_turn=True,
)
message = GameStateMessage(
game_id="game-456",
state=visible_state,
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, GameStateMessage)
assert parsed.game_id == "game-456"
assert parsed.state.viewer_id == "player-1"
assert parsed.event_id is not None
def test_parse_action_result_success(self) -> None:
"""Test parsing successful ActionResultMessage.
Successful action results should include the changes that occurred
as a result of the action for client-side state updates.
"""
message = ActionResultMessage(
game_id="game-456",
request_message_id="action-123",
success=True,
action_type="attack",
changes={"damage_dealt": 30, "defender_hp": 70},
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, ActionResultMessage)
assert parsed.success is True
assert parsed.changes["damage_dealt"] == 30
assert parsed.error_code is None
def test_parse_action_result_failure(self) -> None:
"""Test parsing failed ActionResultMessage.
Failed action results should include error code and message
for client-side error handling and user feedback.
"""
message = ActionResultMessage(
game_id="game-456",
request_message_id="action-123",
success=False,
action_type="attack",
error_code=WSErrorCode.NOT_YOUR_TURN,
error_message="It's not your turn",
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, ActionResultMessage)
assert parsed.success is False
assert parsed.error_code == WSErrorCode.NOT_YOUR_TURN
def test_parse_error_message(self) -> None:
"""Test parsing ErrorMessage.
Error messages are for protocol-level errors not tied to a specific
action, like authentication failures or invalid message format.
"""
message = ErrorMessage(
code=WSErrorCode.AUTHENTICATION_FAILED,
message="Invalid token",
details={"reason": "expired"},
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, ErrorMessage)
assert parsed.code == WSErrorCode.AUTHENTICATION_FAILED
assert parsed.details["reason"] == "expired"
def test_parse_turn_start_message(self) -> None:
"""Test parsing TurnStartMessage.
TurnStartMessage notifies both players when a turn begins,
enabling UI updates like starting a turn timer.
"""
message = TurnStartMessage(
game_id="game-456",
player_id="player-1",
turn_number=5,
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, TurnStartMessage)
assert parsed.turn_number == 5
assert parsed.player_id == "player-1"
def test_parse_turn_timeout_warning(self) -> None:
"""Test parsing TurnTimeoutMessage as warning.
Timeout warnings give players a chance to complete their action
before the turn is forcibly ended.
"""
message = TurnTimeoutMessage(
game_id="game-456",
remaining_seconds=30,
is_warning=True,
player_id="player-1",
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, TurnTimeoutMessage)
assert parsed.is_warning is True
assert parsed.remaining_seconds == 30
def test_parse_turn_timeout_expired(self) -> None:
"""Test parsing TurnTimeoutMessage as expired.
When is_warning is False, the timeout has occurred and the game
engine will force the turn to end.
"""
message = TurnTimeoutMessage(
game_id="game-456",
remaining_seconds=0,
is_warning=False,
player_id="player-1",
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, TurnTimeoutMessage)
assert parsed.is_warning is False
assert parsed.remaining_seconds == 0
def test_parse_game_over_with_winner(self) -> None:
"""Test parsing GameOverMessage with a winner.
Most games end with a winner - the final state is included so
clients can display the final board position.
"""
final_state = VisibleGameState(
game_id="game-456",
viewer_id="player-1",
winner_id="player-1",
end_reason=GameEndReason.PRIZES_TAKEN,
)
message = GameOverMessage(
game_id="game-456",
winner_id="player-1",
end_reason=GameEndReason.PRIZES_TAKEN,
final_state=final_state,
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, GameOverMessage)
assert parsed.winner_id == "player-1"
assert parsed.end_reason == GameEndReason.PRIZES_TAKEN
def test_parse_game_over_draw(self) -> None:
"""Test parsing GameOverMessage as a draw.
Games can end in a draw (e.g., timeout with equal scores),
in which case winner_id is None.
"""
final_state = VisibleGameState(
game_id="game-456",
viewer_id="player-1",
end_reason=GameEndReason.DRAW,
)
message = GameOverMessage(
game_id="game-456",
winner_id=None,
end_reason=GameEndReason.DRAW,
final_state=final_state,
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, GameOverMessage)
assert parsed.winner_id is None
assert parsed.end_reason == GameEndReason.DRAW
def test_parse_opponent_status_connected(self) -> None:
"""Test parsing OpponentStatusMessage for connection.
Notifies players when their opponent connects, enabling
connection status display in the UI.
"""
message = OpponentStatusMessage(
game_id="game-456",
opponent_id="player-2",
status=ConnectionStatus.CONNECTED,
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, OpponentStatusMessage)
assert parsed.status == ConnectionStatus.CONNECTED
def test_parse_opponent_status_disconnected(self) -> None:
"""Test parsing OpponentStatusMessage for disconnection.
When an opponent disconnects, the UI can show a waiting indicator
and possibly pause the game timer.
"""
message = OpponentStatusMessage(
game_id="game-456",
opponent_id="player-2",
status=ConnectionStatus.DISCONNECTED,
)
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, OpponentStatusMessage)
assert parsed.status == ConnectionStatus.DISCONNECTED
def test_parse_heartbeat_ack(self) -> None:
"""Test parsing HeartbeatAckMessage.
HeartbeatAck confirms the connection is alive, allowing clients
to measure round-trip latency.
"""
message = HeartbeatAckMessage()
data = message.model_dump()
parsed = parse_server_message(data)
assert isinstance(parsed, HeartbeatAckMessage)
def test_server_message_has_timestamp(self) -> None:
"""Test that all server messages have a timestamp.
Timestamps enable client-side latency tracking and event ordering
in case of out-of-order message delivery.
"""
message = HeartbeatAckMessage()
assert message.timestamp is not None
assert isinstance(message.timestamp, datetime)
# Should be UTC
assert message.timestamp.tzinfo == UTC
def test_type_adapter_server_message(self) -> None:
"""Test ServerMessage union with TypeAdapter.
Verifies TypeAdapter works correctly for the more complex
server message union with nested types.
"""
adapter = TypeAdapter(ServerMessage)
error_data = {
"type": "error",
"message_id": "test",
"timestamp": datetime.now(UTC).isoformat(),
"code": "game_not_found",
"message": "Game not found",
}
message = adapter.validate_python(error_data)
assert isinstance(message, ErrorMessage)
class TestMessageSerialization:
"""Tests for message JSON serialization and round-trip behavior.
Messages must serialize cleanly to JSON for WebSocket transmission
and deserialize back to the same logical structure.
"""
def test_client_message_round_trip(self) -> None:
"""Test client message JSON round-trip.
Messages should survive serialization to JSON and back without
losing data or changing types.
"""
original = ActionMessage(
message_id="test-123",
game_id="game-456",
action=AttackAction(attack_index=1, targets=["target-1"]),
)
json_str = original.model_dump_json()
import json
data = json.loads(json_str)
restored = parse_client_message(data)
assert isinstance(restored, ActionMessage)
assert restored.message_id == original.message_id
assert isinstance(restored.action, AttackAction)
assert restored.action.attack_index == 1
assert restored.action.targets == ["target-1"]
def test_server_message_round_trip(self) -> None:
"""Test server message JSON round-trip.
Server messages with nested state should serialize and deserialize
correctly, preserving all game state information.
"""
visible_state = VisibleGameState(
game_id="game-456",
viewer_id="player-1",
turn_number=3,
phase=TurnPhase.ATTACK,
)
original = GameStateMessage(
message_id="test-123",
game_id="game-456",
state=visible_state,
)
json_str = original.model_dump_json()
import json
data = json.loads(json_str)
restored = parse_server_message(data)
assert isinstance(restored, GameStateMessage)
assert restored.state.turn_number == 3
assert restored.state.phase == TurnPhase.ATTACK
def test_error_code_serializes_as_string(self) -> None:
"""Test that WSErrorCode serializes as string value.
Error codes should serialize to their string values for clean
JSON output that clients can easily handle.
"""
message = ErrorMessage(
code=WSErrorCode.GAME_NOT_FOUND,
message="Game not found",
)
json_str = message.model_dump_json()
import json
data = json.loads(json_str)
assert data["code"] == "game_not_found"
class TestMessageValidation:
"""Tests for message field validation rules.
Validation ensures messages contain valid data before processing,
preventing invalid state from propagating through the system.
"""
def test_join_game_requires_game_id(self) -> None:
"""Test that JoinGameMessage requires game_id.
game_id is mandatory - clients must specify which game to join.
"""
data = {
"type": "join_game",
"message_id": "test-123",
# missing game_id
}
with pytest.raises(ValidationError):
parse_client_message(data)
def test_action_message_requires_action(self) -> None:
"""Test that ActionMessage requires action field.
The action field is mandatory - messages without an action
cannot be processed.
"""
data = {
"type": "action",
"message_id": "test-123",
"game_id": "game-456",
# missing action
}
with pytest.raises(ValidationError):
parse_client_message(data)
def test_turn_number_must_be_positive(self) -> None:
"""Test that turn_number must be >= 1.
Turn numbers start at 1 - zero or negative values are invalid.
"""
with pytest.raises(ValidationError):
TurnStartMessage(
game_id="game-456",
player_id="player-1",
turn_number=0, # Invalid
)
def test_remaining_seconds_non_negative(self) -> None:
"""Test that remaining_seconds must be >= 0.
Negative timeout values are meaningless and should be rejected.
"""
with pytest.raises(ValidationError):
TurnTimeoutMessage(
game_id="game-456",
remaining_seconds=-1, # Invalid
is_warning=True,
player_id="player-1",
)
def test_action_result_requires_request_id(self) -> None:
"""Test that ActionResultMessage requires request_message_id.
Results must link back to the original request for client-side
correlation of actions and responses.
"""
with pytest.raises(ValidationError):
ActionResultMessage(
game_id="game-456",
# missing request_message_id
success=True,
action_type="attack",
)
class TestWSErrorCode:
"""Tests for WebSocket error code enumeration.
Error codes provide machine-readable error classification for
programmatic error handling on the client.
"""
def test_all_error_codes_are_strings(self) -> None:
"""Test that all error codes are string values.
StrEnum ensures all values serialize as strings for JSON
compatibility.
"""
for code in WSErrorCode:
assert isinstance(code.value, str)
def test_error_codes_are_lowercase_snake_case(self) -> None:
"""Test error code naming convention.
Error codes follow lowercase_snake_case for consistency with
other API identifiers and JSON conventions.
"""
for code in WSErrorCode:
assert code.value == code.value.lower()
assert " " not in code.value
class TestConnectionStatus:
"""Tests for connection status enumeration."""
def test_all_statuses_are_strings(self) -> None:
"""Test that all connection statuses are string values."""
for status in ConnectionStatus:
assert isinstance(status.value, str)
def test_expected_statuses_exist(self) -> None:
"""Test that expected connection statuses are defined.
The three key states (connected, disconnected, reconnecting)
cover all connection lifecycle phases.
"""
assert ConnectionStatus.CONNECTED
assert ConnectionStatus.DISCONNECTED
assert ConnectionStatus.RECONNECTING