diff --git a/bot.py b/bot.py index 3e2f862..3500209 100644 --- a/bot.py +++ b/bot.py @@ -64,6 +64,44 @@ def setup_logging(): return logger +class MaintenanceAwareTree(discord.app_commands.CommandTree): + """ + CommandTree subclass that gates all interactions behind a maintenance mode check. + + When bot.maintenance_mode is True, non-administrator users receive an ephemeral + error message and the interaction is blocked. Administrators are always allowed + through. When maintenance_mode is False the check is a no-op and every + interaction proceeds normally. + + This is the correct way to register a global interaction_check for app commands + in discord.py — overriding the method on a CommandTree subclass passed via + tree_cls rather than attempting to assign a decorator to self.tree inside + setup_hook. + """ + + async def interaction_check(self, interaction: discord.Interaction) -> bool: + """Allow admins through; block everyone else when maintenance mode is active.""" + bot = interaction.client # type: ignore[assignment] + + # If maintenance mode is off, always allow. + if not getattr(bot, "maintenance_mode", False): + return True + + # Maintenance mode is on — let administrators through unconditionally. + if ( + isinstance(interaction.user, discord.Member) + and interaction.user.guild_permissions.administrator + ): + return True + + # Block non-admin users with an ephemeral notice. + await interaction.response.send_message( + "The bot is currently in maintenance mode. Please try again later.", + ephemeral=True, + ) + return False + + class SBABot(commands.Bot): """Custom bot class for SBA league management.""" @@ -77,8 +115,10 @@ class SBABot(commands.Bot): command_prefix="!", # Legacy prefix, primarily using slash commands intents=intents, description="Major Domo v2.0", + tree_cls=MaintenanceAwareTree, ) + self.maintenance_mode: bool = False self.logger = logging.getLogger("discord_bot_v2") self.maintenance_mode: bool = False diff --git a/commands/admin/management.py b/commands/admin/management.py index 8950c4c..b738e2c 100644 --- a/commands/admin/management.py +++ b/commands/admin/management.py @@ -490,7 +490,7 @@ class AdminCommands(commands.Cog): await interaction.response.defer() is_enabling = mode.lower() == "on" - self.bot.maintenance_mode = is_enabling + self.bot.maintenance_mode = is_enabling # type: ignore[attr-defined] self.logger.info( f"Maintenance mode {'enabled' if is_enabling else 'disabled'} by {interaction.user} (id={interaction.user.id})" ) diff --git a/commands/league/schedule.py b/commands/league/schedule.py index 26dd224..7644d84 100644 --- a/commands/league/schedule.py +++ b/commands/league/schedule.py @@ -3,6 +3,7 @@ League Schedule Commands Implements slash commands for displaying game schedules and results. """ + from typing import Optional import asyncio @@ -19,19 +20,16 @@ from views.embeds import EmbedColors, EmbedTemplate class ScheduleCommands(commands.Cog): """League schedule command handlers.""" - + def __init__(self, bot: commands.Bot): self.bot = bot - self.logger = get_contextual_logger(f'{__name__}.ScheduleCommands') - - @discord.app_commands.command( - name="schedule", - description="Display game schedule" - ) + self.logger = get_contextual_logger(f"{__name__}.ScheduleCommands") + + @discord.app_commands.command(name="schedule", description="Display game schedule") @discord.app_commands.describe( season="Season to show schedule for (defaults to current season)", week="Week number to show (optional)", - team="Team abbreviation to filter by (optional)" + team="Team abbreviation to filter by (optional)", ) @requires_team() @logged_command("/schedule") @@ -40,13 +38,13 @@ class ScheduleCommands(commands.Cog): interaction: discord.Interaction, season: Optional[int] = None, week: Optional[int] = None, - team: Optional[str] = None + team: Optional[str] = None, ): """Display game schedule for a week or team.""" await interaction.response.defer() - + search_season = season or get_config().sba_season - + if team: # Show team schedule await self._show_team_schedule(interaction, search_season, team, week) @@ -56,7 +54,7 @@ class ScheduleCommands(commands.Cog): else: # Show recent/upcoming games await self._show_current_schedule(interaction, search_season) - + # @discord.app_commands.command( # name="results", # description="Display recent game results" @@ -74,282 +72,316 @@ class ScheduleCommands(commands.Cog): # ): # """Display recent game results.""" # await interaction.response.defer() - + # search_season = season or get_config().sba_season - + # if week: # # Show specific week results # games = await schedule_service.get_week_schedule(search_season, week) # completed_games = [game for game in games if game.is_completed] - + # if not completed_games: # await interaction.followup.send( # f"❌ No completed games found for season {search_season}, week {week}.", # ephemeral=True # ) # return - + # embed = await self._create_week_results_embed(completed_games, search_season, week) # await interaction.followup.send(embed=embed) # else: # # Show recent results # recent_games = await schedule_service.get_recent_games(search_season) - + # if not recent_games: # await interaction.followup.send( # f"❌ No recent games found for season {search_season}.", # ephemeral=True # ) # return - + # embed = await self._create_recent_results_embed(recent_games, search_season) # await interaction.followup.send(embed=embed) - - async def _show_week_schedule(self, interaction: discord.Interaction, season: int, week: int): + + async def _show_week_schedule( + self, interaction: discord.Interaction, season: int, week: int + ): """Show schedule for a specific week.""" self.logger.debug("Fetching week schedule", season=season, week=week) - + games = await schedule_service.get_week_schedule(season, week) - + if not games: await interaction.followup.send( - f"❌ No games found for season {season}, week {week}.", - ephemeral=True + f"❌ No games found for season {season}, week {week}.", ephemeral=True ) return - + embed = await self._create_week_schedule_embed(games, season, week) await interaction.followup.send(embed=embed) - - async def _show_team_schedule(self, interaction: discord.Interaction, season: int, team: str, week: Optional[int]): + + async def _show_team_schedule( + self, + interaction: discord.Interaction, + season: int, + team: str, + week: Optional[int], + ): """Show schedule for a specific team.""" self.logger.debug("Fetching team schedule", season=season, team=team, week=week) - + if week: # Show team games for specific week week_games = await schedule_service.get_week_schedule(season, week) team_games = [ - game for game in week_games - if game.away_team.abbrev.upper() == team.upper() or game.home_team.abbrev.upper() == team.upper() + game + for game in week_games + if game.away_team.abbrev.upper() == team.upper() + or game.home_team.abbrev.upper() == team.upper() ] else: # Show team's recent/upcoming games (limited weeks) team_games = await schedule_service.get_team_schedule(season, team, weeks=4) - + if not team_games: week_text = f" for week {week}" if week else "" await interaction.followup.send( f"❌ No games found for team '{team}'{week_text} in season {season}.", - ephemeral=True + ephemeral=True, ) return - + embed = await self._create_team_schedule_embed(team_games, season, team, week) await interaction.followup.send(embed=embed) - - async def _show_current_schedule(self, interaction: discord.Interaction, season: int): + + async def _show_current_schedule( + self, interaction: discord.Interaction, season: int + ): """Show current schedule overview with recent and upcoming games.""" self.logger.debug("Fetching current schedule overview", season=season) - + # Get both recent and upcoming games recent_games, upcoming_games = await asyncio.gather( schedule_service.get_recent_games(season, weeks_back=1), - schedule_service.get_upcoming_games(season, weeks_ahead=1) + schedule_service.get_upcoming_games(season), ) - + if not recent_games and not upcoming_games: await interaction.followup.send( f"❌ No recent or upcoming games found for season {season}.", - ephemeral=True + ephemeral=True, ) return - - embed = await self._create_current_schedule_embed(recent_games, upcoming_games, season) + + embed = await self._create_current_schedule_embed( + recent_games, upcoming_games, season + ) await interaction.followup.send(embed=embed) - - async def _create_week_schedule_embed(self, games, season: int, week: int) -> discord.Embed: + + async def _create_week_schedule_embed( + self, games, season: int, week: int + ) -> discord.Embed: """Create an embed for a week's schedule.""" embed = EmbedTemplate.create_base_embed( title=f"📅 Week {week} Schedule - Season {season}", - color=EmbedColors.PRIMARY + color=EmbedColors.PRIMARY, ) - + # Group games by series series_games = schedule_service.group_games_by_series(games) - + schedule_lines = [] for (team1, team2), series in series_games.items(): series_summary = await self._format_series_summary(series) schedule_lines.append(f"**{team1} vs {team2}**\n{series_summary}") - + if schedule_lines: embed.add_field( - name="Games", - value="\n\n".join(schedule_lines), - inline=False + name="Games", value="\n\n".join(schedule_lines), inline=False ) - + # Add week summary completed = len([g for g in games if g.is_completed]) total = len(games) embed.add_field( name="Week Progress", value=f"{completed}/{total} games completed", - inline=True + inline=True, ) - + embed.set_footer(text=f"Season {season} • Week {week}") return embed - - async def _create_team_schedule_embed(self, games, season: int, team: str, week: Optional[int]) -> discord.Embed: + + async def _create_team_schedule_embed( + self, games, season: int, team: str, week: Optional[int] + ) -> discord.Embed: """Create an embed for a team's schedule.""" week_text = f" - Week {week}" if week else "" embed = EmbedTemplate.create_base_embed( title=f"📅 {team.upper()} Schedule{week_text} - Season {season}", - color=EmbedColors.PRIMARY + color=EmbedColors.PRIMARY, ) - + # Separate completed and upcoming games completed_games = [g for g in games if g.is_completed] upcoming_games = [g for g in games if not g.is_completed] - + if completed_games: recent_lines = [] for game in completed_games[-5:]: # Last 5 games - result = "W" if game.winner and game.winner.abbrev.upper() == team.upper() else "L" + result = ( + "W" + if game.winner and game.winner.abbrev.upper() == team.upper() + else "L" + ) if game.home_team.abbrev.upper() == team.upper(): # Team was home - recent_lines.append(f"Week {game.week}: {result} vs {game.away_team.abbrev} ({game.score_display})") + recent_lines.append( + f"Week {game.week}: {result} vs {game.away_team.abbrev} ({game.score_display})" + ) else: - # Team was away - recent_lines.append(f"Week {game.week}: {result} @ {game.home_team.abbrev} ({game.score_display})") - + # Team was away + recent_lines.append( + f"Week {game.week}: {result} @ {game.home_team.abbrev} ({game.score_display})" + ) + embed.add_field( name="Recent Results", value="\n".join(recent_lines) if recent_lines else "No recent games", - inline=False + inline=False, ) - + if upcoming_games: upcoming_lines = [] for game in upcoming_games[:5]: # Next 5 games if game.home_team.abbrev.upper() == team.upper(): # Team is home - upcoming_lines.append(f"Week {game.week}: vs {game.away_team.abbrev}") + upcoming_lines.append( + f"Week {game.week}: vs {game.away_team.abbrev}" + ) else: # Team is away - upcoming_lines.append(f"Week {game.week}: @ {game.home_team.abbrev}") - + upcoming_lines.append( + f"Week {game.week}: @ {game.home_team.abbrev}" + ) + embed.add_field( name="Upcoming Games", - value="\n".join(upcoming_lines) if upcoming_lines else "No upcoming games", - inline=False + value=( + "\n".join(upcoming_lines) if upcoming_lines else "No upcoming games" + ), + inline=False, ) - + embed.set_footer(text=f"Season {season} • {team.upper()}") return embed - - async def _create_week_results_embed(self, games, season: int, week: int) -> discord.Embed: + + async def _create_week_results_embed( + self, games, season: int, week: int + ) -> discord.Embed: """Create an embed for week results.""" embed = EmbedTemplate.create_base_embed( - title=f"🏆 Week {week} Results - Season {season}", - color=EmbedColors.SUCCESS + title=f"🏆 Week {week} Results - Season {season}", color=EmbedColors.SUCCESS ) - + # Group by series and show results series_games = schedule_service.group_games_by_series(games) - + results_lines = [] for (team1, team2), series in series_games.items(): # Count wins for each team - team1_wins = len([g for g in series if g.winner and g.winner.abbrev == team1]) - team2_wins = len([g for g in series if g.winner and g.winner.abbrev == team2]) - + team1_wins = len( + [g for g in series if g.winner and g.winner.abbrev == team1] + ) + team2_wins = len( + [g for g in series if g.winner and g.winner.abbrev == team2] + ) + # Series result series_result = f"**{team1} {team1_wins}-{team2_wins} {team2}**" - + # Individual games game_details = [] for game in series: if game.series_game_display: - game_details.append(f"{game.series_game_display}: {game.matchup_display}") - + game_details.append( + f"{game.series_game_display}: {game.matchup_display}" + ) + results_lines.append(f"{series_result}\n" + "\n".join(game_details)) - + if results_lines: embed.add_field( - name="Series Results", - value="\n\n".join(results_lines), - inline=False + name="Series Results", value="\n\n".join(results_lines), inline=False ) - - embed.set_footer(text=f"Season {season} • Week {week} • {len(games)} games completed") + + embed.set_footer( + text=f"Season {season} • Week {week} • {len(games)} games completed" + ) return embed - + async def _create_recent_results_embed(self, games, season: int) -> discord.Embed: """Create an embed for recent results.""" embed = EmbedTemplate.create_base_embed( - title=f"🏆 Recent Results - Season {season}", - color=EmbedColors.SUCCESS + title=f"🏆 Recent Results - Season {season}", color=EmbedColors.SUCCESS ) - + # Show most recent games recent_lines = [] for game in games[:10]: # Show last 10 games recent_lines.append(f"Week {game.week}: {game.matchup_display}") - + if recent_lines: embed.add_field( - name="Latest Games", - value="\n".join(recent_lines), - inline=False + name="Latest Games", value="\n".join(recent_lines), inline=False ) - + embed.set_footer(text=f"Season {season} • Last {len(games)} completed games") return embed - - async def _create_current_schedule_embed(self, recent_games, upcoming_games, season: int) -> discord.Embed: + + async def _create_current_schedule_embed( + self, recent_games, upcoming_games, season: int + ) -> discord.Embed: """Create an embed for current schedule overview.""" embed = EmbedTemplate.create_base_embed( - title=f"📅 Current Schedule - Season {season}", - color=EmbedColors.INFO + title=f"📅 Current Schedule - Season {season}", color=EmbedColors.INFO ) - + if recent_games: recent_lines = [] for game in recent_games[:5]: recent_lines.append(f"Week {game.week}: {game.matchup_display}") - + embed.add_field( - name="Recent Results", - value="\n".join(recent_lines), - inline=False + name="Recent Results", value="\n".join(recent_lines), inline=False ) - + if upcoming_games: upcoming_lines = [] for game in upcoming_games[:5]: upcoming_lines.append(f"Week {game.week}: {game.matchup_display}") - + embed.add_field( - name="Upcoming Games", - value="\n".join(upcoming_lines), - inline=False + name="Upcoming Games", value="\n".join(upcoming_lines), inline=False ) - + embed.set_footer(text=f"Season {season}") return embed - + async def _format_series_summary(self, series) -> str: """Format a series summary.""" lines = [] for game in series: - game_display = f"{game.series_game_display}: {game.matchup_display}" if game.series_game_display else game.matchup_display + game_display = ( + f"{game.series_game_display}: {game.matchup_display}" + if game.series_game_display + else game.matchup_display + ) lines.append(game_display) - + return "\n".join(lines) if lines else "No games" async def setup(bot: commands.Bot): """Load the schedule commands cog.""" - await bot.add_cog(ScheduleCommands(bot)) \ No newline at end of file + await bot.add_cog(ScheduleCommands(bot)) diff --git a/services/schedule_service.py b/services/schedule_service.py index 78ee51d..cb3c101 100644 --- a/services/schedule_service.py +++ b/services/schedule_service.py @@ -4,6 +4,7 @@ Schedule service for Discord Bot v2.0 Handles game schedule and results retrieval and processing. """ +import asyncio import logging from typing import Optional, List, Dict, Tuple @@ -102,10 +103,10 @@ class ScheduleService: # If weeks not specified, try a reasonable range (18 weeks typical) week_range = range(1, (weeks + 1) if weeks else 19) - for week in week_range: - week_games = await self.get_week_schedule(season, week) - - # Filter games involving this team + all_week_games = await asyncio.gather( + *[self.get_week_schedule(season, week) for week in week_range] + ) + for week_games in all_week_games: for game in week_games: if ( game.away_team.abbrev.upper() == team_abbrev_upper @@ -135,15 +136,13 @@ class ScheduleService: recent_games = [] # Get games from recent weeks - for week_offset in range(weeks_back): - # This is simplified - in production you'd want to determine current week - week = 10 - week_offset # Assuming we're around week 10 - if week <= 0: - break - - week_games = await self.get_week_schedule(season, week) - - # Only include completed games + weeks_to_fetch = [ + (10 - offset) for offset in range(weeks_back) if (10 - offset) > 0 + ] + all_week_games = await asyncio.gather( + *[self.get_week_schedule(season, week) for week in weeks_to_fetch] + ) + for week_games in all_week_games: completed_games = [game for game in week_games if game.is_completed] recent_games.extend(completed_games) @@ -157,13 +156,12 @@ class ScheduleService: logger.error(f"Error getting recent games: {e}") return [] - async def get_upcoming_games(self, season: int, weeks_ahead: int = 6) -> List[Game]: + async def get_upcoming_games(self, season: int) -> List[Game]: """ - Get upcoming scheduled games by scanning multiple weeks. + Get upcoming scheduled games by scanning all weeks. Args: season: Season number - weeks_ahead: Number of weeks to scan ahead (default 6) Returns: List of upcoming Game instances @@ -171,20 +169,16 @@ class ScheduleService: try: upcoming_games = [] - # Scan through weeks to find games without scores - for week in range(1, 19): # Standard season length - week_games = await self.get_week_schedule(season, week) - - # Find games without scores (not yet played) + # Fetch all weeks in parallel and filter for incomplete games + all_week_games = await asyncio.gather( + *[self.get_week_schedule(season, week) for week in range(1, 19)] + ) + for week_games in all_week_games: upcoming_games_week = [ game for game in week_games if not game.is_completed ] upcoming_games.extend(upcoming_games_week) - # If we found upcoming games, we can limit how many more weeks to check - if upcoming_games and len(upcoming_games) >= 20: # Reasonable limit - break - # Sort by week, then game number upcoming_games.sort(key=lambda x: (x.week, x.game_num or 0)) diff --git a/tests/test_bot_maintenance_tree.py b/tests/test_bot_maintenance_tree.py new file mode 100644 index 0000000..c6530f9 --- /dev/null +++ b/tests/test_bot_maintenance_tree.py @@ -0,0 +1,282 @@ +""" +Tests for MaintenanceAwareTree and the maintenance_mode attribute on SBABot. + +What: + Verifies that the CommandTree subclass correctly gates interactions behind + bot.maintenance_mode. When maintenance mode is off every interaction is + allowed through unconditionally. When maintenance mode is on, non-admin + users receive an ephemeral error message and the check returns False, while + administrators are always allowed through. + +Why: + The original code attempted to register an interaction_check via a decorator + on self.tree inside setup_hook. That is not a valid pattern in discord.py — + interaction_check is an overridable async method on CommandTree, not a + decorator. The broken assignment caused a RuntimeWarning (unawaited + coroutine) and silently made maintenance mode a no-op. These tests confirm + the correct subclass-based implementation behaves as specified. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +import discord + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _make_bot(maintenance_mode: bool = False) -> MagicMock: + """Return a minimal mock bot with a maintenance_mode attribute.""" + bot = MagicMock() + bot.maintenance_mode = maintenance_mode + return bot + + +def _make_interaction(is_admin: bool, bot: MagicMock) -> AsyncMock: + """ + Build a mock discord.Interaction. + + The interaction's .client is set to *bot* so that MaintenanceAwareTree + can read bot.maintenance_mode via interaction.client, mirroring how + discord.py wires things at runtime. + """ + interaction = AsyncMock(spec=discord.Interaction) + interaction.client = bot + + # Mock the user as a guild Member so that guild_permissions is accessible. + user = MagicMock(spec=discord.Member) + user.guild_permissions = MagicMock() + user.guild_permissions.administrator = is_admin + interaction.user = user + + # response.send_message must be awaitable. + interaction.response = AsyncMock() + interaction.response.send_message = AsyncMock() + + return interaction + + +# --------------------------------------------------------------------------- +# Import the class under test after mocks are available. +# We import here (not at module level) so that the conftest env-vars are set +# before any discord_bot_v2 modules are touched. +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _patch_discord_app_commands(monkeypatch): + """ + Prevent MaintenanceAwareTree.__init__ from calling discord internals that + need a real event loop / Discord connection. We test only the logic of + interaction_check, so we stub out the parent __init__. + """ + # Nothing extra to patch for the interaction_check itself; the parent + # CommandTree.__init__ is only called when constructing SBABot, which we + # don't do in these unit tests. + yield + + +# --------------------------------------------------------------------------- +# Tests for MaintenanceAwareTree.interaction_check +# --------------------------------------------------------------------------- + + +class TestMaintenanceAwareTree: + """Unit tests for MaintenanceAwareTree.interaction_check.""" + + @pytest.fixture + def tree_cls(self): + """Import and return the MaintenanceAwareTree class.""" + from bot import MaintenanceAwareTree + + return MaintenanceAwareTree + + # ------------------------------------------------------------------ + # Maintenance OFF + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_maintenance_off_allows_non_admin(self, tree_cls): + """ + When maintenance_mode is False, non-admin users are always allowed. + The check must return True without sending any message. + """ + bot = _make_bot(maintenance_mode=False) + interaction = _make_interaction(is_admin=False, bot=bot) + + # Instantiate tree without calling parent __init__ by testing the method + # directly on an unbound basis. + result = await tree_cls.interaction_check( + MagicMock(), # placeholder 'self' for the tree instance + interaction, + ) + + assert result is True + interaction.response.send_message.assert_not_called() + + @pytest.mark.asyncio + async def test_maintenance_off_allows_admin(self, tree_cls): + """ + When maintenance_mode is False, admin users are also always allowed. + """ + bot = _make_bot(maintenance_mode=False) + interaction = _make_interaction(is_admin=True, bot=bot) + + result = await tree_cls.interaction_check(MagicMock(), interaction) + + assert result is True + interaction.response.send_message.assert_not_called() + + # ------------------------------------------------------------------ + # Maintenance ON — non-admin + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_maintenance_on_blocks_non_admin(self, tree_cls): + """ + When maintenance_mode is True, non-admin users must be blocked. + The check must return False and send an ephemeral message. + """ + bot = _make_bot(maintenance_mode=True) + interaction = _make_interaction(is_admin=False, bot=bot) + + result = await tree_cls.interaction_check(MagicMock(), interaction) + + assert result is False + interaction.response.send_message.assert_called_once() + + # Confirm the call used ephemeral=True + _, kwargs = interaction.response.send_message.call_args + assert kwargs.get("ephemeral") is True + + @pytest.mark.asyncio + async def test_maintenance_on_message_has_no_emoji(self, tree_cls): + """ + The maintenance block message must not contain emoji characters. + The project style deliberately strips emoji from user-facing strings. + """ + import unicodedata + + bot = _make_bot(maintenance_mode=True) + interaction = _make_interaction(is_admin=False, bot=bot) + + await tree_cls.interaction_check(MagicMock(), interaction) + + args, _ = interaction.response.send_message.call_args + message_text = args[0] if args else "" + + for ch in message_text: + category = unicodedata.category(ch) + assert category != "So", ( + f"Unexpected emoji/symbol character {ch!r} (category {category!r}) " + f"found in maintenance message: {message_text!r}" + ) + + # ------------------------------------------------------------------ + # Maintenance ON — admin + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_maintenance_on_allows_admin(self, tree_cls): + """ + When maintenance_mode is True, administrator users must still be + allowed through. Admins should never be locked out of bot commands. + """ + bot = _make_bot(maintenance_mode=True) + interaction = _make_interaction(is_admin=True, bot=bot) + + result = await tree_cls.interaction_check(MagicMock(), interaction) + + assert result is True + interaction.response.send_message.assert_not_called() + + # ------------------------------------------------------------------ + # Edge case: non-Member user during maintenance + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_maintenance_on_blocks_non_member_user(self, tree_cls): + """ + When maintenance_mode is True and the user is not a guild Member + (e.g. interaction from a DM context), the check must still block them + because we cannot verify administrator status. + """ + bot = _make_bot(maintenance_mode=True) + interaction = AsyncMock(spec=discord.Interaction) + interaction.client = bot + + # Simulate a non-Member user (e.g. discord.User from DM context) + user = MagicMock(spec=discord.User) + # discord.User has no guild_permissions attribute + interaction.user = user + interaction.response = AsyncMock() + interaction.response.send_message = AsyncMock() + + result = await tree_cls.interaction_check(MagicMock(), interaction) + + assert result is False + interaction.response.send_message.assert_called_once() + + # ------------------------------------------------------------------ + # Missing attribute safety: bot without maintenance_mode attr + # ------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_missing_maintenance_mode_attr_defaults_to_allowed(self, tree_cls): + """ + If the bot object does not have a maintenance_mode attribute (e.g. + during testing or unusual startup), getattr fallback must treat it as + False and allow the interaction. + """ + bot = MagicMock() + # Deliberately do NOT set bot.maintenance_mode + del bot.maintenance_mode + + interaction = _make_interaction(is_admin=False, bot=bot) + + result = await tree_cls.interaction_check(MagicMock(), interaction) + + assert result is True + + +# --------------------------------------------------------------------------- +# Tests for SBABot.maintenance_mode attribute +# --------------------------------------------------------------------------- + + +class TestSBABotMaintenanceModeAttribute: + """ + Confirms that SBABot.__init__ always sets maintenance_mode = False. + + We avoid constructing a real SBABot (which requires a Discord event loop + and valid token infrastructure) by patching the entire commands.Bot.__init__ + and then calling SBABot.__init__ directly on a bare instance so that only + the SBABot-specific attribute assignments execute. + """ + + def test_maintenance_mode_default_is_false(self, monkeypatch): + """ + SBABot.__init__ must set self.maintenance_mode = False so that the + MaintenanceAwareTree has the attribute available from the very first + interaction, even before /admin-maintenance is ever called. + + Strategy: patch commands.Bot.__init__ to be a no-op so super().__init__ + succeeds without a real Discord connection, then call SBABot.__init__ + and assert the attribute is present with the correct default value. + """ + from unittest.mock import patch + from discord.ext import commands + from bot import SBABot + + with patch.object(commands.Bot, "__init__", return_value=None): + bot = SBABot.__new__(SBABot) + SBABot.__init__(bot) + + assert hasattr( + bot, "maintenance_mode" + ), "SBABot must define self.maintenance_mode in __init__" + assert ( + bot.maintenance_mode is False + ), "SBABot.maintenance_mode must default to False" diff --git a/tests/test_services_schedule.py b/tests/test_services_schedule.py new file mode 100644 index 0000000..70e60c7 --- /dev/null +++ b/tests/test_services_schedule.py @@ -0,0 +1,284 @@ +""" +Tests for schedule service functionality. + +Covers get_week_schedule, get_team_schedule, get_recent_games, +get_upcoming_games, and group_games_by_series — verifying the +asyncio.gather parallelization and post-fetch filtering logic. +""" + +import pytest +from unittest.mock import AsyncMock, patch + +from services.schedule_service import ScheduleService +from tests.factories import GameFactory, TeamFactory + + +def _game(game_id, week, away_abbrev, home_abbrev, **kwargs): + """Create a Game with distinct team IDs per matchup.""" + return GameFactory.create( + id=game_id, + week=week, + away_team=TeamFactory.create(id=game_id * 10, abbrev=away_abbrev), + home_team=TeamFactory.create(id=game_id * 10 + 1, abbrev=home_abbrev), + **kwargs, + ) + + +class TestGetWeekSchedule: + """Tests for ScheduleService.get_week_schedule — the HTTP layer.""" + + @pytest.fixture + def service(self): + svc = ScheduleService() + svc.get_client = AsyncMock() + return svc + + @pytest.mark.asyncio + async def test_success(self, service): + """get_week_schedule returns parsed Game objects on a normal response.""" + mock_client = AsyncMock() + mock_client.get.return_value = { + "games": [ + { + "id": 1, + "season": 12, + "week": 5, + "game_num": 1, + "season_type": "regular", + "away_team": { + "id": 10, + "abbrev": "NYY", + "sname": "NYY", + "lname": "New York", + "season": 12, + }, + "home_team": { + "id": 11, + "abbrev": "BOS", + "sname": "BOS", + "lname": "Boston", + "season": 12, + }, + "away_score": 4, + "home_score": 2, + } + ] + } + service.get_client.return_value = mock_client + + games = await service.get_week_schedule(12, 5) + + assert len(games) == 1 + assert games[0].away_team.abbrev == "NYY" + assert games[0].home_team.abbrev == "BOS" + assert games[0].is_completed + + @pytest.mark.asyncio + async def test_empty_response(self, service): + """get_week_schedule returns [] when the API has no games.""" + mock_client = AsyncMock() + mock_client.get.return_value = {"games": []} + service.get_client.return_value = mock_client + + games = await service.get_week_schedule(12, 99) + assert games == [] + + @pytest.mark.asyncio + async def test_api_error_returns_empty(self, service): + """get_week_schedule returns [] on API error (no exception raised).""" + service.get_client.side_effect = Exception("connection refused") + + games = await service.get_week_schedule(12, 1) + assert games == [] + + @pytest.mark.asyncio + async def test_missing_games_key(self, service): + """get_week_schedule returns [] when response lacks 'games' key.""" + mock_client = AsyncMock() + mock_client.get.return_value = {"status": "ok"} + service.get_client.return_value = mock_client + + games = await service.get_week_schedule(12, 1) + assert games == [] + + +class TestGetTeamSchedule: + """Tests for get_team_schedule — gather + team-abbrev filter.""" + + @pytest.fixture + def service(self): + return ScheduleService() + + @pytest.mark.asyncio + async def test_filters_by_team_case_insensitive(self, service): + """get_team_schedule returns only games involving the requested team, + regardless of abbreviation casing.""" + week1 = [ + _game(1, 1, "NYY", "BOS", away_score=3, home_score=1), + _game(2, 1, "LAD", "CHC", away_score=5, home_score=2), + ] + week2 = [ + _game(3, 2, "BOS", "NYY", away_score=2, home_score=4), + ] + + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.side_effect = [week1, week2] + result = await service.get_team_schedule(12, "nyy", weeks=2) + + assert len(result) == 2 + assert all( + g.away_team.abbrev == "NYY" or g.home_team.abbrev == "NYY" for g in result + ) + + @pytest.mark.asyncio + async def test_full_season_fetches_18_weeks(self, service): + """When weeks is None, all 18 weeks are fetched via gather.""" + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.return_value = [] + await service.get_team_schedule(12, "NYY") + + assert mock.call_count == 18 + + @pytest.mark.asyncio + async def test_limited_weeks(self, service): + """When weeks=5, only 5 weeks are fetched.""" + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.return_value = [] + await service.get_team_schedule(12, "NYY", weeks=5) + + assert mock.call_count == 5 + + +class TestGetRecentGames: + """Tests for get_recent_games — gather + completed-only filter.""" + + @pytest.fixture + def service(self): + return ScheduleService() + + @pytest.mark.asyncio + async def test_returns_only_completed_games(self, service): + """get_recent_games filters out games without scores.""" + completed = GameFactory.completed(id=1, week=10) + incomplete = GameFactory.upcoming(id=2, week=10) + + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.return_value = [completed, incomplete] + result = await service.get_recent_games(12, weeks_back=1) + + assert len(result) == 1 + assert result[0].is_completed + + @pytest.mark.asyncio + async def test_sorted_descending_by_week_and_game_num(self, service): + """Recent games are sorted most-recent first.""" + game_w10 = GameFactory.completed(id=1, week=10, game_num=2) + game_w9 = GameFactory.completed(id=2, week=9, game_num=1) + + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.side_effect = [[game_w10], [game_w9]] + result = await service.get_recent_games(12, weeks_back=2) + + assert result[0].week == 10 + assert result[1].week == 9 + + @pytest.mark.asyncio + async def test_skips_negative_weeks(self, service): + """Weeks that would be <= 0 are excluded from fetch.""" + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.return_value = [] + await service.get_recent_games(12, weeks_back=15) + + # weeks_to_fetch = [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] — only 10 valid weeks + assert mock.call_count == 10 + + +class TestGetUpcomingGames: + """Tests for get_upcoming_games — gather all 18 weeks + incomplete filter.""" + + @pytest.fixture + def service(self): + return ScheduleService() + + @pytest.mark.asyncio + async def test_returns_only_incomplete_games(self, service): + """get_upcoming_games filters out completed games.""" + completed = GameFactory.completed(id=1, week=5) + upcoming = GameFactory.upcoming(id=2, week=5) + + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.return_value = [completed, upcoming] + result = await service.get_upcoming_games(12) + + assert len(result) == 18 # 1 incomplete game per week × 18 weeks + assert all(not g.is_completed for g in result) + + @pytest.mark.asyncio + async def test_sorted_ascending_by_week_and_game_num(self, service): + """Upcoming games are sorted earliest first.""" + game_w3 = GameFactory.upcoming(id=1, week=3, game_num=1) + game_w1 = GameFactory.upcoming(id=2, week=1, game_num=2) + + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + + def side_effect(season, week): + if week == 1: + return [game_w1] + if week == 3: + return [game_w3] + return [] + + mock.side_effect = side_effect + result = await service.get_upcoming_games(12) + + assert result[0].week == 1 + assert result[1].week == 3 + + @pytest.mark.asyncio + async def test_fetches_all_18_weeks(self, service): + """All 18 weeks are fetched in parallel (no early exit).""" + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.return_value = [] + await service.get_upcoming_games(12) + + assert mock.call_count == 18 + + +class TestGroupGamesBySeries: + """Tests for group_games_by_series — synchronous grouping logic.""" + + @pytest.fixture + def service(self): + return ScheduleService() + + def test_groups_by_alphabetical_pairing(self, service): + """Games between the same two teams are grouped under one key, + with the alphabetically-first team first in the tuple.""" + games = [ + _game(1, 1, "NYY", "BOS", game_num=1), + _game(2, 1, "BOS", "NYY", game_num=2), + _game(3, 1, "LAD", "CHC", game_num=1), + ] + + result = service.group_games_by_series(games) + + assert ("BOS", "NYY") in result + assert len(result[("BOS", "NYY")]) == 2 + assert ("CHC", "LAD") in result + assert len(result[("CHC", "LAD")]) == 1 + + def test_sorted_by_game_num_within_series(self, service): + """Games within each series are sorted by game_num.""" + games = [ + _game(1, 1, "NYY", "BOS", game_num=3), + _game(2, 1, "NYY", "BOS", game_num=1), + _game(3, 1, "NYY", "BOS", game_num=2), + ] + + result = service.group_games_by_series(games) + series = result[("BOS", "NYY")] + assert [g.game_num for g in series] == [1, 2, 3] + + def test_empty_input(self, service): + """Empty games list returns empty dict.""" + assert service.group_games_by_series([]) == {}