""" Test play locking idempotency guard. """ import pytest from unittest.mock import MagicMock, AsyncMock, patch from command_logic.logic_gameplay import checks_log_interaction from exceptions import PlayLockedException @pytest.fixture def mock_session(): """Create a mock SQLAlchemy session.""" session = MagicMock() return session @pytest.fixture def mock_interaction(): """Create a mock Discord interaction.""" interaction = MagicMock() interaction.user.name = "TestUser" interaction.user.id = 12345 interaction.channel.name = "test-game-channel" interaction.channel_id = 98765 interaction.response = MagicMock() interaction.response.defer = AsyncMock() return interaction @pytest.fixture def mock_game(mock_session): """Create a mock game with current play.""" game = MagicMock() game.id = 100 game.current_play_or_none = MagicMock() return game @pytest.fixture def mock_team(): """Create a mock team.""" team = MagicMock() team.id = 50 team.abbrev = "TEST" return team @pytest.fixture def mock_play_unlocked(): """Create an unlocked mock play.""" play = MagicMock() play.id = 200 play.locked = False return play @pytest.fixture def mock_play_locked(): """Create a locked mock play.""" play = MagicMock() play.id = 200 play.locked = True return play @pytest.mark.asyncio async def test_unlocked_play_locks_successfully( mock_session, mock_interaction, mock_game, mock_team, mock_play_unlocked ): """Verify unlocked play can be locked and processed.""" mock_game.current_play_or_none.return_value = mock_play_unlocked mock_session.exec.return_value.one.return_value = mock_team with patch( "command_logic.logic_gameplay.get_channel_game_or_none", return_value=mock_game ): with patch( "command_logic.logic_gameplay.get_team_or_none", return_value=mock_team ): result_game, result_team, result_play = await checks_log_interaction( mock_session, mock_interaction, command_name="log xcheck" ) assert result_play.locked is True assert result_play.id == mock_play_unlocked.id mock_session.commit.assert_called_once() @pytest.mark.asyncio async def test_locked_play_rejects_duplicate_interaction( mock_session, mock_interaction, mock_game, mock_team, mock_play_locked ): """Verify duplicate command on locked play raises PlayLockedException.""" mock_game.current_play_or_none.return_value = mock_play_locked mock_session.exec.return_value.one.return_value = mock_team with patch( "command_logic.logic_gameplay.get_channel_game_or_none", return_value=mock_game ): with patch( "command_logic.logic_gameplay.get_team_or_none", return_value=mock_team ): with pytest.raises(PlayLockedException) as exc_info: await checks_log_interaction( mock_session, mock_interaction, command_name="log xcheck" ) assert "already being processed" in str(exc_info.value) assert "wait" in str(exc_info.value).lower() @pytest.mark.asyncio async def test_locked_play_logs_warning( mock_session, mock_interaction, mock_game, mock_team, mock_play_locked, caplog ): """Verify locked play attempt is logged with warning.""" mock_game.current_play_or_none.return_value = mock_play_locked mock_session.exec.return_value.one.return_value = mock_team with patch( "command_logic.logic_gameplay.get_channel_game_or_none", return_value=mock_game ): with patch( "command_logic.logic_gameplay.get_team_or_none", return_value=mock_team ): try: await checks_log_interaction( mock_session, mock_interaction, command_name="log xcheck" ) except PlayLockedException: pass assert any( "attempted log xcheck on locked play" in record.message for record in caplog.records ) @pytest.mark.asyncio async def test_lock_released_after_successful_completion(mock_session): """ Verify play lock is released after successful command completion. Tests the complete_play() function to ensure it: - Sets play.locked = False - Sets play.complete = True - Commits changes to database """ from command_logic.logic_gameplay import complete_play # Create mock play that's currently locked mock_play = MagicMock() mock_play.id = 300 mock_play.locked = True mock_play.complete = False mock_play.game = MagicMock() mock_play.game.id = 100 mock_play.inning_num = 1 mock_play.inning_half = "top" mock_play.starting_outs = 0 mock_play.away_score = 0 mock_play.home_score = 0 mock_play.on_base_code = 0 mock_play.batter = MagicMock() mock_play.batter.team = MagicMock() mock_play.pitcher = MagicMock() mock_play.pitcher.team = MagicMock() mock_play.managerai = MagicMock() # Mock the session.exec queries mock_session.exec.return_value.one.return_value = MagicMock() # Execute complete_play with patch("command_logic.logic_gameplay.get_one_lineup"): with patch("command_logic.logic_gameplay.get_re24", return_value=0.0): with patch("command_logic.logic_gameplay.get_wpa", return_value=0.0): complete_play(mock_session, mock_play) # Verify lock was released and play marked complete assert mock_play.locked is False assert mock_play.complete is True # Verify changes were committed mock_session.add.assert_called() mock_session.commit.assert_called_once() @pytest.mark.asyncio async def test_different_commands_racing_on_locked_play( mock_session, mock_interaction, mock_game, mock_team, mock_play_locked ): """ Verify different command types are blocked when play is locked. Tests that the lock prevents ANY command from processing, not just duplicates of the same command. This prevents race conditions where different users try different commands simultaneously. """ mock_game.current_play_or_none.return_value = mock_play_locked mock_session.exec.return_value.one.return_value = mock_team # Test different command types all raise PlayLockedException command_types = ["log walk", "log strikeout", "log single", "log xcheck"] for command_name in command_types: with patch( "command_logic.logic_gameplay.get_channel_game_or_none", return_value=mock_game, ): with patch( "command_logic.logic_gameplay.get_team_or_none", return_value=mock_team, ): with pytest.raises(PlayLockedException) as exc_info: await checks_log_interaction( mock_session, mock_interaction, command_name=command_name ) # Verify exception message is consistent assert "already being processed" in str(exc_info.value) assert "wait" in str(exc_info.value).lower()