Merge pull request 'perf: parallelize schedule_service API calls with asyncio.gather' (#103) from ai/major-domo-v2-88 into next-release
All checks were successful
Build Docker Image / build (push) Successful in 1m16s

Reviewed-on: #103
This commit is contained in:
cal 2026-03-20 15:16:40 +00:00
commit 52fa56cb69
6 changed files with 775 additions and 143 deletions

40
bot.py
View File

@ -64,6 +64,44 @@ def setup_logging():
return logger
class MaintenanceAwareTree(discord.app_commands.CommandTree):
"""
CommandTree subclass that gates all interactions behind a maintenance mode check.
When bot.maintenance_mode is True, non-administrator users receive an ephemeral
error message and the interaction is blocked. Administrators are always allowed
through. When maintenance_mode is False the check is a no-op and every
interaction proceeds normally.
This is the correct way to register a global interaction_check for app commands
in discord.py overriding the method on a CommandTree subclass passed via
tree_cls rather than attempting to assign a decorator to self.tree inside
setup_hook.
"""
async def interaction_check(self, interaction: discord.Interaction) -> bool:
"""Allow admins through; block everyone else when maintenance mode is active."""
bot = interaction.client # type: ignore[assignment]
# If maintenance mode is off, always allow.
if not getattr(bot, "maintenance_mode", False):
return True
# Maintenance mode is on — let administrators through unconditionally.
if (
isinstance(interaction.user, discord.Member)
and interaction.user.guild_permissions.administrator
):
return True
# Block non-admin users with an ephemeral notice.
await interaction.response.send_message(
"The bot is currently in maintenance mode. Please try again later.",
ephemeral=True,
)
return False
class SBABot(commands.Bot):
"""Custom bot class for SBA league management."""
@ -77,8 +115,10 @@ class SBABot(commands.Bot):
command_prefix="!", # Legacy prefix, primarily using slash commands
intents=intents,
description="Major Domo v2.0",
tree_cls=MaintenanceAwareTree,
)
self.maintenance_mode: bool = False
self.logger = logging.getLogger("discord_bot_v2")
self.maintenance_mode: bool = False

View File

@ -490,7 +490,7 @@ class AdminCommands(commands.Cog):
await interaction.response.defer()
is_enabling = mode.lower() == "on"
self.bot.maintenance_mode = is_enabling
self.bot.maintenance_mode = is_enabling # type: ignore[attr-defined]
self.logger.info(
f"Maintenance mode {'enabled' if is_enabling else 'disabled'} by {interaction.user} (id={interaction.user.id})"
)

View File

@ -3,6 +3,7 @@ League Schedule Commands
Implements slash commands for displaying game schedules and results.
"""
from typing import Optional
import asyncio
@ -19,19 +20,16 @@ from views.embeds import EmbedColors, EmbedTemplate
class ScheduleCommands(commands.Cog):
"""League schedule command handlers."""
def __init__(self, bot: commands.Bot):
self.bot = bot
self.logger = get_contextual_logger(f'{__name__}.ScheduleCommands')
@discord.app_commands.command(
name="schedule",
description="Display game schedule"
)
self.logger = get_contextual_logger(f"{__name__}.ScheduleCommands")
@discord.app_commands.command(name="schedule", description="Display game schedule")
@discord.app_commands.describe(
season="Season to show schedule for (defaults to current season)",
week="Week number to show (optional)",
team="Team abbreviation to filter by (optional)"
team="Team abbreviation to filter by (optional)",
)
@requires_team()
@logged_command("/schedule")
@ -40,13 +38,13 @@ class ScheduleCommands(commands.Cog):
interaction: discord.Interaction,
season: Optional[int] = None,
week: Optional[int] = None,
team: Optional[str] = None
team: Optional[str] = None,
):
"""Display game schedule for a week or team."""
await interaction.response.defer()
search_season = season or get_config().sba_season
if team:
# Show team schedule
await self._show_team_schedule(interaction, search_season, team, week)
@ -56,7 +54,7 @@ class ScheduleCommands(commands.Cog):
else:
# Show recent/upcoming games
await self._show_current_schedule(interaction, search_season)
# @discord.app_commands.command(
# name="results",
# description="Display recent game results"
@ -74,282 +72,316 @@ class ScheduleCommands(commands.Cog):
# ):
# """Display recent game results."""
# await interaction.response.defer()
# search_season = season or get_config().sba_season
# if week:
# # Show specific week results
# games = await schedule_service.get_week_schedule(search_season, week)
# completed_games = [game for game in games if game.is_completed]
# if not completed_games:
# await interaction.followup.send(
# f"❌ No completed games found for season {search_season}, week {week}.",
# ephemeral=True
# )
# return
# embed = await self._create_week_results_embed(completed_games, search_season, week)
# await interaction.followup.send(embed=embed)
# else:
# # Show recent results
# recent_games = await schedule_service.get_recent_games(search_season)
# if not recent_games:
# await interaction.followup.send(
# f"❌ No recent games found for season {search_season}.",
# ephemeral=True
# )
# return
# embed = await self._create_recent_results_embed(recent_games, search_season)
# await interaction.followup.send(embed=embed)
async def _show_week_schedule(self, interaction: discord.Interaction, season: int, week: int):
async def _show_week_schedule(
self, interaction: discord.Interaction, season: int, week: int
):
"""Show schedule for a specific week."""
self.logger.debug("Fetching week schedule", season=season, week=week)
games = await schedule_service.get_week_schedule(season, week)
if not games:
await interaction.followup.send(
f"❌ No games found for season {season}, week {week}.",
ephemeral=True
f"❌ No games found for season {season}, week {week}.", ephemeral=True
)
return
embed = await self._create_week_schedule_embed(games, season, week)
await interaction.followup.send(embed=embed)
async def _show_team_schedule(self, interaction: discord.Interaction, season: int, team: str, week: Optional[int]):
async def _show_team_schedule(
self,
interaction: discord.Interaction,
season: int,
team: str,
week: Optional[int],
):
"""Show schedule for a specific team."""
self.logger.debug("Fetching team schedule", season=season, team=team, week=week)
if week:
# Show team games for specific week
week_games = await schedule_service.get_week_schedule(season, week)
team_games = [
game for game in week_games
if game.away_team.abbrev.upper() == team.upper() or game.home_team.abbrev.upper() == team.upper()
game
for game in week_games
if game.away_team.abbrev.upper() == team.upper()
or game.home_team.abbrev.upper() == team.upper()
]
else:
# Show team's recent/upcoming games (limited weeks)
team_games = await schedule_service.get_team_schedule(season, team, weeks=4)
if not team_games:
week_text = f" for week {week}" if week else ""
await interaction.followup.send(
f"❌ No games found for team '{team}'{week_text} in season {season}.",
ephemeral=True
ephemeral=True,
)
return
embed = await self._create_team_schedule_embed(team_games, season, team, week)
await interaction.followup.send(embed=embed)
async def _show_current_schedule(self, interaction: discord.Interaction, season: int):
async def _show_current_schedule(
self, interaction: discord.Interaction, season: int
):
"""Show current schedule overview with recent and upcoming games."""
self.logger.debug("Fetching current schedule overview", season=season)
# Get both recent and upcoming games
recent_games, upcoming_games = await asyncio.gather(
schedule_service.get_recent_games(season, weeks_back=1),
schedule_service.get_upcoming_games(season, weeks_ahead=1)
schedule_service.get_upcoming_games(season),
)
if not recent_games and not upcoming_games:
await interaction.followup.send(
f"❌ No recent or upcoming games found for season {season}.",
ephemeral=True
ephemeral=True,
)
return
embed = await self._create_current_schedule_embed(recent_games, upcoming_games, season)
embed = await self._create_current_schedule_embed(
recent_games, upcoming_games, season
)
await interaction.followup.send(embed=embed)
async def _create_week_schedule_embed(self, games, season: int, week: int) -> discord.Embed:
async def _create_week_schedule_embed(
self, games, season: int, week: int
) -> discord.Embed:
"""Create an embed for a week's schedule."""
embed = EmbedTemplate.create_base_embed(
title=f"📅 Week {week} Schedule - Season {season}",
color=EmbedColors.PRIMARY
color=EmbedColors.PRIMARY,
)
# Group games by series
series_games = schedule_service.group_games_by_series(games)
schedule_lines = []
for (team1, team2), series in series_games.items():
series_summary = await self._format_series_summary(series)
schedule_lines.append(f"**{team1} vs {team2}**\n{series_summary}")
if schedule_lines:
embed.add_field(
name="Games",
value="\n\n".join(schedule_lines),
inline=False
name="Games", value="\n\n".join(schedule_lines), inline=False
)
# Add week summary
completed = len([g for g in games if g.is_completed])
total = len(games)
embed.add_field(
name="Week Progress",
value=f"{completed}/{total} games completed",
inline=True
inline=True,
)
embed.set_footer(text=f"Season {season} • Week {week}")
return embed
async def _create_team_schedule_embed(self, games, season: int, team: str, week: Optional[int]) -> discord.Embed:
async def _create_team_schedule_embed(
self, games, season: int, team: str, week: Optional[int]
) -> discord.Embed:
"""Create an embed for a team's schedule."""
week_text = f" - Week {week}" if week else ""
embed = EmbedTemplate.create_base_embed(
title=f"📅 {team.upper()} Schedule{week_text} - Season {season}",
color=EmbedColors.PRIMARY
color=EmbedColors.PRIMARY,
)
# Separate completed and upcoming games
completed_games = [g for g in games if g.is_completed]
upcoming_games = [g for g in games if not g.is_completed]
if completed_games:
recent_lines = []
for game in completed_games[-5:]: # Last 5 games
result = "W" if game.winner and game.winner.abbrev.upper() == team.upper() else "L"
result = (
"W"
if game.winner and game.winner.abbrev.upper() == team.upper()
else "L"
)
if game.home_team.abbrev.upper() == team.upper():
# Team was home
recent_lines.append(f"Week {game.week}: {result} vs {game.away_team.abbrev} ({game.score_display})")
recent_lines.append(
f"Week {game.week}: {result} vs {game.away_team.abbrev} ({game.score_display})"
)
else:
# Team was away
recent_lines.append(f"Week {game.week}: {result} @ {game.home_team.abbrev} ({game.score_display})")
# Team was away
recent_lines.append(
f"Week {game.week}: {result} @ {game.home_team.abbrev} ({game.score_display})"
)
embed.add_field(
name="Recent Results",
value="\n".join(recent_lines) if recent_lines else "No recent games",
inline=False
inline=False,
)
if upcoming_games:
upcoming_lines = []
for game in upcoming_games[:5]: # Next 5 games
if game.home_team.abbrev.upper() == team.upper():
# Team is home
upcoming_lines.append(f"Week {game.week}: vs {game.away_team.abbrev}")
upcoming_lines.append(
f"Week {game.week}: vs {game.away_team.abbrev}"
)
else:
# Team is away
upcoming_lines.append(f"Week {game.week}: @ {game.home_team.abbrev}")
upcoming_lines.append(
f"Week {game.week}: @ {game.home_team.abbrev}"
)
embed.add_field(
name="Upcoming Games",
value="\n".join(upcoming_lines) if upcoming_lines else "No upcoming games",
inline=False
value=(
"\n".join(upcoming_lines) if upcoming_lines else "No upcoming games"
),
inline=False,
)
embed.set_footer(text=f"Season {season}{team.upper()}")
return embed
async def _create_week_results_embed(self, games, season: int, week: int) -> discord.Embed:
async def _create_week_results_embed(
self, games, season: int, week: int
) -> discord.Embed:
"""Create an embed for week results."""
embed = EmbedTemplate.create_base_embed(
title=f"🏆 Week {week} Results - Season {season}",
color=EmbedColors.SUCCESS
title=f"🏆 Week {week} Results - Season {season}", color=EmbedColors.SUCCESS
)
# Group by series and show results
series_games = schedule_service.group_games_by_series(games)
results_lines = []
for (team1, team2), series in series_games.items():
# Count wins for each team
team1_wins = len([g for g in series if g.winner and g.winner.abbrev == team1])
team2_wins = len([g for g in series if g.winner and g.winner.abbrev == team2])
team1_wins = len(
[g for g in series if g.winner and g.winner.abbrev == team1]
)
team2_wins = len(
[g for g in series if g.winner and g.winner.abbrev == team2]
)
# Series result
series_result = f"**{team1} {team1_wins}-{team2_wins} {team2}**"
# Individual games
game_details = []
for game in series:
if game.series_game_display:
game_details.append(f"{game.series_game_display}: {game.matchup_display}")
game_details.append(
f"{game.series_game_display}: {game.matchup_display}"
)
results_lines.append(f"{series_result}\n" + "\n".join(game_details))
if results_lines:
embed.add_field(
name="Series Results",
value="\n\n".join(results_lines),
inline=False
name="Series Results", value="\n\n".join(results_lines), inline=False
)
embed.set_footer(text=f"Season {season} • Week {week}{len(games)} games completed")
embed.set_footer(
text=f"Season {season} • Week {week}{len(games)} games completed"
)
return embed
async def _create_recent_results_embed(self, games, season: int) -> discord.Embed:
"""Create an embed for recent results."""
embed = EmbedTemplate.create_base_embed(
title=f"🏆 Recent Results - Season {season}",
color=EmbedColors.SUCCESS
title=f"🏆 Recent Results - Season {season}", color=EmbedColors.SUCCESS
)
# Show most recent games
recent_lines = []
for game in games[:10]: # Show last 10 games
recent_lines.append(f"Week {game.week}: {game.matchup_display}")
if recent_lines:
embed.add_field(
name="Latest Games",
value="\n".join(recent_lines),
inline=False
name="Latest Games", value="\n".join(recent_lines), inline=False
)
embed.set_footer(text=f"Season {season} • Last {len(games)} completed games")
return embed
async def _create_current_schedule_embed(self, recent_games, upcoming_games, season: int) -> discord.Embed:
async def _create_current_schedule_embed(
self, recent_games, upcoming_games, season: int
) -> discord.Embed:
"""Create an embed for current schedule overview."""
embed = EmbedTemplate.create_base_embed(
title=f"📅 Current Schedule - Season {season}",
color=EmbedColors.INFO
title=f"📅 Current Schedule - Season {season}", color=EmbedColors.INFO
)
if recent_games:
recent_lines = []
for game in recent_games[:5]:
recent_lines.append(f"Week {game.week}: {game.matchup_display}")
embed.add_field(
name="Recent Results",
value="\n".join(recent_lines),
inline=False
name="Recent Results", value="\n".join(recent_lines), inline=False
)
if upcoming_games:
upcoming_lines = []
for game in upcoming_games[:5]:
upcoming_lines.append(f"Week {game.week}: {game.matchup_display}")
embed.add_field(
name="Upcoming Games",
value="\n".join(upcoming_lines),
inline=False
name="Upcoming Games", value="\n".join(upcoming_lines), inline=False
)
embed.set_footer(text=f"Season {season}")
return embed
async def _format_series_summary(self, series) -> str:
"""Format a series summary."""
lines = []
for game in series:
game_display = f"{game.series_game_display}: {game.matchup_display}" if game.series_game_display else game.matchup_display
game_display = (
f"{game.series_game_display}: {game.matchup_display}"
if game.series_game_display
else game.matchup_display
)
lines.append(game_display)
return "\n".join(lines) if lines else "No games"
async def setup(bot: commands.Bot):
"""Load the schedule commands cog."""
await bot.add_cog(ScheduleCommands(bot))
await bot.add_cog(ScheduleCommands(bot))

View File

@ -4,6 +4,7 @@ Schedule service for Discord Bot v2.0
Handles game schedule and results retrieval and processing.
"""
import asyncio
import logging
from typing import Optional, List, Dict, Tuple
@ -102,10 +103,10 @@ class ScheduleService:
# If weeks not specified, try a reasonable range (18 weeks typical)
week_range = range(1, (weeks + 1) if weeks else 19)
for week in week_range:
week_games = await self.get_week_schedule(season, week)
# Filter games involving this team
all_week_games = await asyncio.gather(
*[self.get_week_schedule(season, week) for week in week_range]
)
for week_games in all_week_games:
for game in week_games:
if (
game.away_team.abbrev.upper() == team_abbrev_upper
@ -135,15 +136,13 @@ class ScheduleService:
recent_games = []
# Get games from recent weeks
for week_offset in range(weeks_back):
# This is simplified - in production you'd want to determine current week
week = 10 - week_offset # Assuming we're around week 10
if week <= 0:
break
week_games = await self.get_week_schedule(season, week)
# Only include completed games
weeks_to_fetch = [
(10 - offset) for offset in range(weeks_back) if (10 - offset) > 0
]
all_week_games = await asyncio.gather(
*[self.get_week_schedule(season, week) for week in weeks_to_fetch]
)
for week_games in all_week_games:
completed_games = [game for game in week_games if game.is_completed]
recent_games.extend(completed_games)
@ -157,13 +156,12 @@ class ScheduleService:
logger.error(f"Error getting recent games: {e}")
return []
async def get_upcoming_games(self, season: int, weeks_ahead: int = 6) -> List[Game]:
async def get_upcoming_games(self, season: int) -> List[Game]:
"""
Get upcoming scheduled games by scanning multiple weeks.
Get upcoming scheduled games by scanning all weeks.
Args:
season: Season number
weeks_ahead: Number of weeks to scan ahead (default 6)
Returns:
List of upcoming Game instances
@ -171,20 +169,16 @@ class ScheduleService:
try:
upcoming_games = []
# Scan through weeks to find games without scores
for week in range(1, 19): # Standard season length
week_games = await self.get_week_schedule(season, week)
# Find games without scores (not yet played)
# Fetch all weeks in parallel and filter for incomplete games
all_week_games = await asyncio.gather(
*[self.get_week_schedule(season, week) for week in range(1, 19)]
)
for week_games in all_week_games:
upcoming_games_week = [
game for game in week_games if not game.is_completed
]
upcoming_games.extend(upcoming_games_week)
# If we found upcoming games, we can limit how many more weeks to check
if upcoming_games and len(upcoming_games) >= 20: # Reasonable limit
break
# Sort by week, then game number
upcoming_games.sort(key=lambda x: (x.week, x.game_num or 0))

View File

@ -0,0 +1,282 @@
"""
Tests for MaintenanceAwareTree and the maintenance_mode attribute on SBABot.
What:
Verifies that the CommandTree subclass correctly gates interactions behind
bot.maintenance_mode. When maintenance mode is off every interaction is
allowed through unconditionally. When maintenance mode is on, non-admin
users receive an ephemeral error message and the check returns False, while
administrators are always allowed through.
Why:
The original code attempted to register an interaction_check via a decorator
on self.tree inside setup_hook. That is not a valid pattern in discord.py
interaction_check is an overridable async method on CommandTree, not a
decorator. The broken assignment caused a RuntimeWarning (unawaited
coroutine) and silently made maintenance mode a no-op. These tests confirm
the correct subclass-based implementation behaves as specified.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
import discord
# ---------------------------------------------------------------------------
# Helpers / fixtures
# ---------------------------------------------------------------------------
def _make_bot(maintenance_mode: bool = False) -> MagicMock:
"""Return a minimal mock bot with a maintenance_mode attribute."""
bot = MagicMock()
bot.maintenance_mode = maintenance_mode
return bot
def _make_interaction(is_admin: bool, bot: MagicMock) -> AsyncMock:
"""
Build a mock discord.Interaction.
The interaction's .client is set to *bot* so that MaintenanceAwareTree
can read bot.maintenance_mode via interaction.client, mirroring how
discord.py wires things at runtime.
"""
interaction = AsyncMock(spec=discord.Interaction)
interaction.client = bot
# Mock the user as a guild Member so that guild_permissions is accessible.
user = MagicMock(spec=discord.Member)
user.guild_permissions = MagicMock()
user.guild_permissions.administrator = is_admin
interaction.user = user
# response.send_message must be awaitable.
interaction.response = AsyncMock()
interaction.response.send_message = AsyncMock()
return interaction
# ---------------------------------------------------------------------------
# Import the class under test after mocks are available.
# We import here (not at module level) so that the conftest env-vars are set
# before any discord_bot_v2 modules are touched.
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _patch_discord_app_commands(monkeypatch):
"""
Prevent MaintenanceAwareTree.__init__ from calling discord internals that
need a real event loop / Discord connection. We test only the logic of
interaction_check, so we stub out the parent __init__.
"""
# Nothing extra to patch for the interaction_check itself; the parent
# CommandTree.__init__ is only called when constructing SBABot, which we
# don't do in these unit tests.
yield
# ---------------------------------------------------------------------------
# Tests for MaintenanceAwareTree.interaction_check
# ---------------------------------------------------------------------------
class TestMaintenanceAwareTree:
"""Unit tests for MaintenanceAwareTree.interaction_check."""
@pytest.fixture
def tree_cls(self):
"""Import and return the MaintenanceAwareTree class."""
from bot import MaintenanceAwareTree
return MaintenanceAwareTree
# ------------------------------------------------------------------
# Maintenance OFF
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_maintenance_off_allows_non_admin(self, tree_cls):
"""
When maintenance_mode is False, non-admin users are always allowed.
The check must return True without sending any message.
"""
bot = _make_bot(maintenance_mode=False)
interaction = _make_interaction(is_admin=False, bot=bot)
# Instantiate tree without calling parent __init__ by testing the method
# directly on an unbound basis.
result = await tree_cls.interaction_check(
MagicMock(), # placeholder 'self' for the tree instance
interaction,
)
assert result is True
interaction.response.send_message.assert_not_called()
@pytest.mark.asyncio
async def test_maintenance_off_allows_admin(self, tree_cls):
"""
When maintenance_mode is False, admin users are also always allowed.
"""
bot = _make_bot(maintenance_mode=False)
interaction = _make_interaction(is_admin=True, bot=bot)
result = await tree_cls.interaction_check(MagicMock(), interaction)
assert result is True
interaction.response.send_message.assert_not_called()
# ------------------------------------------------------------------
# Maintenance ON — non-admin
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_maintenance_on_blocks_non_admin(self, tree_cls):
"""
When maintenance_mode is True, non-admin users must be blocked.
The check must return False and send an ephemeral message.
"""
bot = _make_bot(maintenance_mode=True)
interaction = _make_interaction(is_admin=False, bot=bot)
result = await tree_cls.interaction_check(MagicMock(), interaction)
assert result is False
interaction.response.send_message.assert_called_once()
# Confirm the call used ephemeral=True
_, kwargs = interaction.response.send_message.call_args
assert kwargs.get("ephemeral") is True
@pytest.mark.asyncio
async def test_maintenance_on_message_has_no_emoji(self, tree_cls):
"""
The maintenance block message must not contain emoji characters.
The project style deliberately strips emoji from user-facing strings.
"""
import unicodedata
bot = _make_bot(maintenance_mode=True)
interaction = _make_interaction(is_admin=False, bot=bot)
await tree_cls.interaction_check(MagicMock(), interaction)
args, _ = interaction.response.send_message.call_args
message_text = args[0] if args else ""
for ch in message_text:
category = unicodedata.category(ch)
assert category != "So", (
f"Unexpected emoji/symbol character {ch!r} (category {category!r}) "
f"found in maintenance message: {message_text!r}"
)
# ------------------------------------------------------------------
# Maintenance ON — admin
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_maintenance_on_allows_admin(self, tree_cls):
"""
When maintenance_mode is True, administrator users must still be
allowed through. Admins should never be locked out of bot commands.
"""
bot = _make_bot(maintenance_mode=True)
interaction = _make_interaction(is_admin=True, bot=bot)
result = await tree_cls.interaction_check(MagicMock(), interaction)
assert result is True
interaction.response.send_message.assert_not_called()
# ------------------------------------------------------------------
# Edge case: non-Member user during maintenance
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_maintenance_on_blocks_non_member_user(self, tree_cls):
"""
When maintenance_mode is True and the user is not a guild Member
(e.g. interaction from a DM context), the check must still block them
because we cannot verify administrator status.
"""
bot = _make_bot(maintenance_mode=True)
interaction = AsyncMock(spec=discord.Interaction)
interaction.client = bot
# Simulate a non-Member user (e.g. discord.User from DM context)
user = MagicMock(spec=discord.User)
# discord.User has no guild_permissions attribute
interaction.user = user
interaction.response = AsyncMock()
interaction.response.send_message = AsyncMock()
result = await tree_cls.interaction_check(MagicMock(), interaction)
assert result is False
interaction.response.send_message.assert_called_once()
# ------------------------------------------------------------------
# Missing attribute safety: bot without maintenance_mode attr
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_missing_maintenance_mode_attr_defaults_to_allowed(self, tree_cls):
"""
If the bot object does not have a maintenance_mode attribute (e.g.
during testing or unusual startup), getattr fallback must treat it as
False and allow the interaction.
"""
bot = MagicMock()
# Deliberately do NOT set bot.maintenance_mode
del bot.maintenance_mode
interaction = _make_interaction(is_admin=False, bot=bot)
result = await tree_cls.interaction_check(MagicMock(), interaction)
assert result is True
# ---------------------------------------------------------------------------
# Tests for SBABot.maintenance_mode attribute
# ---------------------------------------------------------------------------
class TestSBABotMaintenanceModeAttribute:
"""
Confirms that SBABot.__init__ always sets maintenance_mode = False.
We avoid constructing a real SBABot (which requires a Discord event loop
and valid token infrastructure) by patching the entire commands.Bot.__init__
and then calling SBABot.__init__ directly on a bare instance so that only
the SBABot-specific attribute assignments execute.
"""
def test_maintenance_mode_default_is_false(self, monkeypatch):
"""
SBABot.__init__ must set self.maintenance_mode = False so that the
MaintenanceAwareTree has the attribute available from the very first
interaction, even before /admin-maintenance is ever called.
Strategy: patch commands.Bot.__init__ to be a no-op so super().__init__
succeeds without a real Discord connection, then call SBABot.__init__
and assert the attribute is present with the correct default value.
"""
from unittest.mock import patch
from discord.ext import commands
from bot import SBABot
with patch.object(commands.Bot, "__init__", return_value=None):
bot = SBABot.__new__(SBABot)
SBABot.__init__(bot)
assert hasattr(
bot, "maintenance_mode"
), "SBABot must define self.maintenance_mode in __init__"
assert (
bot.maintenance_mode is False
), "SBABot.maintenance_mode must default to False"

View File

@ -0,0 +1,284 @@
"""
Tests for schedule service functionality.
Covers get_week_schedule, get_team_schedule, get_recent_games,
get_upcoming_games, and group_games_by_series verifying the
asyncio.gather parallelization and post-fetch filtering logic.
"""
import pytest
from unittest.mock import AsyncMock, patch
from services.schedule_service import ScheduleService
from tests.factories import GameFactory, TeamFactory
def _game(game_id, week, away_abbrev, home_abbrev, **kwargs):
"""Create a Game with distinct team IDs per matchup."""
return GameFactory.create(
id=game_id,
week=week,
away_team=TeamFactory.create(id=game_id * 10, abbrev=away_abbrev),
home_team=TeamFactory.create(id=game_id * 10 + 1, abbrev=home_abbrev),
**kwargs,
)
class TestGetWeekSchedule:
"""Tests for ScheduleService.get_week_schedule — the HTTP layer."""
@pytest.fixture
def service(self):
svc = ScheduleService()
svc.get_client = AsyncMock()
return svc
@pytest.mark.asyncio
async def test_success(self, service):
"""get_week_schedule returns parsed Game objects on a normal response."""
mock_client = AsyncMock()
mock_client.get.return_value = {
"games": [
{
"id": 1,
"season": 12,
"week": 5,
"game_num": 1,
"season_type": "regular",
"away_team": {
"id": 10,
"abbrev": "NYY",
"sname": "NYY",
"lname": "New York",
"season": 12,
},
"home_team": {
"id": 11,
"abbrev": "BOS",
"sname": "BOS",
"lname": "Boston",
"season": 12,
},
"away_score": 4,
"home_score": 2,
}
]
}
service.get_client.return_value = mock_client
games = await service.get_week_schedule(12, 5)
assert len(games) == 1
assert games[0].away_team.abbrev == "NYY"
assert games[0].home_team.abbrev == "BOS"
assert games[0].is_completed
@pytest.mark.asyncio
async def test_empty_response(self, service):
"""get_week_schedule returns [] when the API has no games."""
mock_client = AsyncMock()
mock_client.get.return_value = {"games": []}
service.get_client.return_value = mock_client
games = await service.get_week_schedule(12, 99)
assert games == []
@pytest.mark.asyncio
async def test_api_error_returns_empty(self, service):
"""get_week_schedule returns [] on API error (no exception raised)."""
service.get_client.side_effect = Exception("connection refused")
games = await service.get_week_schedule(12, 1)
assert games == []
@pytest.mark.asyncio
async def test_missing_games_key(self, service):
"""get_week_schedule returns [] when response lacks 'games' key."""
mock_client = AsyncMock()
mock_client.get.return_value = {"status": "ok"}
service.get_client.return_value = mock_client
games = await service.get_week_schedule(12, 1)
assert games == []
class TestGetTeamSchedule:
"""Tests for get_team_schedule — gather + team-abbrev filter."""
@pytest.fixture
def service(self):
return ScheduleService()
@pytest.mark.asyncio
async def test_filters_by_team_case_insensitive(self, service):
"""get_team_schedule returns only games involving the requested team,
regardless of abbreviation casing."""
week1 = [
_game(1, 1, "NYY", "BOS", away_score=3, home_score=1),
_game(2, 1, "LAD", "CHC", away_score=5, home_score=2),
]
week2 = [
_game(3, 2, "BOS", "NYY", away_score=2, home_score=4),
]
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.side_effect = [week1, week2]
result = await service.get_team_schedule(12, "nyy", weeks=2)
assert len(result) == 2
assert all(
g.away_team.abbrev == "NYY" or g.home_team.abbrev == "NYY" for g in result
)
@pytest.mark.asyncio
async def test_full_season_fetches_18_weeks(self, service):
"""When weeks is None, all 18 weeks are fetched via gather."""
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.return_value = []
await service.get_team_schedule(12, "NYY")
assert mock.call_count == 18
@pytest.mark.asyncio
async def test_limited_weeks(self, service):
"""When weeks=5, only 5 weeks are fetched."""
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.return_value = []
await service.get_team_schedule(12, "NYY", weeks=5)
assert mock.call_count == 5
class TestGetRecentGames:
"""Tests for get_recent_games — gather + completed-only filter."""
@pytest.fixture
def service(self):
return ScheduleService()
@pytest.mark.asyncio
async def test_returns_only_completed_games(self, service):
"""get_recent_games filters out games without scores."""
completed = GameFactory.completed(id=1, week=10)
incomplete = GameFactory.upcoming(id=2, week=10)
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.return_value = [completed, incomplete]
result = await service.get_recent_games(12, weeks_back=1)
assert len(result) == 1
assert result[0].is_completed
@pytest.mark.asyncio
async def test_sorted_descending_by_week_and_game_num(self, service):
"""Recent games are sorted most-recent first."""
game_w10 = GameFactory.completed(id=1, week=10, game_num=2)
game_w9 = GameFactory.completed(id=2, week=9, game_num=1)
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.side_effect = [[game_w10], [game_w9]]
result = await service.get_recent_games(12, weeks_back=2)
assert result[0].week == 10
assert result[1].week == 9
@pytest.mark.asyncio
async def test_skips_negative_weeks(self, service):
"""Weeks that would be <= 0 are excluded from fetch."""
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.return_value = []
await service.get_recent_games(12, weeks_back=15)
# weeks_to_fetch = [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] — only 10 valid weeks
assert mock.call_count == 10
class TestGetUpcomingGames:
"""Tests for get_upcoming_games — gather all 18 weeks + incomplete filter."""
@pytest.fixture
def service(self):
return ScheduleService()
@pytest.mark.asyncio
async def test_returns_only_incomplete_games(self, service):
"""get_upcoming_games filters out completed games."""
completed = GameFactory.completed(id=1, week=5)
upcoming = GameFactory.upcoming(id=2, week=5)
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.return_value = [completed, upcoming]
result = await service.get_upcoming_games(12)
assert len(result) == 18 # 1 incomplete game per week × 18 weeks
assert all(not g.is_completed for g in result)
@pytest.mark.asyncio
async def test_sorted_ascending_by_week_and_game_num(self, service):
"""Upcoming games are sorted earliest first."""
game_w3 = GameFactory.upcoming(id=1, week=3, game_num=1)
game_w1 = GameFactory.upcoming(id=2, week=1, game_num=2)
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
def side_effect(season, week):
if week == 1:
return [game_w1]
if week == 3:
return [game_w3]
return []
mock.side_effect = side_effect
result = await service.get_upcoming_games(12)
assert result[0].week == 1
assert result[1].week == 3
@pytest.mark.asyncio
async def test_fetches_all_18_weeks(self, service):
"""All 18 weeks are fetched in parallel (no early exit)."""
with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock:
mock.return_value = []
await service.get_upcoming_games(12)
assert mock.call_count == 18
class TestGroupGamesBySeries:
"""Tests for group_games_by_series — synchronous grouping logic."""
@pytest.fixture
def service(self):
return ScheduleService()
def test_groups_by_alphabetical_pairing(self, service):
"""Games between the same two teams are grouped under one key,
with the alphabetically-first team first in the tuple."""
games = [
_game(1, 1, "NYY", "BOS", game_num=1),
_game(2, 1, "BOS", "NYY", game_num=2),
_game(3, 1, "LAD", "CHC", game_num=1),
]
result = service.group_games_by_series(games)
assert ("BOS", "NYY") in result
assert len(result[("BOS", "NYY")]) == 2
assert ("CHC", "LAD") in result
assert len(result[("CHC", "LAD")]) == 1
def test_sorted_by_game_num_within_series(self, service):
"""Games within each series are sorted by game_num."""
games = [
_game(1, 1, "NYY", "BOS", game_num=3),
_game(2, 1, "NYY", "BOS", game_num=1),
_game(3, 1, "NYY", "BOS", game_num=2),
]
result = service.group_games_by_series(games)
series = result[("BOS", "NYY")]
assert [g.game_num for g in series] == [1, 2, 3]
def test_empty_input(self, service):
"""Empty games list returns empty dict."""
assert service.group_games_by_series([]) == {}