feat: add is_admin() helper to utils/permissions.py (#55)
All checks were successful
Build Docker Image / build (pull_request) Successful in 1m9s

Add centralized `is_admin(interaction)` helper that includes the
`isinstance(interaction.user, discord.Member)` guard, preventing
AttributeError in DM contexts.

Use it in `can_edit_player_image()` which previously accessed
`guild_permissions.administrator` directly without the isinstance
guard. Update the corresponding test to mock the user with
`spec=discord.Member` so the isinstance check passes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2026-03-04 22:33:31 -06:00
parent f7a65706a1
commit ed40b532b5
3 changed files with 154 additions and 108 deletions

View File

@ -4,6 +4,7 @@ Player Image Management Commands
Allows users to update player fancy card and headshot images for players Allows users to update player fancy card and headshot images for players
on teams they own. Admins can update any player's images. on teams they own. Admins can update any player's images.
""" """
from typing import List, Tuple from typing import List, Tuple
import asyncio import asyncio
import aiohttp import aiohttp
@ -20,10 +21,11 @@ from utils.decorators import logged_command
from views.embeds import EmbedColors, EmbedTemplate from views.embeds import EmbedColors, EmbedTemplate
from views.base import BaseView from views.base import BaseView
from models.player import Player from models.player import Player
from utils.permissions import is_admin
# URL Validation Functions # URL Validation Functions
def validate_url_format(url: str) -> Tuple[bool, str]: def validate_url_format(url: str) -> Tuple[bool, str]:
""" """
Validate URL format for image links. Validate URL format for image links.
@ -40,17 +42,20 @@ def validate_url_format(url: str) -> Tuple[bool, str]:
return False, "URL too long (max 500 characters)" return False, "URL too long (max 500 characters)"
# Protocol check # Protocol check
if not url.startswith(('http://', 'https://')): if not url.startswith(("http://", "https://")):
return False, "URL must start with http:// or https://" return False, "URL must start with http:// or https://"
# Image extension check # Image extension check
valid_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp') valid_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp")
url_lower = url.lower() url_lower = url.lower()
# Check if URL ends with valid extension (ignore query params) # Check if URL ends with valid extension (ignore query params)
base_url = url_lower.split('?')[0] # Remove query parameters base_url = url_lower.split("?")[0] # Remove query parameters
if not any(base_url.endswith(ext) for ext in valid_extensions): if not any(base_url.endswith(ext) for ext in valid_extensions):
return False, f"URL must end with a valid image extension: {', '.join(valid_extensions)}" return (
False,
f"URL must end with a valid image extension: {', '.join(valid_extensions)}",
)
return True, "" return True, ""
@ -68,14 +73,19 @@ async def check_url_accessibility(url: str) -> Tuple[bool, str]:
""" """
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.head(url, timeout=aiohttp.ClientTimeout(total=5)) as response: async with session.head(
url, timeout=aiohttp.ClientTimeout(total=5)
) as response:
if response.status != 200: if response.status != 200:
return False, f"URL returned status {response.status}" return False, f"URL returned status {response.status}"
# Check content-type header # Check content-type header
content_type = response.headers.get('content-type', '').lower() content_type = response.headers.get("content-type", "").lower()
if content_type and not content_type.startswith('image/'): if content_type and not content_type.startswith("image/"):
return False, f"URL does not return an image (content-type: {content_type})" return (
False,
f"URL does not return an image (content-type: {content_type})",
)
return True, "" return True, ""
@ -89,11 +99,9 @@ async def check_url_accessibility(url: str) -> Tuple[bool, str]:
# Permission Checking # Permission Checking
async def can_edit_player_image( async def can_edit_player_image(
interaction: discord.Interaction, interaction: discord.Interaction, player: Player, season: int, logger
player: Player,
season: int,
logger
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
""" """
Check if user can edit player's image. Check if user can edit player's image.
@ -109,7 +117,7 @@ async def can_edit_player_image(
If has permission, error_message is empty string If has permission, error_message is empty string
""" """
# Admins can edit anyone # Admins can edit anyone
if interaction.user.guild_permissions.administrator: if is_admin(interaction):
logger.debug("User is admin, granting permission", user_id=interaction.user.id) logger.debug("User is admin, granting permission", user_id=interaction.user.id)
return True, "" return True, ""
@ -130,7 +138,7 @@ async def can_edit_player_image(
"User owns organization, granting permission", "User owns organization, granting permission",
user_id=interaction.user.id, user_id=interaction.user.id,
user_team=user_team.abbrev, user_team=user_team.abbrev,
player_team=player.team.abbrev player_team=player.team.abbrev,
) )
return True, "" return True, ""
@ -141,6 +149,7 @@ async def can_edit_player_image(
# Confirmation View # Confirmation View
class ImageUpdateConfirmView(BaseView): class ImageUpdateConfirmView(BaseView):
"""Confirmation view showing image preview before updating.""" """Confirmation view showing image preview before updating."""
@ -151,27 +160,33 @@ class ImageUpdateConfirmView(BaseView):
self.image_type = image_type self.image_type = image_type
self.confirmed = False self.confirmed = False
@discord.ui.button(label="Confirm Update", style=discord.ButtonStyle.success, emoji="") @discord.ui.button(
async def confirm_button(self, interaction: discord.Interaction, button: discord.ui.Button): label="Confirm Update", style=discord.ButtonStyle.success, emoji=""
)
async def confirm_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Confirm the image update.""" """Confirm the image update."""
self.confirmed = True self.confirmed = True
# Disable all buttons # Disable all buttons
for item in self.children: for item in self.children:
if hasattr(item, 'disabled'): if hasattr(item, "disabled"):
item.disabled = True # type: ignore item.disabled = True # type: ignore
await interaction.response.edit_message(view=self) await interaction.response.edit_message(view=self)
self.stop() self.stop()
@discord.ui.button(label="Cancel", style=discord.ButtonStyle.secondary, emoji="") @discord.ui.button(label="Cancel", style=discord.ButtonStyle.secondary, emoji="")
async def cancel_button(self, interaction: discord.Interaction, button: discord.ui.Button): async def cancel_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Cancel the image update.""" """Cancel the image update."""
self.confirmed = False self.confirmed = False
# Disable all buttons # Disable all buttons
for item in self.children: for item in self.children:
if hasattr(item, 'disabled'): if hasattr(item, "disabled"):
item.disabled = True # type: ignore item.disabled = True # type: ignore
await interaction.response.edit_message(view=self) await interaction.response.edit_message(view=self)
@ -180,6 +195,7 @@ class ImageUpdateConfirmView(BaseView):
# Autocomplete # Autocomplete
async def player_name_autocomplete( async def player_name_autocomplete(
interaction: discord.Interaction, interaction: discord.Interaction,
current: str, current: str,
@ -191,6 +207,7 @@ async def player_name_autocomplete(
try: try:
# Use the shared autocomplete utility with team prioritization # Use the shared autocomplete utility with team prioritization
from utils.autocomplete import player_autocomplete from utils.autocomplete import player_autocomplete
return await player_autocomplete(interaction, current) return await player_autocomplete(interaction, current)
except Exception: except Exception:
# Return empty list on error to avoid breaking autocomplete # Return empty list on error to avoid breaking autocomplete
@ -199,27 +216,29 @@ async def player_name_autocomplete(
# Main Command Cog # Main Command Cog
class ImageCommands(commands.Cog): class ImageCommands(commands.Cog):
"""Player image management command handlers.""" """Player image management command handlers."""
def __init__(self, bot: commands.Bot): def __init__(self, bot: commands.Bot):
self.bot = bot self.bot = bot
self.logger = get_contextual_logger(f'{__name__}.ImageCommands') self.logger = get_contextual_logger(f"{__name__}.ImageCommands")
self.logger.info("ImageCommands cog initialized") self.logger.info("ImageCommands cog initialized")
@app_commands.command( @app_commands.command(
name="set-image", name="set-image", description="Update a player's fancy card or headshot image"
description="Update a player's fancy card or headshot image"
) )
@app_commands.describe( @app_commands.describe(
image_type="Type of image to update", image_type="Type of image to update",
player_name="Player name (use autocomplete)", player_name="Player name (use autocomplete)",
image_url="Direct URL to the image file" image_url="Direct URL to the image file",
)
@app_commands.choices(
image_type=[
app_commands.Choice(name="Fancy Card", value="fancy-card"),
app_commands.Choice(name="Headshot", value="headshot"),
]
) )
@app_commands.choices(image_type=[
app_commands.Choice(name="Fancy Card", value="fancy-card"),
app_commands.Choice(name="Headshot", value="headshot")
])
@app_commands.autocomplete(player_name=player_name_autocomplete) @app_commands.autocomplete(player_name=player_name_autocomplete)
@logged_command("/set-image") @logged_command("/set-image")
async def set_image( async def set_image(
@ -227,7 +246,7 @@ class ImageCommands(commands.Cog):
interaction: discord.Interaction, interaction: discord.Interaction,
image_type: app_commands.Choice[str], image_type: app_commands.Choice[str],
player_name: str, player_name: str,
image_url: str image_url: str,
): ):
"""Update a player's image (fancy card or headshot).""" """Update a player's image (fancy card or headshot)."""
# Defer response for potentially slow operations # Defer response for potentially slow operations
@ -242,7 +261,7 @@ class ImageCommands(commands.Cog):
"Image update requested", "Image update requested",
user_id=interaction.user.id, user_id=interaction.user.id,
player_name=player_name, player_name=player_name,
image_type=img_type image_type=img_type,
) )
# Step 1: Validate URL format # Step 1: Validate URL format
@ -252,10 +271,10 @@ class ImageCommands(commands.Cog):
embed = EmbedTemplate.error( embed = EmbedTemplate.error(
title="Invalid URL Format", title="Invalid URL Format",
description=f"{format_error}\n\n" description=f"{format_error}\n\n"
f"**Requirements:**\n" f"**Requirements:**\n"
f"• Must start with `http://` or `https://`\n" f"• Must start with `http://` or `https://`\n"
f"• Must end with `.jpg`, `.jpeg`, `.png`, `.gif`, or `.webp`\n" f"• Must end with `.jpg`, `.jpeg`, `.png`, `.gif`, or `.webp`\n"
f"• Maximum 500 characters" f"• Maximum 500 characters",
) )
await interaction.followup.send(embed=embed, ephemeral=True) await interaction.followup.send(embed=embed, ephemeral=True)
return return
@ -268,24 +287,26 @@ class ImageCommands(commands.Cog):
embed = EmbedTemplate.error( embed = EmbedTemplate.error(
title="URL Not Accessible", title="URL Not Accessible",
description=f"{access_error}\n\n" description=f"{access_error}\n\n"
f"**Please check:**\n" f"**Please check:**\n"
f"• URL is correct and not expired\n" f"• URL is correct and not expired\n"
f"• Image host is online\n" f"• Image host is online\n"
f"• URL points directly to an image file\n" f"• URL points directly to an image file\n"
f"• URL is publicly accessible" f"• URL is publicly accessible",
) )
await interaction.followup.send(embed=embed, ephemeral=True) await interaction.followup.send(embed=embed, ephemeral=True)
return return
# Step 3: Find player # Step 3: Find player
self.logger.debug("Searching for player", player_name=player_name) self.logger.debug("Searching for player", player_name=player_name)
players = await player_service.get_players_by_name(player_name, get_config().sba_season) players = await player_service.get_players_by_name(
player_name, get_config().sba_season
)
if not players: if not players:
self.logger.warning("Player not found", player_name=player_name) self.logger.warning("Player not found", player_name=player_name)
embed = EmbedTemplate.error( embed = EmbedTemplate.error(
title="Player Not Found", title="Player Not Found",
description=f"❌ No player found matching `{player_name}` in the current season." description=f"❌ No player found matching `{player_name}` in the current season.",
) )
await interaction.followup.send(embed=embed, ephemeral=True) await interaction.followup.send(embed=embed, ephemeral=True)
return return
@ -303,11 +324,13 @@ class ImageCommands(commands.Cog):
if player is None: if player is None:
# Multiple candidates, ask user to be more specific # Multiple candidates, ask user to be more specific
player_list = "\n".join([f"{p.name} ({p.primary_position})" for p in players[:10]]) player_list = "\n".join(
[f"{p.name} ({p.primary_position})" for p in players[:10]]
)
embed = EmbedTemplate.info( embed = EmbedTemplate.info(
title="Multiple Players Found", title="Multiple Players Found",
description=f"🔍 Multiple players match `{player_name}`:\n\n{player_list}\n\n" description=f"🔍 Multiple players match `{player_name}`:\n\n{player_list}\n\n"
f"Please use the exact name from autocomplete." f"Please use the exact name from autocomplete.",
) )
await interaction.followup.send(embed=embed, ephemeral=True) await interaction.followup.send(embed=embed, ephemeral=True)
return return
@ -324,12 +347,12 @@ class ImageCommands(commands.Cog):
"Permission denied", "Permission denied",
user_id=interaction.user.id, user_id=interaction.user.id,
player_id=player.id, player_id=player.id,
error=permission_error error=permission_error,
) )
embed = EmbedTemplate.error( embed = EmbedTemplate.error(
title="Permission Denied", title="Permission Denied",
description=f"{permission_error}\n\n" description=f"{permission_error}\n\n"
f"You can only update images for players on teams you own." f"You can only update images for players on teams you own.",
) )
await interaction.followup.send(embed=embed, ephemeral=True) await interaction.followup.send(embed=embed, ephemeral=True)
return return
@ -339,52 +362,46 @@ class ImageCommands(commands.Cog):
preview_embed = EmbedTemplate.create_base_embed( preview_embed = EmbedTemplate.create_base_embed(
title=f"🖼️ Update {display_name} for {player.name}", title=f"🖼️ Update {display_name} for {player.name}",
description=f"Preview the new {display_name.lower()} below. Click **Confirm Update** to save this change.", description=f"Preview the new {display_name.lower()} below. Click **Confirm Update** to save this change.",
color=EmbedColors.INFO color=EmbedColors.INFO,
) )
# Add current image info # Add current image info
current_image = getattr(player, field_name, None) current_image = getattr(player, field_name, None)
if current_image: if current_image:
preview_embed.add_field( preview_embed.add_field(
name="Current Image", name="Current Image", value="Will be replaced", inline=True
value="Will be replaced",
inline=True
) )
else: else:
preview_embed.add_field( preview_embed.add_field(name="Current Image", value="None set", inline=True)
name="Current Image",
value="None set",
inline=True
)
# Add player info # Add player info
preview_embed.add_field( preview_embed.add_field(
name="Player", name="Player",
value=f"{player.name} ({player.primary_position})", value=f"{player.name} ({player.primary_position})",
inline=True inline=True,
) )
if hasattr(player, 'team') and player.team: if hasattr(player, "team") and player.team:
preview_embed.add_field( preview_embed.add_field(name="Team", value=player.team.abbrev, inline=True)
name="Team",
value=player.team.abbrev,
inline=True
)
# Set the new image as thumbnail for preview # Set the new image as thumbnail for preview
preview_embed.set_thumbnail(url=image_url) preview_embed.set_thumbnail(url=image_url)
preview_embed.set_footer(text="This preview shows how the image will appear. Confirm to save.") preview_embed.set_footer(
text="This preview shows how the image will appear. Confirm to save."
)
# Create confirmation view # Create confirmation view
confirm_view = ImageUpdateConfirmView( confirm_view = ImageUpdateConfirmView(
player=player, player=player,
image_url=image_url, image_url=image_url,
image_type=img_type, image_type=img_type,
user_id=interaction.user.id user_id=interaction.user.id,
) )
await interaction.followup.send(embed=preview_embed, view=confirm_view, ephemeral=True) await interaction.followup.send(
embed=preview_embed, view=confirm_view, ephemeral=True
)
# Wait for confirmation # Wait for confirmation
await confirm_view.wait() await confirm_view.wait()
@ -393,7 +410,7 @@ class ImageCommands(commands.Cog):
self.logger.info("Image update cancelled by user", player_id=player.id) self.logger.info("Image update cancelled by user", player_id=player.id)
cancelled_embed = EmbedTemplate.info( cancelled_embed = EmbedTemplate.info(
title="Update Cancelled", title="Update Cancelled",
description=f"No changes were made to {player.name}'s {display_name.lower()}." description=f"No changes were made to {player.name}'s {display_name.lower()}.",
) )
await interaction.edit_original_response(embed=cancelled_embed, view=None) await interaction.edit_original_response(embed=cancelled_embed, view=None)
return return
@ -403,7 +420,7 @@ class ImageCommands(commands.Cog):
"Updating player image", "Updating player image",
player_id=player.id, player_id=player.id,
field=field_name, field=field_name,
url_length=len(image_url) url_length=len(image_url),
) )
update_data = {field_name: image_url} update_data = {field_name: image_url}
@ -413,7 +430,7 @@ class ImageCommands(commands.Cog):
self.logger.error("Failed to update player", player_id=player.id) self.logger.error("Failed to update player", player_id=player.id)
error_embed = EmbedTemplate.error( error_embed = EmbedTemplate.error(
title="Update Failed", title="Update Failed",
description="❌ An error occurred while updating the player's image. Please try again." description="❌ An error occurred while updating the player's image. Please try again.",
) )
await interaction.edit_original_response(embed=error_embed, view=None) await interaction.edit_original_response(embed=error_embed, view=None)
return return
@ -423,32 +440,24 @@ class ImageCommands(commands.Cog):
"Player image updated successfully", "Player image updated successfully",
player_id=player.id, player_id=player.id,
field=field_name, field=field_name,
user_id=interaction.user.id user_id=interaction.user.id,
) )
success_embed = EmbedTemplate.success( success_embed = EmbedTemplate.success(
title="Image Updated Successfully!", title="Image Updated Successfully!",
description=f"**{display_name}** for **{player.name}** has been updated." description=f"**{display_name}** for **{player.name}** has been updated.",
) )
success_embed.add_field( success_embed.add_field(
name="Player", name="Player",
value=f"{player.name} ({player.primary_position})", value=f"{player.name} ({player.primary_position})",
inline=True inline=True,
) )
if hasattr(player, 'team') and player.team: if hasattr(player, "team") and player.team:
success_embed.add_field( success_embed.add_field(name="Team", value=player.team.abbrev, inline=True)
name="Team",
value=player.team.abbrev,
inline=True
)
success_embed.add_field( success_embed.add_field(name="Image Type", value=display_name, inline=True)
name="Image Type",
value=display_name,
inline=True
)
# Show the new image # Show the new image
success_embed.set_thumbnail(url=image_url) success_embed.set_thumbnail(url=image_url)

View File

@ -3,17 +3,19 @@ Tests for player image management commands.
Covers URL validation, permission checking, and command execution. Covers URL validation, permission checking, and command execution.
""" """
import pytest import pytest
import asyncio import asyncio
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import aiohttp import aiohttp
import discord
from aioresponses import aioresponses from aioresponses import aioresponses
from commands.profile.images import ( from commands.profile.images import (
validate_url_format, validate_url_format,
check_url_accessibility, check_url_accessibility,
can_edit_player_image, can_edit_player_image,
ImageCommands ImageCommands,
) )
from tests.factories import PlayerFactory, TeamFactory from tests.factories import PlayerFactory, TeamFactory
@ -94,7 +96,7 @@ class TestURLAccessibility:
url = "https://example.com/image.jpg" url = "https://example.com/image.jpg"
with aioresponses() as m: with aioresponses() as m:
m.head(url, status=200, headers={'content-type': 'image/jpeg'}) m.head(url, status=200, headers={"content-type": "image/jpeg"})
is_accessible, error = await check_url_accessibility(url) is_accessible, error = await check_url_accessibility(url)
@ -118,7 +120,7 @@ class TestURLAccessibility:
url = "https://example.com/page.html" url = "https://example.com/page.html"
with aioresponses() as m: with aioresponses() as m:
m.head(url, status=200, headers={'content-type': 'text/html'}) m.head(url, status=200, headers={"content-type": "text/html"})
is_accessible, error = await check_url_accessibility(url) is_accessible, error = await check_url_accessibility(url)
@ -157,6 +159,7 @@ class TestPermissionChecking:
async def test_admin_can_edit_any_player(self): async def test_admin_can_edit_any_player(self):
"""Test administrator can edit any player's images.""" """Test administrator can edit any player's images."""
mock_interaction = MagicMock() mock_interaction = MagicMock()
mock_interaction.user = MagicMock(spec=discord.Member)
mock_interaction.user.id = 12345 mock_interaction.user.id = 12345
mock_interaction.user.guild_permissions.administrator = True mock_interaction.user.guild_permissions.administrator = True
@ -186,7 +189,9 @@ class TestPermissionChecking:
mock_logger = MagicMock() mock_logger = MagicMock()
with patch('commands.profile.images.team_service.get_teams_by_owner') as mock_get_teams: with patch(
"commands.profile.images.team_service.get_teams_by_owner"
) as mock_get_teams:
mock_get_teams.return_value = [user_team] mock_get_teams.return_value = [user_team]
has_permission, error = await can_edit_player_image( has_permission, error = await can_edit_player_image(
@ -211,7 +216,9 @@ class TestPermissionChecking:
mock_logger = MagicMock() mock_logger = MagicMock()
with patch('commands.profile.images.team_service.get_teams_by_owner') as mock_get_teams: with patch(
"commands.profile.images.team_service.get_teams_by_owner"
) as mock_get_teams:
mock_get_teams.return_value = [user_team] mock_get_teams.return_value = [user_team]
has_permission, error = await can_edit_player_image( has_permission, error = await can_edit_player_image(
@ -236,7 +243,9 @@ class TestPermissionChecking:
mock_logger = MagicMock() mock_logger = MagicMock()
with patch('commands.profile.images.team_service.get_teams_by_owner') as mock_get_teams: with patch(
"commands.profile.images.team_service.get_teams_by_owner"
) as mock_get_teams:
mock_get_teams.return_value = [user_team] mock_get_teams.return_value = [user_team]
has_permission, error = await can_edit_player_image( has_permission, error = await can_edit_player_image(
@ -258,7 +267,9 @@ class TestPermissionChecking:
mock_logger = MagicMock() mock_logger = MagicMock()
with patch('commands.profile.images.team_service.get_teams_by_owner') as mock_get_teams: with patch(
"commands.profile.images.team_service.get_teams_by_owner"
) as mock_get_teams:
mock_get_teams.return_value = [] mock_get_teams.return_value = []
has_permission, error = await can_edit_player_image( has_permission, error = await can_edit_player_image(
@ -299,7 +310,7 @@ class TestImageCommandsIntegration:
async def test_set_image_command_structure(self, commands_cog): async def test_set_image_command_structure(self, commands_cog):
"""Test that set_image command is properly configured.""" """Test that set_image command is properly configured."""
assert hasattr(commands_cog, 'set_image') assert hasattr(commands_cog, "set_image")
assert commands_cog.set_image.name == "set-image" assert commands_cog.set_image.name == "set-image"
async def test_fancy_card_updates_vanity_card_field(self, commands_cog): async def test_fancy_card_updates_vanity_card_field(self, commands_cog):

View File

@ -7,6 +7,7 @@ servers and user types:
- @league_only: Only available in the league server - @league_only: Only available in the league server
- @requires_team: User must have a team (works with global commands) - @requires_team: User must have a team (works with global commands)
""" """
import logging import logging
from functools import wraps from functools import wraps
from typing import Optional, Callable from typing import Optional, Callable
@ -21,6 +22,7 @@ logger = logging.getLogger(__name__)
class PermissionError(Exception): class PermissionError(Exception):
"""Raised when a user doesn't have permission to use a command.""" """Raised when a user doesn't have permission to use a command."""
pass pass
@ -54,17 +56,16 @@ async def get_user_team(user_id: int) -> Optional[dict]:
# This call is automatically cached by TeamService # This call is automatically cached by TeamService
config = get_config() config = get_config()
team = await team_service.get_team_by_owner( team = await team_service.get_team_by_owner(
owner_id=user_id, owner_id=user_id, season=config.sba_season
season=config.sba_season
) )
if team: if team:
logger.debug(f"User {user_id} has team: {team.lname}") logger.debug(f"User {user_id} has team: {team.lname}")
return { return {
'id': team.id, "id": team.id,
'name': team.lname, "name": team.lname,
'abbrev': team.abbrev, "abbrev": team.abbrev,
'season': team.season "season": team.season,
} }
logger.debug(f"User {user_id} does not have a team") logger.debug(f"User {user_id} does not have a team")
@ -77,6 +78,18 @@ def is_league_server(guild_id: int) -> bool:
return guild_id == config.guild_id return guild_id == config.guild_id
def is_admin(interaction: discord.Interaction) -> bool:
"""Check if the interaction user is a server administrator.
Includes an isinstance guard for discord.Member so this is safe
to call from DM contexts (where guild_permissions is unavailable).
"""
return (
isinstance(interaction.user, discord.Member)
and interaction.user.guild_permissions.administrator
)
def league_only(): def league_only():
""" """
Decorator to restrict a command to the league server only. Decorator to restrict a command to the league server only.
@ -87,14 +100,14 @@ def league_only():
async def team_command(self, interaction: discord.Interaction): async def team_command(self, interaction: discord.Interaction):
# Only executes in league server # Only executes in league server
""" """
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@wraps(func) @wraps(func)
async def wrapper(self, interaction: discord.Interaction, *args, **kwargs): async def wrapper(self, interaction: discord.Interaction, *args, **kwargs):
# Check if in a guild # Check if in a guild
if not interaction.guild: if not interaction.guild:
await interaction.response.send_message( await interaction.response.send_message(
"❌ This command can only be used in a server.", "❌ This command can only be used in a server.", ephemeral=True
ephemeral=True
) )
return return
@ -102,13 +115,14 @@ def league_only():
if not is_league_server(interaction.guild.id): if not is_league_server(interaction.guild.id):
await interaction.response.send_message( await interaction.response.send_message(
"❌ This command is only available in the SBa league server.", "❌ This command is only available in the SBa league server.",
ephemeral=True ephemeral=True,
) )
return return
return await func(self, interaction, *args, **kwargs) return await func(self, interaction, *args, **kwargs)
return wrapper return wrapper
return decorator return decorator
@ -123,6 +137,7 @@ def requires_team():
async def mymoves_command(self, interaction: discord.Interaction): async def mymoves_command(self, interaction: discord.Interaction):
# Only executes if user has a team # Only executes if user has a team
""" """
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@wraps(func) @wraps(func)
async def wrapper(self, interaction: discord.Interaction, *args, **kwargs): async def wrapper(self, interaction: discord.Interaction, *args, **kwargs):
@ -133,29 +148,33 @@ def requires_team():
if team is None: if team is None:
await interaction.response.send_message( await interaction.response.send_message(
"❌ This command requires you to have a team in the SBa league. Contact an admin if you believe this is an error.", "❌ This command requires you to have a team in the SBa league. Contact an admin if you believe this is an error.",
ephemeral=True ephemeral=True,
) )
return return
# Store team info in interaction for command to use # Store team info in interaction for command to use
# This allows commands to access the team without another lookup # This allows commands to access the team without another lookup
interaction.extras['user_team'] = team interaction.extras["user_team"] = team
return await func(self, interaction, *args, **kwargs) return await func(self, interaction, *args, **kwargs)
except Exception as e: except Exception as e:
# Log the error for debugging # Log the error for debugging
logger.error(f"Error checking team ownership for user {interaction.user.id}: {e}", exc_info=True) logger.error(
f"Error checking team ownership for user {interaction.user.id}: {e}",
exc_info=True,
)
# Provide helpful error message to user # Provide helpful error message to user
await interaction.response.send_message( await interaction.response.send_message(
"❌ Unable to verify team ownership due to a temporary error. Please try again in a moment. " "❌ Unable to verify team ownership due to a temporary error. Please try again in a moment. "
"If this persists, contact an admin.", "If this persists, contact an admin.",
ephemeral=True ephemeral=True,
) )
return return
return wrapper return wrapper
return decorator return decorator
@ -170,12 +189,14 @@ def global_command():
async def roll_command(self, interaction: discord.Interaction): async def roll_command(self, interaction: discord.Interaction):
# Available in all servers # Available in all servers
""" """
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@wraps(func) @wraps(func)
async def wrapper(self, interaction: discord.Interaction, *args, **kwargs): async def wrapper(self, interaction: discord.Interaction, *args, **kwargs):
return await func(self, interaction, *args, **kwargs) return await func(self, interaction, *args, **kwargs)
return wrapper return wrapper
return decorator return decorator
@ -190,35 +211,35 @@ def admin_only():
async def sync_command(self, interaction: discord.Interaction): async def sync_command(self, interaction: discord.Interaction):
# Only executes for admins # Only executes for admins
""" """
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@wraps(func) @wraps(func)
async def wrapper(self, interaction: discord.Interaction, *args, **kwargs): async def wrapper(self, interaction: discord.Interaction, *args, **kwargs):
# Check if user is guild admin # Check if user is guild admin
if not interaction.guild: if not interaction.guild:
await interaction.response.send_message( await interaction.response.send_message(
"❌ This command can only be used in a server.", "❌ This command can only be used in a server.", ephemeral=True
ephemeral=True
) )
return return
# Check if user has admin permissions # Check if user has admin permissions
if not isinstance(interaction.user, discord.Member): if not isinstance(interaction.user, discord.Member):
await interaction.response.send_message( await interaction.response.send_message(
"❌ Unable to verify permissions.", "❌ Unable to verify permissions.", ephemeral=True
ephemeral=True
) )
return return
if not interaction.user.guild_permissions.administrator: if not interaction.user.guild_permissions.administrator:
await interaction.response.send_message( await interaction.response.send_message(
"❌ This command requires administrator permissions.", "❌ This command requires administrator permissions.",
ephemeral=True ephemeral=True,
) )
return return
return await func(self, interaction, *args, **kwargs) return await func(self, interaction, *args, **kwargs)
return wrapper return wrapper
return decorator return decorator
@ -241,6 +262,7 @@ def league_admin_only():
async def admin_sync_prefix(self, ctx: commands.Context): async def admin_sync_prefix(self, ctx: commands.Context):
# Only league server admins can use this # Only league server admins can use this
""" """
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@wraps(func) @wraps(func)
async def wrapper(self, ctx_or_interaction, *args, **kwargs): async def wrapper(self, ctx_or_interaction, *args, **kwargs):
@ -254,6 +276,7 @@ def league_admin_only():
async def send_error(msg: str): async def send_error(msg: str):
await ctx.send(msg) await ctx.send(msg)
else: else:
interaction = ctx_or_interaction interaction = ctx_or_interaction
guild = interaction.guild guild = interaction.guild
@ -269,7 +292,9 @@ def league_admin_only():
# Check if league server # Check if league server
if not is_league_server(guild.id): if not is_league_server(guild.id):
await send_error("❌ This command is only available in the SBa league server.") await send_error(
"❌ This command is only available in the SBa league server."
)
return return
# Check admin permissions # Check admin permissions
@ -284,4 +309,5 @@ def league_admin_only():
return await func(self, ctx_or_interaction, *args, **kwargs) return await func(self, ctx_or_interaction, *args, **kwargs)
return wrapper return wrapper
return decorator return decorator