fix: replace broken @self.tree.interaction_check with MaintenanceAwareTree subclass
All checks were successful
Build Docker Image / build (pull_request) Successful in 1m11s

The previous code attempted to register a maintenance mode gate via
@self.tree.interaction_check inside setup_hook.  That pattern is invalid
in discord.py — interaction_check is an overridable method on CommandTree,
not a decorator.  The assignment was silently dropped, making maintenance
mode a no-op and producing a RuntimeWarning about an unawaited coroutine.

Changes:
- Add MaintenanceAwareTree(discord.app_commands.CommandTree) that overrides
  interaction_check: blocks non-admins when bot.maintenance_mode is True,
  always passes admins through, no-op when maintenance mode is off
- Pass tree_cls=MaintenanceAwareTree to super().__init__() in SBABot.__init__
- Add self.maintenance_mode: bool = False to SBABot.__init__
- Update /admin-maintenance command to actually toggle bot.maintenance_mode
- Add tests/test_bot_maintenance_tree.py with 8 unit tests covering all
  maintenance mode states, admin pass-through, DM context, and missing attr

Closes #82

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2026-03-17 12:25:01 -05:00
parent 2f7b82e377
commit d295f27afe
3 changed files with 576 additions and 250 deletions

40
bot.py
View File

@ -65,6 +65,44 @@ def setup_logging():
return logger return logger
class MaintenanceAwareTree(discord.app_commands.CommandTree):
"""
CommandTree subclass that gates all interactions behind a maintenance mode check.
When bot.maintenance_mode is True, non-administrator users receive an ephemeral
error message and the interaction is blocked. Administrators are always allowed
through. When maintenance_mode is False the check is a no-op and every
interaction proceeds normally.
This is the correct way to register a global interaction_check for app commands
in discord.py overriding the method on a CommandTree subclass passed via
tree_cls rather than attempting to assign a decorator to self.tree inside
setup_hook.
"""
async def interaction_check(self, interaction: discord.Interaction) -> bool:
"""Allow admins through; block everyone else when maintenance mode is active."""
bot = interaction.client # type: ignore[assignment]
# If maintenance mode is off, always allow.
if not getattr(bot, "maintenance_mode", False):
return True
# Maintenance mode is on — let administrators through unconditionally.
if (
isinstance(interaction.user, discord.Member)
and interaction.user.guild_permissions.administrator
):
return True
# Block non-admin users with an ephemeral notice.
await interaction.response.send_message(
"The bot is currently in maintenance mode. Please try again later.",
ephemeral=True,
)
return False
class SBABot(commands.Bot): class SBABot(commands.Bot):
"""Custom bot class for SBA league management.""" """Custom bot class for SBA league management."""
@ -78,8 +116,10 @@ class SBABot(commands.Bot):
command_prefix="!", # Legacy prefix, primarily using slash commands command_prefix="!", # Legacy prefix, primarily using slash commands
intents=intents, intents=intents,
description="Major Domo v2.0", description="Major Domo v2.0",
tree_cls=MaintenanceAwareTree,
) )
self.maintenance_mode: bool = False
self.logger = logging.getLogger("discord_bot_v2") self.logger = logging.getLogger("discord_bot_v2")
async def setup_hook(self): async def setup_hook(self):

View File

@ -3,6 +3,7 @@ Admin Management Commands
Administrative commands for league management and bot maintenance. Administrative commands for league management and bot maintenance.
""" """
import asyncio import asyncio
from typing import List, Dict, Any from typing import List, Dict, Any
@ -26,29 +27,27 @@ class AdminCommands(commands.Cog):
def __init__(self, bot: commands.Bot): def __init__(self, bot: commands.Bot):
self.bot = bot self.bot = bot
self.logger = get_contextual_logger(f'{__name__}.AdminCommands') self.logger = get_contextual_logger(f"{__name__}.AdminCommands")
async def interaction_check(self, interaction: discord.Interaction) -> bool: async def interaction_check(self, interaction: discord.Interaction) -> bool:
"""Check if user has admin permissions.""" """Check if user has admin permissions."""
# Check if interaction is from a guild and user is a Member # Check if interaction is from a guild and user is a Member
if not isinstance(interaction.user, discord.Member): if not isinstance(interaction.user, discord.Member):
await interaction.response.send_message( await interaction.response.send_message(
"❌ Admin commands can only be used in a server.", "❌ Admin commands can only be used in a server.", ephemeral=True
ephemeral=True
) )
return False return False
if not interaction.user.guild_permissions.administrator: if not interaction.user.guild_permissions.administrator:
await interaction.response.send_message( await interaction.response.send_message(
"❌ You need administrator permissions to use admin commands.", "❌ You need administrator permissions to use admin commands.",
ephemeral=True ephemeral=True,
) )
return False return False
return True return True
@app_commands.command( @app_commands.command(
name="admin-status", name="admin-status", description="Display bot status and system information"
description="Display bot status and system information"
) )
@league_admin_only() @league_admin_only()
@logged_command("/admin-status") @logged_command("/admin-status")
@ -62,12 +61,13 @@ class AdminCommands(commands.Cog):
commands_count = len([cmd for cmd in self.bot.tree.walk_commands()]) commands_count = len([cmd for cmd in self.bot.tree.walk_commands()])
# Bot uptime calculation # Bot uptime calculation
uptime = discord.utils.utcnow() - self.bot.user.created_at if self.bot.user else None uptime = (
discord.utils.utcnow() - self.bot.user.created_at if self.bot.user else None
)
uptime_str = f"{uptime.days} days" if uptime else "Unknown" uptime_str = f"{uptime.days} days" if uptime else "Unknown"
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title="🤖 Bot Status - Admin Panel", title="🤖 Bot Status - Admin Panel", color=EmbedColors.INFO
color=EmbedColors.INFO
) )
# System Stats # System Stats
@ -77,7 +77,7 @@ class AdminCommands(commands.Cog):
f"**Users:** {users_count:,}\n" f"**Users:** {users_count:,}\n"
f"**Commands:** {commands_count}\n" f"**Commands:** {commands_count}\n"
f"**Uptime:** {uptime_str}", f"**Uptime:** {uptime_str}",
inline=True inline=True,
) )
# Bot Information # Bot Information
@ -86,16 +86,20 @@ class AdminCommands(commands.Cog):
value=f"**Latency:** {round(self.bot.latency * 1000)}ms\n" value=f"**Latency:** {round(self.bot.latency * 1000)}ms\n"
f"**Version:** Discord.py {discord.__version__}\n" f"**Version:** Discord.py {discord.__version__}\n"
f"**Current Season:** {get_config().sba_season}", f"**Current Season:** {get_config().sba_season}",
inline=True inline=True,
) )
# Cog Status # Cog Status
loaded_cogs = list(self.bot.cogs.keys()) loaded_cogs = list(self.bot.cogs.keys())
embed.add_field( embed.add_field(
name="Loaded Cogs", name="Loaded Cogs",
value="\n".join([f"{cog}" for cog in loaded_cogs[:10]]) + value="\n".join([f"{cog}" for cog in loaded_cogs[:10]])
(f"\n... and {len(loaded_cogs) - 10} more" if len(loaded_cogs) > 10 else ""), + (
inline=False f"\n... and {len(loaded_cogs) - 10} more"
if len(loaded_cogs) > 10
else ""
),
inline=False,
) )
embed.set_footer(text="Admin Status • Use /admin-help for more commands") embed.set_footer(text="Admin Status • Use /admin-help for more commands")
@ -103,7 +107,7 @@ class AdminCommands(commands.Cog):
@app_commands.command( @app_commands.command(
name="admin-help", name="admin-help",
description="Display available admin commands and their usage" description="Display available admin commands and their usage",
) )
@league_admin_only() @league_admin_only()
@logged_command("/admin-help") @logged_command("/admin-help")
@ -114,7 +118,7 @@ class AdminCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title="🛠️ Admin Commands - Help", title="🛠️ Admin Commands - Help",
description="Available administrative commands for league management", description="Available administrative commands for league management",
color=EmbedColors.PRIMARY color=EmbedColors.PRIMARY,
) )
# System Commands # System Commands
@ -125,7 +129,7 @@ class AdminCommands(commands.Cog):
"**`/admin-sync`** - Sync application commands\n" "**`/admin-sync`** - Sync application commands\n"
"**`/admin-clear <count>`** - Clear messages from channel\n" "**`/admin-clear <count>`** - Clear messages from channel\n"
"**`/admin-clear-scorecards`** - Clear live scorebug channel and hide it", "**`/admin-clear-scorecards`** - Clear live scorebug channel and hide it",
inline=False inline=False,
) )
# League Management # League Management
@ -135,7 +139,7 @@ class AdminCommands(commands.Cog):
"**`/admin-announce <message>`** - Send announcement to channel\n" "**`/admin-announce <message>`** - Send announcement to channel\n"
"**`/admin-maintenance <on/off>`** - Toggle maintenance mode\n" "**`/admin-maintenance <on/off>`** - Toggle maintenance mode\n"
"**`/admin-process-transactions [week]`** - Manually process weekly transactions", "**`/admin-process-transactions [week]`** - Manually process weekly transactions",
inline=False inline=False,
) )
# User Management # User Management
@ -144,7 +148,7 @@ class AdminCommands(commands.Cog):
value="**`/admin-timeout <user> <duration>`** - Timeout a user\n" value="**`/admin-timeout <user> <duration>`** - Timeout a user\n"
"**`/admin-kick <user> <reason>`** - Kick a user\n" "**`/admin-kick <user> <reason>`** - Kick a user\n"
"**`/admin-ban <user> <reason>`** - Ban a user", "**`/admin-ban <user> <reason>`** - Ban a user",
inline=False inline=False,
) )
embed.add_field( embed.add_field(
@ -152,16 +156,13 @@ class AdminCommands(commands.Cog):
value="• All admin commands require Administrator permissions\n" value="• All admin commands require Administrator permissions\n"
"• Commands are logged for audit purposes\n" "• Commands are logged for audit purposes\n"
"• Use with caution - some actions are irreversible", "• Use with caution - some actions are irreversible",
inline=False inline=False,
) )
embed.set_footer(text="Administrator Permissions Required") embed.set_footer(text="Administrator Permissions Required")
await interaction.followup.send(embed=embed) await interaction.followup.send(embed=embed)
@app_commands.command( @app_commands.command(name="admin-reload", description="Reload a specific bot cog")
name="admin-reload",
description="Reload a specific bot cog"
)
@app_commands.describe( @app_commands.describe(
cog="Name of the cog to reload (e.g., 'commands.players.info')" cog="Name of the cog to reload (e.g., 'commands.players.info')"
) )
@ -178,7 +179,7 @@ class AdminCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title="✅ Cog Reloaded Successfully", title="✅ Cog Reloaded Successfully",
description=f"Successfully reloaded `{cog}`", description=f"Successfully reloaded `{cog}`",
color=EmbedColors.SUCCESS color=EmbedColors.SUCCESS,
) )
embed.add_field( embed.add_field(
@ -186,37 +187,36 @@ class AdminCommands(commands.Cog):
value=f"**Cog:** {cog}\n" value=f"**Cog:** {cog}\n"
f"**Status:** Successfully reloaded\n" f"**Status:** Successfully reloaded\n"
f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}", f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}",
inline=False inline=False,
) )
except commands.ExtensionNotFound: except commands.ExtensionNotFound:
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title="❌ Cog Not Found", title="❌ Cog Not Found",
description=f"Could not find cog: `{cog}`", description=f"Could not find cog: `{cog}`",
color=EmbedColors.ERROR color=EmbedColors.ERROR,
) )
except commands.ExtensionNotLoaded: except commands.ExtensionNotLoaded:
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title="❌ Cog Not Loaded", title="❌ Cog Not Loaded",
description=f"Cog `{cog}` is not currently loaded", description=f"Cog `{cog}` is not currently loaded",
color=EmbedColors.ERROR color=EmbedColors.ERROR,
) )
except Exception as e: except Exception as e:
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title="❌ Reload Failed", title="❌ Reload Failed",
description=f"Failed to reload `{cog}`: {str(e)}", description=f"Failed to reload `{cog}`: {str(e)}",
color=EmbedColors.ERROR color=EmbedColors.ERROR,
) )
await interaction.followup.send(embed=embed) await interaction.followup.send(embed=embed)
@app_commands.command( @app_commands.command(
name="admin-sync", name="admin-sync", description="Sync application commands with Discord"
description="Sync application commands with Discord"
) )
@app_commands.describe( @app_commands.describe(
local="Sync to this guild only (fast, for development)", local="Sync to this guild only (fast, for development)",
clear_local="Clear locally synced commands (does not sync after clearing)" clear_local="Clear locally synced commands (does not sync after clearing)",
) )
@league_admin_only() @league_admin_only()
@logged_command("/admin-sync") @logged_command("/admin-sync")
@ -224,7 +224,7 @@ class AdminCommands(commands.Cog):
self, self,
interaction: discord.Interaction, interaction: discord.Interaction,
local: bool = False, local: bool = False,
clear_local: bool = False clear_local: bool = False,
): ):
"""Sync slash commands with Discord API.""" """Sync slash commands with Discord API."""
await interaction.response.defer() await interaction.response.defer()
@ -235,20 +235,24 @@ class AdminCommands(commands.Cog):
if not interaction.guild_id: if not interaction.guild_id:
raise ValueError("Cannot clear local commands outside of a guild") raise ValueError("Cannot clear local commands outside of a guild")
self.logger.info(f"Clearing local commands for guild {interaction.guild_id}") self.logger.info(
self.bot.tree.clear_commands(guild=discord.Object(id=interaction.guild_id)) f"Clearing local commands for guild {interaction.guild_id}"
)
self.bot.tree.clear_commands(
guild=discord.Object(id=interaction.guild_id)
)
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title="✅ Local Commands Cleared", title="✅ Local Commands Cleared",
description=f"Cleared all commands synced to this guild", description=f"Cleared all commands synced to this guild",
color=EmbedColors.SUCCESS color=EmbedColors.SUCCESS,
) )
embed.add_field( embed.add_field(
name="Clear Details", name="Clear Details",
value=f"**Guild ID:** {interaction.guild_id}\n" value=f"**Guild ID:** {interaction.guild_id}\n"
f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}\n" f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}\n"
f"**Note:** Commands not synced after clearing", f"**Note:** Commands not synced after clearing",
inline=False inline=False,
) )
await interaction.followup.send(embed=embed) await interaction.followup.send(embed=embed)
return return
@ -270,16 +274,20 @@ class AdminCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title="✅ Commands Synced Successfully", title="✅ Commands Synced Successfully",
description=f"Synced {len(synced_commands)} application commands {sync_type}", description=f"Synced {len(synced_commands)} application commands {sync_type}",
color=EmbedColors.SUCCESS color=EmbedColors.SUCCESS,
) )
# Show some of the synced commands # Show some of the synced commands
command_names = [cmd.name for cmd in synced_commands[:10]] command_names = [cmd.name for cmd in synced_commands[:10]]
embed.add_field( embed.add_field(
name="Synced Commands", name="Synced Commands",
value="\n".join([f"• /{name}" for name in command_names]) + value="\n".join([f"• /{name}" for name in command_names])
(f"\n... and {len(synced_commands) - 10} more" if len(synced_commands) > 10 else ""), + (
inline=False f"\n... and {len(synced_commands) - 10} more"
if len(synced_commands) > 10
else ""
),
inline=False,
) )
embed.add_field( embed.add_field(
@ -288,7 +296,7 @@ class AdminCommands(commands.Cog):
f"**Sync Type:** {sync_type.title()}\n" f"**Sync Type:** {sync_type.title()}\n"
f"**Guild ID:** {interaction.guild_id or 'N/A'}\n" f"**Guild ID:** {interaction.guild_id or 'N/A'}\n"
f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}", f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}",
inline=False inline=False,
) )
except Exception as e: except Exception as e:
@ -296,7 +304,7 @@ class AdminCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title="❌ Sync Failed", title="❌ Sync Failed",
description=f"Failed to sync commands: {str(e)}", description=f"Failed to sync commands: {str(e)}",
color=EmbedColors.ERROR color=EmbedColors.ERROR,
) )
await interaction.followup.send(embed=embed) await interaction.followup.send(embed=embed)
@ -310,7 +318,9 @@ class AdminCommands(commands.Cog):
Use this when slash commands aren't synced yet and you can't access /admin-sync. Use this when slash commands aren't synced yet and you can't access /admin-sync.
Syncs to the current guild only (for multi-bot scenarios). Syncs to the current guild only (for multi-bot scenarios).
""" """
self.logger.info(f"Prefix command !admin-sync invoked by {ctx.author} in {ctx.guild}") self.logger.info(
f"Prefix command !admin-sync invoked by {ctx.author} in {ctx.guild}"
)
try: try:
# Sync to current guild only (not globally) for multi-bot scenarios # Sync to current guild only (not globally) for multi-bot scenarios
@ -319,16 +329,20 @@ class AdminCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title="✅ Commands Synced Successfully", title="✅ Commands Synced Successfully",
description=f"Synced {len(synced_commands)} application commands", description=f"Synced {len(synced_commands)} application commands",
color=EmbedColors.SUCCESS color=EmbedColors.SUCCESS,
) )
# Show some of the synced commands # Show some of the synced commands
command_names = [cmd.name for cmd in synced_commands[:10]] command_names = [cmd.name for cmd in synced_commands[:10]]
embed.add_field( embed.add_field(
name="Synced Commands", name="Synced Commands",
value="\n".join([f"• /{name}" for name in command_names]) + value="\n".join([f"• /{name}" for name in command_names])
(f"\n... and {len(synced_commands) - 10} more" if len(synced_commands) > 10 else ""), + (
inline=False f"\n... and {len(synced_commands) - 10} more"
if len(synced_commands) > 10
else ""
),
inline=False,
) )
embed.add_field( embed.add_field(
@ -337,7 +351,7 @@ class AdminCommands(commands.Cog):
f"**Sync Type:** Local Guild\n" f"**Sync Type:** Local Guild\n"
f"**Guild ID:** {ctx.guild.id}\n" f"**Guild ID:** {ctx.guild.id}\n"
f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}", f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}",
inline=False inline=False,
) )
embed.set_footer(text="💡 Use /admin-sync local:True for guild-only sync") embed.set_footer(text="💡 Use /admin-sync local:True for guild-only sync")
@ -347,36 +361,39 @@ class AdminCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title="❌ Sync Failed", title="❌ Sync Failed",
description=f"Failed to sync commands: {str(e)}", description=f"Failed to sync commands: {str(e)}",
color=EmbedColors.ERROR color=EmbedColors.ERROR,
) )
await ctx.send(embed=embed) await ctx.send(embed=embed)
@app_commands.command( @app_commands.command(
name="admin-clear", name="admin-clear", description="Clear messages from the current channel"
description="Clear messages from the current channel"
)
@app_commands.describe(
count="Number of messages to delete (1-100)"
) )
@app_commands.describe(count="Number of messages to delete (1-100)")
@league_admin_only() @league_admin_only()
@logged_command("/admin-clear") @logged_command("/admin-clear")
async def admin_clear(self, interaction: discord.Interaction, count: int): async def admin_clear(self, interaction: discord.Interaction, count: int):
"""Clear a specified number of messages from the channel.""" """Clear a specified number of messages from the channel."""
if count < 1 or count > 100: if count < 1 or count > 100:
await interaction.response.send_message( await interaction.response.send_message(
"❌ Count must be between 1 and 100.", "❌ Count must be between 1 and 100.", ephemeral=True
ephemeral=True
) )
return return
await interaction.response.defer() await interaction.response.defer()
# Verify channel type supports purge # Verify channel type supports purge
if not isinstance(interaction.channel, (discord.TextChannel, discord.Thread, discord.VoiceChannel, discord.StageChannel)): if not isinstance(
interaction.channel,
(
discord.TextChannel,
discord.Thread,
discord.VoiceChannel,
discord.StageChannel,
),
):
await interaction.followup.send( await interaction.followup.send(
"❌ Cannot purge messages in this channel type.", "❌ Cannot purge messages in this channel type.", ephemeral=True
ephemeral=True
) )
return return
@ -386,7 +403,7 @@ class AdminCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title="🗑️ Messages Cleared", title="🗑️ Messages Cleared",
description=f"Successfully deleted {len(deleted)} messages", description=f"Successfully deleted {len(deleted)} messages",
color=EmbedColors.SUCCESS color=EmbedColors.SUCCESS,
) )
embed.add_field( embed.add_field(
@ -395,7 +412,7 @@ class AdminCommands(commands.Cog):
f"**Channel:** {interaction.channel.mention}\n" f"**Channel:** {interaction.channel.mention}\n"
f"**Requested:** {count} messages\n" f"**Requested:** {count} messages\n"
f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}", f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}",
inline=False inline=False,
) )
# Send confirmation and auto-delete after 5 seconds # Send confirmation and auto-delete after 5 seconds
@ -409,22 +426,19 @@ class AdminCommands(commands.Cog):
except discord.Forbidden: except discord.Forbidden:
await interaction.followup.send( await interaction.followup.send(
"❌ Missing permissions to delete messages.", "❌ Missing permissions to delete messages.", ephemeral=True
ephemeral=True
) )
except Exception as e: except Exception as e:
await interaction.followup.send( await interaction.followup.send(
f"❌ Failed to clear messages: {str(e)}", f"❌ Failed to clear messages: {str(e)}", ephemeral=True
ephemeral=True
) )
@app_commands.command( @app_commands.command(
name="admin-announce", name="admin-announce", description="Send an announcement to the current channel"
description="Send an announcement to the current channel"
) )
@app_commands.describe( @app_commands.describe(
message="Announcement message to send", message="Announcement message to send",
mention_everyone="Whether to mention @everyone (default: False)" mention_everyone="Whether to mention @everyone (default: False)",
) )
@league_admin_only() @league_admin_only()
@logged_command("/admin-announce") @logged_command("/admin-announce")
@ -432,7 +446,7 @@ class AdminCommands(commands.Cog):
self, self,
interaction: discord.Interaction, interaction: discord.Interaction,
message: str, message: str,
mention_everyone: bool = False mention_everyone: bool = False,
): ):
"""Send an official announcement to the channel.""" """Send an official announcement to the channel."""
await interaction.response.defer() await interaction.response.defer()
@ -440,12 +454,12 @@ class AdminCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title="📢 League Announcement", title="📢 League Announcement",
description=message, description=message,
color=EmbedColors.PRIMARY color=EmbedColors.PRIMARY,
) )
embed.set_footer( embed.set_footer(
text=f"Announcement by {interaction.user.display_name}", text=f"Announcement by {interaction.user.display_name}",
icon_url=interaction.user.display_avatar.url icon_url=interaction.user.display_avatar.url,
) )
# Send with or without mention based on flag # Send with or without mention based on flag
@ -460,33 +474,30 @@ class AdminCommands(commands.Cog):
) )
@app_commands.command( @app_commands.command(
name="admin-maintenance", name="admin-maintenance", description="Toggle maintenance mode for the bot"
description="Toggle maintenance mode for the bot"
) )
@app_commands.describe( @app_commands.describe(mode="Turn maintenance mode on or off")
mode="Turn maintenance mode on or off" @app_commands.choices(
) mode=[
@app_commands.choices(mode=[
app_commands.Choice(name="On", value="on"), app_commands.Choice(name="On", value="on"),
app_commands.Choice(name="Off", value="off") app_commands.Choice(name="Off", value="off"),
]) ]
)
@league_admin_only() @league_admin_only()
@logged_command("/admin-maintenance") @logged_command("/admin-maintenance")
async def admin_maintenance(self, interaction: discord.Interaction, mode: str): async def admin_maintenance(self, interaction: discord.Interaction, mode: str):
"""Toggle maintenance mode to prevent normal command usage.""" """Toggle maintenance mode to prevent normal command usage."""
await interaction.response.defer() await interaction.response.defer()
# This would typically set a global flag or database value
# For now, we'll just show the interface
is_enabling = mode.lower() == "on" is_enabling = mode.lower() == "on"
self.bot.maintenance_mode = is_enabling # type: ignore[attr-defined]
status_text = "enabled" if is_enabling else "disabled" status_text = "enabled" if is_enabling else "disabled"
color = EmbedColors.WARNING if is_enabling else EmbedColors.SUCCESS color = EmbedColors.WARNING if is_enabling else EmbedColors.SUCCESS
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title=f"🔧 Maintenance Mode {status_text.title()}", title=f"🔧 Maintenance Mode {status_text.title()}",
description=f"Maintenance mode has been **{status_text}**", description=f"Maintenance mode has been **{status_text}**",
color=color color=color,
) )
if is_enabling: if is_enabling:
@ -495,7 +506,7 @@ class AdminCommands(commands.Cog):
value="• Normal commands are disabled\n" value="• Normal commands are disabled\n"
"• Only admin commands are available\n" "• Only admin commands are available\n"
"• Users will see maintenance message", "• Users will see maintenance message",
inline=False inline=False,
) )
else: else:
embed.add_field( embed.add_field(
@ -503,7 +514,7 @@ class AdminCommands(commands.Cog):
value="• All commands are now available\n" value="• All commands are now available\n"
"• Normal bot operation resumed\n" "• Normal bot operation resumed\n"
"• Users can access all features", "• Users can access all features",
inline=False inline=False,
) )
embed.add_field( embed.add_field(
@ -511,14 +522,14 @@ class AdminCommands(commands.Cog):
value=f"**Changed by:** {interaction.user.mention}\n" value=f"**Changed by:** {interaction.user.mention}\n"
f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}\n" f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}\n"
f"**Mode:** {status_text.title()}", f"**Mode:** {status_text.title()}",
inline=False inline=False,
) )
await interaction.followup.send(embed=embed) await interaction.followup.send(embed=embed)
@app_commands.command( @app_commands.command(
name="admin-clear-scorecards", name="admin-clear-scorecards",
description="Manually clear the live scorebug channel and hide it from members" description="Manually clear the live scorebug channel and hide it from members",
) )
@league_admin_only() @league_admin_only()
@logged_command("/admin-clear-scorecards") @logged_command("/admin-clear-scorecards")
@ -539,17 +550,17 @@ class AdminCommands(commands.Cog):
if not guild: if not guild:
await interaction.followup.send( await interaction.followup.send(
"❌ Could not find guild. Check configuration.", "❌ Could not find guild. Check configuration.", ephemeral=True
ephemeral=True
) )
return return
live_scores_channel = discord.utils.get(guild.text_channels, name='live-sba-scores') live_scores_channel = discord.utils.get(
guild.text_channels, name="live-sba-scores"
)
if not live_scores_channel: if not live_scores_channel:
await interaction.followup.send( await interaction.followup.send(
"❌ Could not find #live-sba-scores channel.", "❌ Could not find #live-sba-scores channel.", ephemeral=True
ephemeral=True
) )
return return
@ -569,7 +580,7 @@ class AdminCommands(commands.Cog):
visibility_success = await set_channel_visibility( visibility_success = await set_channel_visibility(
live_scores_channel, live_scores_channel,
visible=False, visible=False,
reason="Admin manual clear via /admin-clear-scorecards" reason="Admin manual clear via /admin-clear-scorecards",
) )
if visibility_success: if visibility_success:
@ -580,7 +591,7 @@ class AdminCommands(commands.Cog):
# Create success embed # Create success embed
embed = EmbedTemplate.success( embed = EmbedTemplate.success(
title="Live Scorebug Channel Cleared", title="Live Scorebug Channel Cleared",
description="Successfully cleared the #live-sba-scores channel" description="Successfully cleared the #live-sba-scores channel",
) )
embed.add_field( embed.add_field(
@ -589,7 +600,7 @@ class AdminCommands(commands.Cog):
f"**Messages Deleted:** {deleted_count}\n" f"**Messages Deleted:** {deleted_count}\n"
f"**Visibility:** {visibility_status}\n" f"**Visibility:** {visibility_status}\n"
f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}", f"**Time:** {discord.utils.utcnow().strftime('%H:%M:%S UTC')}",
inline=False inline=False,
) )
embed.add_field( embed.add_field(
@ -598,7 +609,7 @@ class AdminCommands(commands.Cog):
"• Bot retains full access to the channel\n" "• Bot retains full access to the channel\n"
"• Channel will auto-show when games are published\n" "• Channel will auto-show when games are published\n"
"• Live scorebug tracker runs every 3 minutes", "• Live scorebug tracker runs every 3 minutes",
inline=False inline=False,
) )
await interaction.followup.send(embed=embed) await interaction.followup.send(embed=embed)
@ -606,18 +617,17 @@ class AdminCommands(commands.Cog):
except discord.Forbidden: except discord.Forbidden:
await interaction.followup.send( await interaction.followup.send(
"❌ Missing permissions to clear messages or modify channel permissions.", "❌ Missing permissions to clear messages or modify channel permissions.",
ephemeral=True ephemeral=True,
) )
except Exception as e: except Exception as e:
self.logger.error(f"Error clearing scorecards: {e}", exc_info=True) self.logger.error(f"Error clearing scorecards: {e}", exc_info=True)
await interaction.followup.send( await interaction.followup.send(
f"❌ Failed to clear channel: {str(e)}", f"❌ Failed to clear channel: {str(e)}", ephemeral=True
ephemeral=True
) )
@app_commands.command( @app_commands.command(
name="admin-process-transactions", name="admin-process-transactions",
description="[ADMIN] Manually process all transactions for the current week (or specified week)" description="[ADMIN] Manually process all transactions for the current week (or specified week)",
) )
@app_commands.describe( @app_commands.describe(
week="Week number to process (optional, defaults to current week)" week="Week number to process (optional, defaults to current week)"
@ -625,9 +635,7 @@ class AdminCommands(commands.Cog):
@league_admin_only() @league_admin_only()
@logged_command("/admin-process-transactions") @logged_command("/admin-process-transactions")
async def admin_process_transactions( async def admin_process_transactions(
self, self, interaction: discord.Interaction, week: int | None = None
interaction: discord.Interaction,
week: int | None = None
): ):
""" """
Manually process all transactions for the current week. Manually process all transactions for the current week.
@ -649,7 +657,7 @@ class AdminCommands(commands.Cog):
if not current: if not current:
await interaction.followup.send( await interaction.followup.send(
"❌ Could not get current league state from the API.", "❌ Could not get current league state from the API.",
ephemeral=True ephemeral=True,
) )
return return
@ -659,22 +667,24 @@ class AdminCommands(commands.Cog):
self.logger.info( self.logger.info(
f"Processing transactions for week {target_week}, season {target_season}", f"Processing transactions for week {target_week}, season {target_season}",
requested_by=str(interaction.user) requested_by=str(interaction.user),
) )
# Get all non-frozen, non-cancelled transactions for the target week using service layer # Get all non-frozen, non-cancelled transactions for the target week using service layer
transactions = await transaction_service.get_all_items(params=[ transactions = await transaction_service.get_all_items(
('season', str(target_season)), params=[
('week_start', str(target_week)), ("season", str(target_season)),
('week_end', str(target_week)), ("week_start", str(target_week)),
('frozen', 'false'), ("week_end", str(target_week)),
('cancelled', 'false') ("frozen", "false"),
]) ("cancelled", "false"),
]
)
if not transactions: if not transactions:
embed = EmbedTemplate.info( embed = EmbedTemplate.info(
title="No Transactions to Process", title="No Transactions to Process",
description=f"No non-frozen, non-cancelled transactions found for Week {target_week}" description=f"No non-frozen, non-cancelled transactions found for Week {target_week}",
) )
embed.add_field( embed.add_field(
@ -683,7 +693,7 @@ class AdminCommands(commands.Cog):
f"**Week:** {target_week}\n" f"**Week:** {target_week}\n"
f"**Frozen:** No\n" f"**Frozen:** No\n"
f"**Cancelled:** No", f"**Cancelled:** No",
inline=False inline=False,
) )
await interaction.followup.send(embed=embed) await interaction.followup.send(embed=embed)
@ -692,7 +702,9 @@ class AdminCommands(commands.Cog):
# Count total transactions # Count total transactions
total_count = len(transactions) total_count = len(transactions)
self.logger.info(f"Found {total_count} transactions to process for week {target_week}") self.logger.info(
f"Found {total_count} transactions to process for week {target_week}"
)
# Process each transaction # Process each transaction
success_count = 0 success_count = 0
@ -702,12 +714,10 @@ class AdminCommands(commands.Cog):
# Create initial status embed # Create initial status embed
processing_embed = EmbedTemplate.loading( processing_embed = EmbedTemplate.loading(
title="Processing Transactions", title="Processing Transactions",
description=f"Processing {total_count} transactions for Week {target_week}..." description=f"Processing {total_count} transactions for Week {target_week}...",
) )
processing_embed.add_field( processing_embed.add_field(
name="Progress", name="Progress", value="Starting...", inline=False
value="Starting...",
inline=False
) )
status_message = await interaction.followup.send(embed=processing_embed) status_message = await interaction.followup.send(embed=processing_embed)
@ -718,7 +728,7 @@ class AdminCommands(commands.Cog):
await self._execute_player_update( await self._execute_player_update(
player_id=transaction.player.id, player_id=transaction.player.id,
new_team_id=transaction.newteam.id, new_team_id=transaction.newteam.id,
player_name=transaction.player.name player_name=transaction.player.name,
) )
success_count += 1 success_count += 1
@ -731,7 +741,7 @@ class AdminCommands(commands.Cog):
value=f"Processed {idx}/{total_count} transactions\n" value=f"Processed {idx}/{total_count} transactions\n"
f"✅ Successful: {success_count}\n" f"✅ Successful: {success_count}\n"
f"❌ Failed: {failure_count}", f"❌ Failed: {failure_count}",
inline=False inline=False,
) )
await status_message.edit(embed=processing_embed) # type: ignore await status_message.edit(embed=processing_embed) # type: ignore
@ -741,35 +751,35 @@ class AdminCommands(commands.Cog):
except Exception as e: except Exception as e:
failure_count += 1 failure_count += 1
error_info = { error_info = {
'player': transaction.player.name, "player": transaction.player.name,
'player_id': transaction.player.id, "player_id": transaction.player.id,
'new_team': transaction.newteam.abbrev, "new_team": transaction.newteam.abbrev,
'error': str(e) "error": str(e),
} }
errors.append(error_info) errors.append(error_info)
self.logger.error( self.logger.error(
f"Failed to execute transaction for {error_info['player']}", f"Failed to execute transaction for {error_info['player']}",
player_id=error_info['player_id'], player_id=error_info["player_id"],
new_team=error_info['new_team'], new_team=error_info["new_team"],
error=e error=e,
) )
# Create completion embed # Create completion embed
if failure_count == 0: if failure_count == 0:
completion_embed = EmbedTemplate.success( completion_embed = EmbedTemplate.success(
title="Transactions Processed Successfully", title="Transactions Processed Successfully",
description=f"All {total_count} transactions for Week {target_week} have been processed." description=f"All {total_count} transactions for Week {target_week} have been processed.",
) )
elif success_count == 0: elif success_count == 0:
completion_embed = EmbedTemplate.error( completion_embed = EmbedTemplate.error(
title="Transaction Processing Failed", title="Transaction Processing Failed",
description=f"Failed to process all {total_count} transactions for Week {target_week}." description=f"Failed to process all {total_count} transactions for Week {target_week}.",
) )
else: else:
completion_embed = EmbedTemplate.warning( completion_embed = EmbedTemplate.warning(
title="Transactions Partially Processed", title="Transactions Partially Processed",
description=f"Some transactions for Week {target_week} failed to process." description=f"Some transactions for Week {target_week} failed to process.",
) )
completion_embed.add_field( completion_embed.add_field(
@ -779,7 +789,7 @@ class AdminCommands(commands.Cog):
f"**❌ Failed:** {failure_count}\n" f"**❌ Failed:** {failure_count}\n"
f"**Week:** {target_week}\n" f"**Week:** {target_week}\n"
f"**Season:** {target_season}", f"**Season:** {target_season}",
inline=False inline=False,
) )
# Add error details if there were failures # Add error details if there were failures
@ -792,9 +802,7 @@ class AdminCommands(commands.Cog):
error_text += f"\n... and {len(errors) - 5} more errors" error_text += f"\n... and {len(errors) - 5} more errors"
completion_embed.add_field( completion_embed.add_field(
name="Errors", name="Errors", value=error_text, inline=False
value=error_text,
inline=False
) )
completion_embed.add_field( completion_embed.add_field(
@ -802,7 +810,7 @@ class AdminCommands(commands.Cog):
value="• Verify transactions in the database\n" value="• Verify transactions in the database\n"
"• Check #transaction-log channel for posted moves\n" "• Check #transaction-log channel for posted moves\n"
"• Review any errors and retry if necessary", "• Review any errors and retry if necessary",
inline=False inline=False,
) )
completion_embed.set_footer( completion_embed.set_footer(
@ -816,7 +824,7 @@ class AdminCommands(commands.Cog):
f"Transaction processing complete for week {target_week}", f"Transaction processing complete for week {target_week}",
success=success_count, success=success_count,
failures=failure_count, failures=failure_count,
total=total_count total=total_count,
) )
except Exception as e: except Exception as e:
@ -824,16 +832,13 @@ class AdminCommands(commands.Cog):
embed = EmbedTemplate.error( embed = EmbedTemplate.error(
title="Transaction Processing Failed", title="Transaction Processing Failed",
description=f"An error occurred while processing transactions: {str(e)}" description=f"An error occurred while processing transactions: {str(e)}",
) )
await interaction.followup.send(embed=embed, ephemeral=True) await interaction.followup.send(embed=embed, ephemeral=True)
async def _execute_player_update( async def _execute_player_update(
self, self, player_id: int, new_team_id: int, player_name: str
player_id: int,
new_team_id: int,
player_name: str
) -> bool: ) -> bool:
""" """
Execute a player roster update via service layer. Execute a player roster update via service layer.
@ -854,13 +859,12 @@ class AdminCommands(commands.Cog):
f"Updating player roster", f"Updating player roster",
player_id=player_id, player_id=player_id,
player_name=player_name, player_name=player_name,
new_team_id=new_team_id new_team_id=new_team_id,
) )
# Execute player team update via service layer # Execute player team update via service layer
updated_player = await player_service.update_player_team( updated_player = await player_service.update_player_team(
player_id=player_id, player_id=player_id, new_team_id=new_team_id
new_team_id=new_team_id
) )
# Verify update was successful # Verify update was successful
@ -869,7 +873,7 @@ class AdminCommands(commands.Cog):
f"Successfully updated player roster", f"Successfully updated player roster",
player_id=player_id, player_id=player_id,
player_name=player_name, player_name=player_name,
new_team_id=new_team_id new_team_id=new_team_id,
) )
return True return True
else: else:
@ -877,7 +881,7 @@ class AdminCommands(commands.Cog):
f"Player update returned no response", f"Player update returned no response",
player_id=player_id, player_id=player_id,
player_name=player_name, player_name=player_name,
new_team_id=new_team_id new_team_id=new_team_id,
) )
return False return False
@ -888,7 +892,7 @@ class AdminCommands(commands.Cog):
player_name=player_name, player_name=player_name,
new_team_id=new_team_id, new_team_id=new_team_id,
error=e, error=e,
exc_info=True exc_info=True,
) )
raise raise

View File

@ -0,0 +1,282 @@
"""
Tests for MaintenanceAwareTree and the maintenance_mode attribute on SBABot.
What:
Verifies that the CommandTree subclass correctly gates interactions behind
bot.maintenance_mode. When maintenance mode is off every interaction is
allowed through unconditionally. When maintenance mode is on, non-admin
users receive an ephemeral error message and the check returns False, while
administrators are always allowed through.
Why:
The original code attempted to register an interaction_check via a decorator
on self.tree inside setup_hook. That is not a valid pattern in discord.py
interaction_check is an overridable async method on CommandTree, not a
decorator. The broken assignment caused a RuntimeWarning (unawaited
coroutine) and silently made maintenance mode a no-op. These tests confirm
the correct subclass-based implementation behaves as specified.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
import discord
# ---------------------------------------------------------------------------
# Helpers / fixtures
# ---------------------------------------------------------------------------
def _make_bot(maintenance_mode: bool = False) -> MagicMock:
"""Return a minimal mock bot with a maintenance_mode attribute."""
bot = MagicMock()
bot.maintenance_mode = maintenance_mode
return bot
def _make_interaction(is_admin: bool, bot: MagicMock) -> AsyncMock:
"""
Build a mock discord.Interaction.
The interaction's .client is set to *bot* so that MaintenanceAwareTree
can read bot.maintenance_mode via interaction.client, mirroring how
discord.py wires things at runtime.
"""
interaction = AsyncMock(spec=discord.Interaction)
interaction.client = bot
# Mock the user as a guild Member so that guild_permissions is accessible.
user = MagicMock(spec=discord.Member)
user.guild_permissions = MagicMock()
user.guild_permissions.administrator = is_admin
interaction.user = user
# response.send_message must be awaitable.
interaction.response = AsyncMock()
interaction.response.send_message = AsyncMock()
return interaction
# ---------------------------------------------------------------------------
# Import the class under test after mocks are available.
# We import here (not at module level) so that the conftest env-vars are set
# before any discord_bot_v2 modules are touched.
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _patch_discord_app_commands(monkeypatch):
"""
Prevent MaintenanceAwareTree.__init__ from calling discord internals that
need a real event loop / Discord connection. We test only the logic of
interaction_check, so we stub out the parent __init__.
"""
# Nothing extra to patch for the interaction_check itself; the parent
# CommandTree.__init__ is only called when constructing SBABot, which we
# don't do in these unit tests.
yield
# ---------------------------------------------------------------------------
# Tests for MaintenanceAwareTree.interaction_check
# ---------------------------------------------------------------------------
class TestMaintenanceAwareTree:
"""Unit tests for MaintenanceAwareTree.interaction_check."""
@pytest.fixture
def tree_cls(self):
"""Import and return the MaintenanceAwareTree class."""
from bot import MaintenanceAwareTree
return MaintenanceAwareTree
# ------------------------------------------------------------------
# Maintenance OFF
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_maintenance_off_allows_non_admin(self, tree_cls):
"""
When maintenance_mode is False, non-admin users are always allowed.
The check must return True without sending any message.
"""
bot = _make_bot(maintenance_mode=False)
interaction = _make_interaction(is_admin=False, bot=bot)
# Instantiate tree without calling parent __init__ by testing the method
# directly on an unbound basis.
result = await tree_cls.interaction_check(
MagicMock(), # placeholder 'self' for the tree instance
interaction,
)
assert result is True
interaction.response.send_message.assert_not_called()
@pytest.mark.asyncio
async def test_maintenance_off_allows_admin(self, tree_cls):
"""
When maintenance_mode is False, admin users are also always allowed.
"""
bot = _make_bot(maintenance_mode=False)
interaction = _make_interaction(is_admin=True, bot=bot)
result = await tree_cls.interaction_check(MagicMock(), interaction)
assert result is True
interaction.response.send_message.assert_not_called()
# ------------------------------------------------------------------
# Maintenance ON — non-admin
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_maintenance_on_blocks_non_admin(self, tree_cls):
"""
When maintenance_mode is True, non-admin users must be blocked.
The check must return False and send an ephemeral message.
"""
bot = _make_bot(maintenance_mode=True)
interaction = _make_interaction(is_admin=False, bot=bot)
result = await tree_cls.interaction_check(MagicMock(), interaction)
assert result is False
interaction.response.send_message.assert_called_once()
# Confirm the call used ephemeral=True
_, kwargs = interaction.response.send_message.call_args
assert kwargs.get("ephemeral") is True
@pytest.mark.asyncio
async def test_maintenance_on_message_has_no_emoji(self, tree_cls):
"""
The maintenance block message must not contain emoji characters.
The project style deliberately strips emoji from user-facing strings.
"""
import unicodedata
bot = _make_bot(maintenance_mode=True)
interaction = _make_interaction(is_admin=False, bot=bot)
await tree_cls.interaction_check(MagicMock(), interaction)
args, _ = interaction.response.send_message.call_args
message_text = args[0] if args else ""
for ch in message_text:
category = unicodedata.category(ch)
assert category != "So", (
f"Unexpected emoji/symbol character {ch!r} (category {category!r}) "
f"found in maintenance message: {message_text!r}"
)
# ------------------------------------------------------------------
# Maintenance ON — admin
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_maintenance_on_allows_admin(self, tree_cls):
"""
When maintenance_mode is True, administrator users must still be
allowed through. Admins should never be locked out of bot commands.
"""
bot = _make_bot(maintenance_mode=True)
interaction = _make_interaction(is_admin=True, bot=bot)
result = await tree_cls.interaction_check(MagicMock(), interaction)
assert result is True
interaction.response.send_message.assert_not_called()
# ------------------------------------------------------------------
# Edge case: non-Member user during maintenance
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_maintenance_on_blocks_non_member_user(self, tree_cls):
"""
When maintenance_mode is True and the user is not a guild Member
(e.g. interaction from a DM context), the check must still block them
because we cannot verify administrator status.
"""
bot = _make_bot(maintenance_mode=True)
interaction = AsyncMock(spec=discord.Interaction)
interaction.client = bot
# Simulate a non-Member user (e.g. discord.User from DM context)
user = MagicMock(spec=discord.User)
# discord.User has no guild_permissions attribute
interaction.user = user
interaction.response = AsyncMock()
interaction.response.send_message = AsyncMock()
result = await tree_cls.interaction_check(MagicMock(), interaction)
assert result is False
interaction.response.send_message.assert_called_once()
# ------------------------------------------------------------------
# Missing attribute safety: bot without maintenance_mode attr
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_missing_maintenance_mode_attr_defaults_to_allowed(self, tree_cls):
"""
If the bot object does not have a maintenance_mode attribute (e.g.
during testing or unusual startup), getattr fallback must treat it as
False and allow the interaction.
"""
bot = MagicMock()
# Deliberately do NOT set bot.maintenance_mode
del bot.maintenance_mode
interaction = _make_interaction(is_admin=False, bot=bot)
result = await tree_cls.interaction_check(MagicMock(), interaction)
assert result is True
# ---------------------------------------------------------------------------
# Tests for SBABot.maintenance_mode attribute
# ---------------------------------------------------------------------------
class TestSBABotMaintenanceModeAttribute:
"""
Confirms that SBABot.__init__ always sets maintenance_mode = False.
We avoid constructing a real SBABot (which requires a Discord event loop
and valid token infrastructure) by patching the entire commands.Bot.__init__
and then calling SBABot.__init__ directly on a bare instance so that only
the SBABot-specific attribute assignments execute.
"""
def test_maintenance_mode_default_is_false(self, monkeypatch):
"""
SBABot.__init__ must set self.maintenance_mode = False so that the
MaintenanceAwareTree has the attribute available from the very first
interaction, even before /admin-maintenance is ever called.
Strategy: patch commands.Bot.__init__ to be a no-op so super().__init__
succeeds without a real Discord connection, then call SBABot.__init__
and assert the attribute is present with the correct default value.
"""
from unittest.mock import patch
from discord.ext import commands
from bot import SBABot
with patch.object(commands.Bot, "__init__", return_value=None):
bot = SBABot.__new__(SBABot)
SBABot.__init__(bot)
assert hasattr(
bot, "maintenance_mode"
), "SBABot must define self.maintenance_mode in __init__"
assert (
bot.maintenance_mode is False
), "SBABot.maintenance_mode must default to False"