""" Unit tests for AIService. Tests AI decision-making business logic extracted from ManagerAi model. """ import pytest from unittest.mock import Mock, MagicMock from sqlmodel import Session from app.services.ai_service import AIService from app.models.manager_ai import ManagerAi from app.models.ai_responses import ( JumpResponse, TagResponse, ThrowResponse, UncappedRunResponse, DefenseResponse, RunResponse, ) @pytest.fixture def mock_session(): """Create mock database session.""" return Mock(spec=Session) @pytest.fixture def ai_service(mock_session): """Create AIService instance with mocked session.""" return AIService(mock_session) @pytest.fixture def balanced_ai(): """Create balanced ManagerAi configuration.""" return ManagerAi( name="Balanced", steal=5, running=5, hold=5, catcher_throw=5, uncapped_home=5, uncapped_third=5, uncapped_trail=5, bullpen_matchup=5, behind_aggression=5, ahead_aggression=5, decide_throw=5 ) @pytest.fixture def aggressive_ai(): """Create aggressive ManagerAi configuration.""" return ManagerAi( name="Yolo", steal=10, running=10, hold=5, catcher_throw=10, uncapped_home=10, uncapped_third=10, uncapped_trail=10, bullpen_matchup=3, behind_aggression=10, ahead_aggression=10, decide_throw=10 ) @pytest.fixture def conservative_ai(): """Create conservative ManagerAi configuration.""" return ManagerAi( name="Safe", steal=3, running=3, hold=8, catcher_throw=5, uncapped_home=5, uncapped_third=3, uncapped_trail=5, bullpen_matchup=8, behind_aggression=5, ahead_aggression=1, decide_throw=1 ) @pytest.fixture def mock_game(): """Create mock game object.""" game = Mock() game.id = 1 game.ai_team = 'home' return game @pytest.fixture def mock_play(): """Create mock play object.""" play = Mock() play.starting_outs = 0 play.outs = 0 play.away_score = 3 play.home_score = 3 play.inning_num = 5 play.on_base_code = 1 play.ai_run_diff = 0 play.could_walkoff = False play.is_new_inning = False # Mock runners play.on_first = Mock() play.on_first.player.name = "Runner One" play.on_first.card.batterscouting.battingcard.steal_auto = False play.on_first.card.batterscouting.battingcard.steal_high = 15 play.on_first.card.batterscouting.battingcard.steal_low = 12 play.on_second = Mock() play.on_second.player.name = "Runner Two" play.on_second.card.batterscouting.battingcard.steal_auto = False play.on_second.card.batterscouting.battingcard.steal_low = 10 play.on_third = Mock() play.on_third.player.name = "Runner Three" play.on_third.card.batterscouting.battingcard.steal_low = 8 # Mock pitcher and catcher play.pitcher.card.pitcherscouting.pitchingcard.hold = 3 play.catcher.player_id = 100 play.catcher.card.variant = 0 return play class TestAIServiceInitialization: """Test AIService initialization and basic functionality.""" def test_initialization(self, mock_session): """Test AIService initializes correctly.""" service = AIService(mock_session) assert service.session == mock_session assert service.logger is not None def test_inherits_from_base_service(self, ai_service): """Test AIService inherits BaseService functionality.""" assert hasattr(ai_service, '_log_operation') assert hasattr(ai_service, '_log_error') assert hasattr(ai_service, '_validate_required_fields') class TestCheckStealOpportunity: """Test check_steal_opportunity method.""" def test_steal_to_second_aggressive(self, ai_service, aggressive_ai, mock_game, mock_play): """Test steal decision to second base with aggressive AI.""" mock_game.current_play_or_none.return_value = mock_play mock_catcher_defense = Mock() mock_catcher_defense.arm = 5 ai_service.session.exec.return_value.one.return_value = mock_catcher_defense result = ai_service.check_steal_opportunity(aggressive_ai, mock_game, 2) assert isinstance(result, JumpResponse) assert result.min_safe == 12 # 12 + 0 outs for steal=10 assert result.run_if_auto_jump is True # steal > 7 def test_steal_to_second_conservative(self, ai_service, conservative_ai, mock_game, mock_play): """Test steal decision to second base with conservative AI.""" mock_game.current_play_or_none.return_value = mock_play mock_catcher_defense = Mock() mock_catcher_defense.arm = 5 ai_service.session.exec.return_value.one.return_value = mock_catcher_defense result = ai_service.check_steal_opportunity(conservative_ai, mock_game, 2) assert isinstance(result, JumpResponse) assert result.min_safe == 16 # 16 + 0 outs for steal=3 assert result.must_auto_jump is True # steal < 5 def test_steal_to_third(self, ai_service, aggressive_ai, mock_game, mock_play): """Test steal decision to third base.""" mock_game.current_play_or_none.return_value = mock_play mock_catcher_defense = Mock() mock_catcher_defense.arm = 5 ai_service.session.exec.return_value.one.return_value = mock_catcher_defense result = ai_service.check_steal_opportunity(aggressive_ai, mock_game, 3) assert isinstance(result, JumpResponse) assert result.min_safe == 12 # 12 + 0 outs for steal=10 assert result.run_if_auto_jump is True def test_no_current_play_raises_error(self, ai_service, balanced_ai, mock_game): """Test that missing current play raises ValueError.""" mock_game.current_play_or_none.return_value = None with pytest.raises(ValueError, match="No game found while checking for steal"): ai_service.check_steal_opportunity(balanced_ai, mock_game, 2) def test_no_runner_on_first_raises_error(self, ai_service, balanced_ai, mock_game, mock_play): """Test that missing runner on first raises ValueError.""" mock_play.on_first = None mock_game.current_play_or_none.return_value = mock_play mock_catcher_defense = Mock() mock_catcher_defense.arm = 5 ai_service.session.exec.return_value.one.return_value = mock_catcher_defense with pytest.raises(ValueError, match="no runner found on first"): ai_service.check_steal_opportunity(balanced_ai, mock_game, 2) class TestTagDecisions: """Test tag-up decision methods.""" def test_tag_from_second_aggressive(self, ai_service, aggressive_ai, mock_game, mock_play): """Test tag from second with aggressive AI.""" mock_game.current_play_or_none.return_value = mock_play result = ai_service.check_tag_from_second(aggressive_ai, mock_game) assert isinstance(result, TagResponse) # aggressive_ai.running=10 + aggression_mod=5 = 15 >= 8, so min_safe=4 # starting_outs=0 != 1, so +2, final=6 assert result.min_safe == 6 def test_tag_from_second_conservative(self, ai_service, conservative_ai, mock_game, mock_play): """Test tag from second with conservative AI.""" mock_game.current_play_or_none.return_value = mock_play result = ai_service.check_tag_from_second(conservative_ai, mock_game) assert isinstance(result, TagResponse) # conservative_ai.running=3 + aggression_mod=4 = 7 < 8, so min_safe=10 # starting_outs=0 != 1, so +2, final=12 assert result.min_safe == 12 def test_tag_from_third_one_out(self, ai_service, balanced_ai, mock_game, mock_play): """Test tag from third with one out.""" mock_play.starting_outs = 1 mock_play.ai_run_diff = 2 # Not in [-1, 0] range to avoid extra -2 mock_game.current_play_or_none.return_value = mock_play result = ai_service.check_tag_from_third(balanced_ai, mock_game) assert isinstance(result, TagResponse) # balanced_ai.running=5 + aggression_mod=0 = 5 < 8, so min_safe=10 # starting_outs=1, so -2, final=8 assert result.min_safe == 8 class TestThrowDecisions: """Test throw target decision methods.""" def test_throw_decision_big_lead(self, ai_service, aggressive_ai, mock_game, mock_play): """Test throw decision when AI has big lead.""" mock_play.ai_run_diff = 6 # Big lead mock_game.current_play_or_none.return_value = mock_play result = ai_service.decide_throw_target(aggressive_ai, mock_game) assert isinstance(result, ThrowResponse) assert result.at_trail_runner is True assert result.trail_max_safe_delta == -4 # -4 + 0 current_outs def test_throw_decision_close_game(self, ai_service, balanced_ai, mock_game, mock_play): """Test throw decision in close game.""" mock_play.ai_run_diff = 0 # Tied game mock_game.current_play_or_none.return_value = mock_play result = ai_service.decide_throw_target(balanced_ai, mock_game) assert isinstance(result, ThrowResponse) # Default values for close game with balanced AI assert result.at_trail_runner is False assert result.cutoff is False class TestRunnerAdvanceDecisions: """Test runner advance decision methods.""" def test_uncapped_advance_to_home(self, ai_service, aggressive_ai, mock_game, mock_play): """Test uncapped advance decision for runner going home.""" mock_play.ai_run_diff = 2 mock_game.current_play_or_none.return_value = mock_play result = ai_service.decide_runner_advance(aggressive_ai, mock_game, 4, 3) assert isinstance(result, UncappedRunResponse) # ai_rd=2, lead_base=4: min_safe = 12 - 0 - 5 = 7 assert result.min_safe == 7 assert result.send_trail is True def test_uncapped_advance_bounds_checking(self, ai_service, aggressive_ai, mock_game, mock_play): """Test that advance decisions respect bounds.""" mock_play.ai_run_diff = -10 # Way behind mock_play.starting_outs = 2 mock_game.current_play_or_none.return_value = mock_play result = ai_service.decide_runner_advance(aggressive_ai, mock_game, 4, 3) assert isinstance(result, UncappedRunResponse) # Should be bounded between 1 and 20 assert 1 <= result.min_safe <= 20 assert 1 <= result.trail_min_safe <= 20 class TestDefensiveAlignment: """Test defensive alignment decisions.""" def test_defense_with_runner_on_third_walkoff(self, ai_service, balanced_ai, mock_game, mock_play): """Test defensive alignment with walkoff situation.""" mock_play.on_third = Mock() mock_play.on_third.player.name = "Walkoff Runner" mock_play.could_walkoff = True mock_play.starting_outs = 1 mock_game.current_play_or_none.return_value = mock_play mock_catcher_defense = Mock() mock_catcher_defense.arm = 5 ai_service.session.exec.return_value.one.return_value = mock_catcher_defense result = ai_service.set_defensive_alignment(balanced_ai, mock_game) assert isinstance(result, DefenseResponse) assert result.outfield_in is True assert result.infield_in is True assert "play the outfield and infield in" in result.ai_note def test_defense_two_outs_hold_runners(self, ai_service, balanced_ai, mock_game, mock_play): """Test defensive holds with two outs.""" mock_play.starting_outs = 2 mock_play.on_base_code = 1 # Runner on first mock_game.current_play_or_none.return_value = mock_play mock_catcher_defense = Mock() mock_catcher_defense.arm = 5 ai_service.session.exec.return_value.one.return_value = mock_catcher_defense result = ai_service.set_defensive_alignment(balanced_ai, mock_game) assert isinstance(result, DefenseResponse) assert result.hold_first is True assert "hold Runner One on 1st" in result.ai_note class TestGroundballDecisions: """Test groundball-specific decisions.""" def test_groundball_running_decision(self, ai_service, balanced_ai, mock_game, mock_play): """Test groundball running decision.""" mock_game.current_play_or_none.return_value = mock_play result = ai_service.decide_groundball_running(balanced_ai, mock_game) assert isinstance(result, RunResponse) # min_safe = 15 - aggression(0) = 15 assert result.min_safe == 15 def test_groundball_throw_decision(self, ai_service, balanced_ai, mock_game, mock_play): """Test groundball throw decision.""" mock_game.current_play_or_none.return_value = mock_play result = ai_service.decide_groundball_throw(balanced_ai, mock_game, 10, 3) assert isinstance(result, ThrowResponse) # (10 - 4 + 3) = 9 <= (10 + 0) = 10, so at_lead_runner=True assert result.at_lead_runner is True class TestPitcherReplacement: """Test pitcher replacement decisions.""" def test_should_replace_fatigued_starter(self, ai_service, balanced_ai, mock_game, mock_play): """Test pitcher replacement for fatigued starter.""" mock_play.pitcher.replacing_id = None # This is a starter mock_play.pitcher.is_fatigued = True mock_play.on_base_code = 2 # Runners on base mock_play.pitcher.card.pitcherscouting.pitchingcard.starter_rating = 5 mock_game.current_play_or_none.return_value = mock_play # Mock database queries ai_service.session.exec.return_value.one.side_effect = [18, 6] # 18 outs, 6 allowed runners result = ai_service.should_replace_pitcher(balanced_ai, mock_game) assert result is True # Fatigued starter with runners should be replaced def test_should_keep_effective_starter(self, ai_service, balanced_ai, mock_game, mock_play): """Test keeping effective starter.""" mock_play.pitcher.replacing_id = None # This is a starter mock_play.pitcher.is_fatigued = False mock_play.on_base_code = 0 # No runners mock_play.pitcher.card.pitcherscouting.pitchingcard.starter_rating = 6 mock_game.current_play_or_none.return_value = mock_play # Mock database queries - effective pitcher ai_service.session.exec.return_value.one.side_effect = [15, 2] # 15 outs, 2 allowed runners result = ai_service.should_replace_pitcher(balanced_ai, mock_game) assert result is False # Effective starter should stay in def test_should_replace_overworked_reliever(self, ai_service, balanced_ai, mock_game, mock_play): """Test replacing overworked reliever.""" mock_play.pitcher.replacing_id = 123 # This is a reliever mock_play.pitcher.card.pitcherscouting.pitchingcard.relief_rating = 3 mock_game.current_play_or_none.return_value = mock_play # Mock database queries - overworked reliever ai_service.session.exec.return_value.one.side_effect = [12, 4] # 12 outs (4 IP), 4 allowed runners result = ai_service.should_replace_pitcher(balanced_ai, mock_game) assert result is True # Overworked reliever should be replaced class TestErrorHandling: """Test error handling in AIService methods.""" def test_methods_handle_no_current_play(self, ai_service, balanced_ai, mock_game): """Test that all methods handle missing current play gracefully.""" mock_game.current_play_or_none.return_value = None methods_to_test = [ (ai_service.check_tag_from_second, (balanced_ai, mock_game)), (ai_service.check_tag_from_third, (balanced_ai, mock_game)), (ai_service.decide_throw_target, (balanced_ai, mock_game)), (ai_service.decide_runner_advance, (balanced_ai, mock_game, 4, 3)), (ai_service.set_defensive_alignment, (balanced_ai, mock_game)), (ai_service.decide_groundball_running, (balanced_ai, mock_game)), (ai_service.decide_groundball_throw, (balanced_ai, mock_game, 10, 3)), (ai_service.should_replace_pitcher, (balanced_ai, mock_game)), ] for method, args in methods_to_test: with pytest.raises(ValueError, match="No game found"): method(*args)