diff --git a/.gitea/workflows/docker-build.yml b/.gitea/workflows/docker-build.yml index 98c497a..0aa562a 100644 --- a/.gitea/workflows/docker-build.yml +++ b/.gitea/workflows/docker-build.yml @@ -1,9 +1,9 @@ # Gitea Actions: Docker Build, Push, and Notify # # CI/CD pipeline for Major Domo Discord Bot: -# - Builds Docker images on every push/PR +# - Builds Docker images on merge to main/next-release # - Auto-generates CalVer version (YYYY.MM.BUILD) on main branch merges -# - Supports multi-channel releases: stable (main), rc (next-release), dev (PRs) +# - Supports multi-channel releases: stable (main), rc (next-release) # - Pushes to Docker Hub and creates git tag on main # - Sends Discord notifications on success/failure @@ -14,9 +14,6 @@ on: branches: - main - next-release - pull_request: - branches: - - main jobs: build: diff --git a/.gitignore b/.gitignore index 9500d65..28d68a6 100644 --- a/.gitignore +++ b/.gitignore @@ -218,5 +218,6 @@ __marimo__/ # Project-specific data/ +storage/ production_logs/ *.json diff --git a/bot.py b/bot.py index 3500209..735e875 100644 --- a/bot.py +++ b/bot.py @@ -42,7 +42,9 @@ def setup_logging(): # JSON file handler - structured logging for monitoring and analysis json_handler = RotatingFileHandler( - "logs/discord_bot_v2.json", maxBytes=5 * 1024 * 1024, backupCount=5 # 5MB + "logs/discord_bot_v2.json", + maxBytes=5 * 1024 * 1024, + backupCount=5, # 5MB ) json_handler.setFormatter(JSONFormatter()) logger.addHandler(json_handler) @@ -120,28 +122,11 @@ class SBABot(commands.Bot): self.maintenance_mode: bool = False self.logger = logging.getLogger("discord_bot_v2") - self.maintenance_mode: bool = False async def setup_hook(self): """Called when the bot is starting up.""" self.logger.info("Setting up bot...") - @self.tree.interaction_check - async def maintenance_check(interaction: discord.Interaction) -> bool: - """Block non-admin users when maintenance mode is enabled.""" - if not self.maintenance_mode: - return True - if ( - isinstance(interaction.user, discord.Member) - and interaction.user.guild_permissions.administrator - ): - return True - await interaction.response.send_message( - "🔧 The bot is currently in maintenance mode. Please try again later.", - ephemeral=True, - ) - return False - # Load command packages await self._load_command_packages() @@ -443,7 +428,9 @@ async def health_command(interaction: discord.Interaction): embed.add_field(name="Bot Status", value="✅ Online", inline=True) embed.add_field(name="API Status", value=api_status, inline=True) embed.add_field(name="Guilds", value=str(guild_count), inline=True) - embed.add_field(name="Latency", value=f"{bot.latency*1000:.1f}ms", inline=True) + embed.add_field( + name="Latency", value=f"{bot.latency * 1000:.1f}ms", inline=True + ) if bot.user: embed.set_footer( diff --git a/commands/admin/management.py b/commands/admin/management.py index b738e2c..2c5d566 100644 --- a/commands/admin/management.py +++ b/commands/admin/management.py @@ -568,14 +568,9 @@ class AdminCommands(commands.Cog): return try: - # Clear all messages from the channel - deleted_count = 0 - async for message in live_scores_channel.history(limit=100): - try: - await message.delete() - deleted_count += 1 - except discord.NotFound: - pass # Message already deleted + # Clear all messages from the channel using bulk delete + deleted_messages = await live_scores_channel.purge(limit=100) + deleted_count = len(deleted_messages) self.logger.info(f"Cleared {deleted_count} messages from #live-sba-scores") diff --git a/commands/gameplay/scorebug.py b/commands/gameplay/scorebug.py index dee4780..114e48f 100644 --- a/commands/gameplay/scorebug.py +++ b/commands/gameplay/scorebug.py @@ -4,6 +4,7 @@ Scorebug Commands Implements commands for publishing and displaying live game scorebugs from Google Sheets scorecards. """ +import asyncio import discord from discord.ext import commands from discord import app_commands @@ -73,12 +74,18 @@ class ScorebugCommands(commands.Cog): return # Get team data for display - away_team = None - home_team = None - if scorebug_data.away_team_id: - away_team = await team_service.get_team(scorebug_data.away_team_id) - if scorebug_data.home_team_id: - home_team = await team_service.get_team(scorebug_data.home_team_id) + away_team, home_team = await asyncio.gather( + ( + team_service.get_team(scorebug_data.away_team_id) + if scorebug_data.away_team_id + else asyncio.sleep(0) + ), + ( + team_service.get_team(scorebug_data.home_team_id) + if scorebug_data.home_team_id + else asyncio.sleep(0) + ), + ) # Format scorecard link away_abbrev = away_team.abbrev if away_team else "AWAY" @@ -86,7 +93,7 @@ class ScorebugCommands(commands.Cog): scorecard_link = f"[{away_abbrev} @ {home_abbrev}]({url})" # Store the scorecard in the tracker - self.scorecard_tracker.publish_scorecard( + await self.scorecard_tracker.publish_scorecard( text_channel_id=interaction.channel_id, # type: ignore sheet_url=url, publisher_id=interaction.user.id, @@ -157,7 +164,7 @@ class ScorebugCommands(commands.Cog): await interaction.response.defer(ephemeral=True) # Check if a scorecard is published in this channel - sheet_url = self.scorecard_tracker.get_scorecard(interaction.channel_id) # type: ignore + sheet_url = await self.scorecard_tracker.get_scorecard(interaction.channel_id) # type: ignore if not sheet_url: embed = EmbedTemplate.error( @@ -179,12 +186,18 @@ class ScorebugCommands(commands.Cog): ) # Get team data - away_team = None - home_team = None - if scorebug_data.away_team_id: - away_team = await team_service.get_team(scorebug_data.away_team_id) - if scorebug_data.home_team_id: - home_team = await team_service.get_team(scorebug_data.home_team_id) + away_team, home_team = await asyncio.gather( + ( + team_service.get_team(scorebug_data.away_team_id) + if scorebug_data.away_team_id + else asyncio.sleep(0) + ), + ( + team_service.get_team(scorebug_data.home_team_id) + if scorebug_data.home_team_id + else asyncio.sleep(0) + ), + ) # Create scorebug embed using shared utility embed = create_scorebug_embed( @@ -194,7 +207,7 @@ class ScorebugCommands(commands.Cog): await interaction.edit_original_response(content=None, embed=embed) # Update timestamp in tracker - self.scorecard_tracker.update_timestamp(interaction.channel_id) # type: ignore + await self.scorecard_tracker.update_timestamp(interaction.channel_id) # type: ignore except SheetsException as e: embed = EmbedTemplate.error( diff --git a/commands/gameplay/scorecard_tracker.py b/commands/gameplay/scorecard_tracker.py index 8b2a674..b5fd6db 100644 --- a/commands/gameplay/scorecard_tracker.py +++ b/commands/gameplay/scorecard_tracker.py @@ -24,7 +24,7 @@ class ScorecardTracker: - Timestamp tracking for monitoring """ - def __init__(self, data_file: str = "data/scorecards.json"): + def __init__(self, data_file: str = "storage/scorecards.json"): """ Initialize the scorecard tracker. diff --git a/commands/injuries/management.py b/commands/injuries/management.py index d43802b..decf251 100644 --- a/commands/injuries/management.py +++ b/commands/injuries/management.py @@ -11,6 +11,7 @@ The injury rating format (#p##) encodes both games played and rating: - Remaining: Injury rating (p70, p65, p60, p50, p40, p30, p20) """ +import asyncio import math import random import discord @@ -114,16 +115,14 @@ class InjuryGroup(app_commands.Group): """Roll for injury using 3d6 dice and injury tables.""" await interaction.response.defer() - # Get current season - current = await league_service.get_current_state() + # Get current season and search for player in parallel + current, players = await asyncio.gather( + league_service.get_current_state(), + player_service.search_players(player_name, limit=10), + ) if not current: raise BotException("Failed to get current season information") - # Search for player using the search endpoint (more reliable than name param) - players = await player_service.search_players( - player_name, limit=10, season=current.season - ) - if not players: embed = EmbedTemplate.error( title="Player Not Found", @@ -530,16 +529,14 @@ class InjuryGroup(app_commands.Group): await interaction.followup.send(embed=embed, ephemeral=True) return - # Get current season - current = await league_service.get_current_state() + # Get current season and search for player in parallel + current, players = await asyncio.gather( + league_service.get_current_state(), + player_service.search_players(player_name, limit=10), + ) if not current: raise BotException("Failed to get current season information") - # Search for player using the search endpoint (more reliable than name param) - players = await player_service.search_players( - player_name, limit=10, season=current.season - ) - if not players: embed = EmbedTemplate.error( title="Player Not Found", @@ -717,16 +714,14 @@ class InjuryGroup(app_commands.Group): await interaction.response.defer() - # Get current season - current = await league_service.get_current_state() + # Get current season and search for player in parallel + current, players = await asyncio.gather( + league_service.get_current_state(), + player_service.search_players(player_name, limit=10), + ) if not current: raise BotException("Failed to get current season information") - # Search for player using the search endpoint (more reliable than name param) - players = await player_service.search_players( - player_name, limit=10, season=current.season - ) - if not players: embed = EmbedTemplate.error( title="Player Not Found", diff --git a/commands/league/schedule.py b/commands/league/schedule.py index 26dd224..7644d84 100644 --- a/commands/league/schedule.py +++ b/commands/league/schedule.py @@ -3,6 +3,7 @@ League Schedule Commands Implements slash commands for displaying game schedules and results. """ + from typing import Optional import asyncio @@ -19,19 +20,16 @@ from views.embeds import EmbedColors, EmbedTemplate class ScheduleCommands(commands.Cog): """League schedule command handlers.""" - + def __init__(self, bot: commands.Bot): self.bot = bot - self.logger = get_contextual_logger(f'{__name__}.ScheduleCommands') - - @discord.app_commands.command( - name="schedule", - description="Display game schedule" - ) + self.logger = get_contextual_logger(f"{__name__}.ScheduleCommands") + + @discord.app_commands.command(name="schedule", description="Display game schedule") @discord.app_commands.describe( season="Season to show schedule for (defaults to current season)", week="Week number to show (optional)", - team="Team abbreviation to filter by (optional)" + team="Team abbreviation to filter by (optional)", ) @requires_team() @logged_command("/schedule") @@ -40,13 +38,13 @@ class ScheduleCommands(commands.Cog): interaction: discord.Interaction, season: Optional[int] = None, week: Optional[int] = None, - team: Optional[str] = None + team: Optional[str] = None, ): """Display game schedule for a week or team.""" await interaction.response.defer() - + search_season = season or get_config().sba_season - + if team: # Show team schedule await self._show_team_schedule(interaction, search_season, team, week) @@ -56,7 +54,7 @@ class ScheduleCommands(commands.Cog): else: # Show recent/upcoming games await self._show_current_schedule(interaction, search_season) - + # @discord.app_commands.command( # name="results", # description="Display recent game results" @@ -74,282 +72,316 @@ class ScheduleCommands(commands.Cog): # ): # """Display recent game results.""" # await interaction.response.defer() - + # search_season = season or get_config().sba_season - + # if week: # # Show specific week results # games = await schedule_service.get_week_schedule(search_season, week) # completed_games = [game for game in games if game.is_completed] - + # if not completed_games: # await interaction.followup.send( # f"❌ No completed games found for season {search_season}, week {week}.", # ephemeral=True # ) # return - + # embed = await self._create_week_results_embed(completed_games, search_season, week) # await interaction.followup.send(embed=embed) # else: # # Show recent results # recent_games = await schedule_service.get_recent_games(search_season) - + # if not recent_games: # await interaction.followup.send( # f"❌ No recent games found for season {search_season}.", # ephemeral=True # ) # return - + # embed = await self._create_recent_results_embed(recent_games, search_season) # await interaction.followup.send(embed=embed) - - async def _show_week_schedule(self, interaction: discord.Interaction, season: int, week: int): + + async def _show_week_schedule( + self, interaction: discord.Interaction, season: int, week: int + ): """Show schedule for a specific week.""" self.logger.debug("Fetching week schedule", season=season, week=week) - + games = await schedule_service.get_week_schedule(season, week) - + if not games: await interaction.followup.send( - f"❌ No games found for season {season}, week {week}.", - ephemeral=True + f"❌ No games found for season {season}, week {week}.", ephemeral=True ) return - + embed = await self._create_week_schedule_embed(games, season, week) await interaction.followup.send(embed=embed) - - async def _show_team_schedule(self, interaction: discord.Interaction, season: int, team: str, week: Optional[int]): + + async def _show_team_schedule( + self, + interaction: discord.Interaction, + season: int, + team: str, + week: Optional[int], + ): """Show schedule for a specific team.""" self.logger.debug("Fetching team schedule", season=season, team=team, week=week) - + if week: # Show team games for specific week week_games = await schedule_service.get_week_schedule(season, week) team_games = [ - game for game in week_games - if game.away_team.abbrev.upper() == team.upper() or game.home_team.abbrev.upper() == team.upper() + game + for game in week_games + if game.away_team.abbrev.upper() == team.upper() + or game.home_team.abbrev.upper() == team.upper() ] else: # Show team's recent/upcoming games (limited weeks) team_games = await schedule_service.get_team_schedule(season, team, weeks=4) - + if not team_games: week_text = f" for week {week}" if week else "" await interaction.followup.send( f"❌ No games found for team '{team}'{week_text} in season {season}.", - ephemeral=True + ephemeral=True, ) return - + embed = await self._create_team_schedule_embed(team_games, season, team, week) await interaction.followup.send(embed=embed) - - async def _show_current_schedule(self, interaction: discord.Interaction, season: int): + + async def _show_current_schedule( + self, interaction: discord.Interaction, season: int + ): """Show current schedule overview with recent and upcoming games.""" self.logger.debug("Fetching current schedule overview", season=season) - + # Get both recent and upcoming games recent_games, upcoming_games = await asyncio.gather( schedule_service.get_recent_games(season, weeks_back=1), - schedule_service.get_upcoming_games(season, weeks_ahead=1) + schedule_service.get_upcoming_games(season), ) - + if not recent_games and not upcoming_games: await interaction.followup.send( f"❌ No recent or upcoming games found for season {season}.", - ephemeral=True + ephemeral=True, ) return - - embed = await self._create_current_schedule_embed(recent_games, upcoming_games, season) + + embed = await self._create_current_schedule_embed( + recent_games, upcoming_games, season + ) await interaction.followup.send(embed=embed) - - async def _create_week_schedule_embed(self, games, season: int, week: int) -> discord.Embed: + + async def _create_week_schedule_embed( + self, games, season: int, week: int + ) -> discord.Embed: """Create an embed for a week's schedule.""" embed = EmbedTemplate.create_base_embed( title=f"📅 Week {week} Schedule - Season {season}", - color=EmbedColors.PRIMARY + color=EmbedColors.PRIMARY, ) - + # Group games by series series_games = schedule_service.group_games_by_series(games) - + schedule_lines = [] for (team1, team2), series in series_games.items(): series_summary = await self._format_series_summary(series) schedule_lines.append(f"**{team1} vs {team2}**\n{series_summary}") - + if schedule_lines: embed.add_field( - name="Games", - value="\n\n".join(schedule_lines), - inline=False + name="Games", value="\n\n".join(schedule_lines), inline=False ) - + # Add week summary completed = len([g for g in games if g.is_completed]) total = len(games) embed.add_field( name="Week Progress", value=f"{completed}/{total} games completed", - inline=True + inline=True, ) - + embed.set_footer(text=f"Season {season} • Week {week}") return embed - - async def _create_team_schedule_embed(self, games, season: int, team: str, week: Optional[int]) -> discord.Embed: + + async def _create_team_schedule_embed( + self, games, season: int, team: str, week: Optional[int] + ) -> discord.Embed: """Create an embed for a team's schedule.""" week_text = f" - Week {week}" if week else "" embed = EmbedTemplate.create_base_embed( title=f"📅 {team.upper()} Schedule{week_text} - Season {season}", - color=EmbedColors.PRIMARY + color=EmbedColors.PRIMARY, ) - + # Separate completed and upcoming games completed_games = [g for g in games if g.is_completed] upcoming_games = [g for g in games if not g.is_completed] - + if completed_games: recent_lines = [] for game in completed_games[-5:]: # Last 5 games - result = "W" if game.winner and game.winner.abbrev.upper() == team.upper() else "L" + result = ( + "W" + if game.winner and game.winner.abbrev.upper() == team.upper() + else "L" + ) if game.home_team.abbrev.upper() == team.upper(): # Team was home - recent_lines.append(f"Week {game.week}: {result} vs {game.away_team.abbrev} ({game.score_display})") + recent_lines.append( + f"Week {game.week}: {result} vs {game.away_team.abbrev} ({game.score_display})" + ) else: - # Team was away - recent_lines.append(f"Week {game.week}: {result} @ {game.home_team.abbrev} ({game.score_display})") - + # Team was away + recent_lines.append( + f"Week {game.week}: {result} @ {game.home_team.abbrev} ({game.score_display})" + ) + embed.add_field( name="Recent Results", value="\n".join(recent_lines) if recent_lines else "No recent games", - inline=False + inline=False, ) - + if upcoming_games: upcoming_lines = [] for game in upcoming_games[:5]: # Next 5 games if game.home_team.abbrev.upper() == team.upper(): # Team is home - upcoming_lines.append(f"Week {game.week}: vs {game.away_team.abbrev}") + upcoming_lines.append( + f"Week {game.week}: vs {game.away_team.abbrev}" + ) else: # Team is away - upcoming_lines.append(f"Week {game.week}: @ {game.home_team.abbrev}") - + upcoming_lines.append( + f"Week {game.week}: @ {game.home_team.abbrev}" + ) + embed.add_field( name="Upcoming Games", - value="\n".join(upcoming_lines) if upcoming_lines else "No upcoming games", - inline=False + value=( + "\n".join(upcoming_lines) if upcoming_lines else "No upcoming games" + ), + inline=False, ) - + embed.set_footer(text=f"Season {season} • {team.upper()}") return embed - - async def _create_week_results_embed(self, games, season: int, week: int) -> discord.Embed: + + async def _create_week_results_embed( + self, games, season: int, week: int + ) -> discord.Embed: """Create an embed for week results.""" embed = EmbedTemplate.create_base_embed( - title=f"🏆 Week {week} Results - Season {season}", - color=EmbedColors.SUCCESS + title=f"🏆 Week {week} Results - Season {season}", color=EmbedColors.SUCCESS ) - + # Group by series and show results series_games = schedule_service.group_games_by_series(games) - + results_lines = [] for (team1, team2), series in series_games.items(): # Count wins for each team - team1_wins = len([g for g in series if g.winner and g.winner.abbrev == team1]) - team2_wins = len([g for g in series if g.winner and g.winner.abbrev == team2]) - + team1_wins = len( + [g for g in series if g.winner and g.winner.abbrev == team1] + ) + team2_wins = len( + [g for g in series if g.winner and g.winner.abbrev == team2] + ) + # Series result series_result = f"**{team1} {team1_wins}-{team2_wins} {team2}**" - + # Individual games game_details = [] for game in series: if game.series_game_display: - game_details.append(f"{game.series_game_display}: {game.matchup_display}") - + game_details.append( + f"{game.series_game_display}: {game.matchup_display}" + ) + results_lines.append(f"{series_result}\n" + "\n".join(game_details)) - + if results_lines: embed.add_field( - name="Series Results", - value="\n\n".join(results_lines), - inline=False + name="Series Results", value="\n\n".join(results_lines), inline=False ) - - embed.set_footer(text=f"Season {season} • Week {week} • {len(games)} games completed") + + embed.set_footer( + text=f"Season {season} • Week {week} • {len(games)} games completed" + ) return embed - + async def _create_recent_results_embed(self, games, season: int) -> discord.Embed: """Create an embed for recent results.""" embed = EmbedTemplate.create_base_embed( - title=f"🏆 Recent Results - Season {season}", - color=EmbedColors.SUCCESS + title=f"🏆 Recent Results - Season {season}", color=EmbedColors.SUCCESS ) - + # Show most recent games recent_lines = [] for game in games[:10]: # Show last 10 games recent_lines.append(f"Week {game.week}: {game.matchup_display}") - + if recent_lines: embed.add_field( - name="Latest Games", - value="\n".join(recent_lines), - inline=False + name="Latest Games", value="\n".join(recent_lines), inline=False ) - + embed.set_footer(text=f"Season {season} • Last {len(games)} completed games") return embed - - async def _create_current_schedule_embed(self, recent_games, upcoming_games, season: int) -> discord.Embed: + + async def _create_current_schedule_embed( + self, recent_games, upcoming_games, season: int + ) -> discord.Embed: """Create an embed for current schedule overview.""" embed = EmbedTemplate.create_base_embed( - title=f"📅 Current Schedule - Season {season}", - color=EmbedColors.INFO + title=f"📅 Current Schedule - Season {season}", color=EmbedColors.INFO ) - + if recent_games: recent_lines = [] for game in recent_games[:5]: recent_lines.append(f"Week {game.week}: {game.matchup_display}") - + embed.add_field( - name="Recent Results", - value="\n".join(recent_lines), - inline=False + name="Recent Results", value="\n".join(recent_lines), inline=False ) - + if upcoming_games: upcoming_lines = [] for game in upcoming_games[:5]: upcoming_lines.append(f"Week {game.week}: {game.matchup_display}") - + embed.add_field( - name="Upcoming Games", - value="\n".join(upcoming_lines), - inline=False + name="Upcoming Games", value="\n".join(upcoming_lines), inline=False ) - + embed.set_footer(text=f"Season {season}") return embed - + async def _format_series_summary(self, series) -> str: """Format a series summary.""" lines = [] for game in series: - game_display = f"{game.series_game_display}: {game.matchup_display}" if game.series_game_display else game.matchup_display + game_display = ( + f"{game.series_game_display}: {game.matchup_display}" + if game.series_game_display + else game.matchup_display + ) lines.append(game_display) - + return "\n".join(lines) if lines else "No games" async def setup(bot: commands.Bot): """Load the schedule commands cog.""" - await bot.add_cog(ScheduleCommands(bot)) \ No newline at end of file + await bot.add_cog(ScheduleCommands(bot)) diff --git a/commands/league/submit_scorecard.py b/commands/league/submit_scorecard.py index caa2985..d416a8b 100644 --- a/commands/league/submit_scorecard.py +++ b/commands/league/submit_scorecard.py @@ -5,6 +5,7 @@ Implements the /submit-scorecard command for submitting Google Sheets scorecards with play-by-play data, pitching decisions, and game results. """ +import asyncio from typing import Optional, List import discord @@ -107,11 +108,13 @@ class SubmitScorecardCommands(commands.Cog): content="🔍 Looking up teams and managers..." ) - away_team = await team_service.get_team_by_abbrev( - setup_data["away_team_abbrev"], current.season - ) - home_team = await team_service.get_team_by_abbrev( - setup_data["home_team_abbrev"], current.season + away_team, home_team = await asyncio.gather( + team_service.get_team_by_abbrev( + setup_data["away_team_abbrev"], current.season + ), + team_service.get_team_by_abbrev( + setup_data["home_team_abbrev"], current.season + ), ) if not away_team or not home_team: @@ -235,9 +238,13 @@ class SubmitScorecardCommands(commands.Cog): decision["game_num"] = setup_data["game_num"] # Validate WP and LP exist and fetch Player objects - wp, lp, sv, holders, _blown_saves = ( - await decision_service.find_winning_losing_pitchers(decisions_data) - ) + ( + wp, + lp, + sv, + holders, + _blown_saves, + ) = await decision_service.find_winning_losing_pitchers(decisions_data) if wp is None or lp is None: await interaction.edit_original_response( diff --git a/commands/soak/tracker.py b/commands/soak/tracker.py index f084a6a..4708a5e 100644 --- a/commands/soak/tracker.py +++ b/commands/soak/tracker.py @@ -3,13 +3,14 @@ Soak Tracker Provides persistent tracking of "soak" mentions using JSON file storage. """ + import json import logging from datetime import datetime, timedelta, UTC from pathlib import Path from typing import Dict, List, Optional, Any -logger = logging.getLogger(f'{__name__}.SoakTracker') +logger = logging.getLogger(f"{__name__}.SoakTracker") class SoakTracker: @@ -22,7 +23,7 @@ class SoakTracker: - Time-based calculations for disappointment tiers """ - def __init__(self, data_file: str = "data/soak_data.json"): + def __init__(self, data_file: str = "storage/soak_data.json"): """ Initialize the soak tracker. @@ -38,28 +39,22 @@ class SoakTracker: """Load soak data from JSON file.""" try: if self.data_file.exists(): - with open(self.data_file, 'r') as f: + with open(self.data_file, "r") as f: self._data = json.load(f) - logger.debug(f"Loaded soak data: {self._data.get('total_count', 0)} total soaks") + logger.debug( + f"Loaded soak data: {self._data.get('total_count', 0)} total soaks" + ) else: - self._data = { - "last_soak": None, - "total_count": 0, - "history": [] - } + self._data = {"last_soak": None, "total_count": 0, "history": []} logger.info("No existing soak data found, starting fresh") except Exception as e: logger.error(f"Failed to load soak data: {e}") - self._data = { - "last_soak": None, - "total_count": 0, - "history": [] - } + self._data = {"last_soak": None, "total_count": 0, "history": []} def save_data(self) -> None: """Save soak data to JSON file.""" try: - with open(self.data_file, 'w') as f: + with open(self.data_file, "w") as f: json.dump(self._data, f, indent=2, default=str) logger.debug("Soak data saved successfully") except Exception as e: @@ -71,7 +66,7 @@ class SoakTracker: username: str, display_name: str, channel_id: int, - message_id: int + message_id: int, ) -> None: """ Record a new soak mention. @@ -89,7 +84,7 @@ class SoakTracker: "username": username, "display_name": display_name, "channel_id": str(channel_id), - "message_id": str(message_id) + "message_id": str(message_id), } # Update last_soak @@ -110,7 +105,9 @@ class SoakTracker: self.save_data() - logger.info(f"Recorded soak by {username} (ID: {user_id}) in channel {channel_id}") + logger.info( + f"Recorded soak by {username} (ID: {user_id}) in channel {channel_id}" + ) def get_last_soak(self) -> Optional[Dict[str, Any]]: """ @@ -135,10 +132,12 @@ class SoakTracker: try: # Parse ISO format timestamp last_timestamp_str = last_soak["timestamp"] - if last_timestamp_str.endswith('Z'): - last_timestamp_str = last_timestamp_str[:-1] + '+00:00' + if last_timestamp_str.endswith("Z"): + last_timestamp_str = last_timestamp_str[:-1] + "+00:00" - last_timestamp = datetime.fromisoformat(last_timestamp_str.replace('Z', '+00:00')) + last_timestamp = datetime.fromisoformat( + last_timestamp_str.replace("Z", "+00:00") + ) # Ensure both times are timezone-aware if last_timestamp.tzinfo is None: diff --git a/commands/transactions/management.py b/commands/transactions/management.py index b740b03..45d7e33 100644 --- a/commands/transactions/management.py +++ b/commands/transactions/management.py @@ -3,6 +3,7 @@ Transaction Management Commands Core transaction commands for roster management and transaction tracking. """ + from typing import Optional import asyncio @@ -21,6 +22,7 @@ from views.base import PaginationView from services.transaction_service import transaction_service from services.roster_service import roster_service from services.team_service import team_service + # No longer need TransactionStatus enum @@ -34,25 +36,28 @@ class TransactionPaginationView(PaginationView): all_transactions: list, user_id: int, timeout: float = 300.0, - show_page_numbers: bool = True + show_page_numbers: bool = True, ): super().__init__( pages=pages, user_id=user_id, timeout=timeout, - show_page_numbers=show_page_numbers + show_page_numbers=show_page_numbers, ) self.all_transactions = all_transactions - @discord.ui.button(label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1) - async def show_move_ids(self, interaction: discord.Interaction, button: discord.ui.Button): + @discord.ui.button( + label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1 + ) + async def show_move_ids( + self, interaction: discord.Interaction, button: discord.ui.Button + ): """Show all move IDs in an ephemeral message.""" self.increment_interaction_count() if not self.all_transactions: await interaction.response.send_message( - "No transactions to show.", - ephemeral=True + "No transactions to show.", ephemeral=True ) return @@ -85,8 +90,7 @@ class TransactionPaginationView(PaginationView): # Send the messages if not messages: await interaction.response.send_message( - "No transactions to display.", - ephemeral=True + "No transactions to display.", ephemeral=True ) return @@ -101,14 +105,13 @@ class TransactionPaginationView(PaginationView): class TransactionCommands(commands.Cog): """Transaction command handlers for roster management.""" - + def __init__(self, bot: commands.Bot): self.bot = bot - self.logger = get_contextual_logger(f'{__name__}.TransactionCommands') - + self.logger = get_contextual_logger(f"{__name__}.TransactionCommands") + @app_commands.command( - name="mymoves", - description="View your pending and scheduled transactions" + name="mymoves", description="View your pending and scheduled transactions" ) @app_commands.describe( show_cancelled="Include cancelled transactions in the display (default: False)" @@ -116,39 +119,45 @@ class TransactionCommands(commands.Cog): @requires_team() @logged_command("/mymoves") async def my_moves( - self, - interaction: discord.Interaction, - show_cancelled: bool = False + self, interaction: discord.Interaction, show_cancelled: bool = False ): """Display user's transaction status and history.""" await interaction.response.defer() - + # Get user's team - team = await get_user_major_league_team(interaction.user.id, get_config().sba_season) - + team = await get_user_major_league_team( + interaction.user.id, get_config().sba_season + ) + if not team: await interaction.followup.send( "❌ You don't appear to own a team in the current season.", - ephemeral=True + ephemeral=True, ) return - + # Get transactions in parallel - pending_task = transaction_service.get_pending_transactions(team.abbrev, get_config().sba_season) - frozen_task = transaction_service.get_frozen_transactions(team.abbrev, get_config().sba_season) - processed_task = transaction_service.get_processed_transactions(team.abbrev, get_config().sba_season) - - pending_transactions = await pending_task - frozen_transactions = await frozen_task - processed_transactions = await processed_task - + ( + pending_transactions, + frozen_transactions, + processed_transactions, + ) = await asyncio.gather( + transaction_service.get_pending_transactions( + team.abbrev, get_config().sba_season + ), + transaction_service.get_frozen_transactions( + team.abbrev, get_config().sba_season + ), + transaction_service.get_processed_transactions( + team.abbrev, get_config().sba_season + ), + ) + # Get cancelled if requested cancelled_transactions = [] if show_cancelled: cancelled_transactions = await transaction_service.get_team_transactions( - team.abbrev, - get_config().sba_season, - cancelled=True + team.abbrev, get_config().sba_season, cancelled=True ) pages = self._create_my_moves_pages( @@ -156,15 +165,15 @@ class TransactionCommands(commands.Cog): pending_transactions, frozen_transactions, processed_transactions, - cancelled_transactions + cancelled_transactions, ) # Collect all transactions for the "Show Move IDs" button all_transactions = ( - pending_transactions + - frozen_transactions + - processed_transactions + - cancelled_transactions + pending_transactions + + frozen_transactions + + processed_transactions + + cancelled_transactions ) # If only one page and no transactions, send without any buttons @@ -177,93 +186,90 @@ class TransactionCommands(commands.Cog): all_transactions=all_transactions, user_id=interaction.user.id, timeout=300.0, - show_page_numbers=True + show_page_numbers=True, ) await interaction.followup.send(embed=view.get_current_embed(), view=view) - + @app_commands.command( - name="legal", - description="Check roster legality for current and next week" - ) - @app_commands.describe( - team="Team abbreviation to check (defaults to your team)" + name="legal", description="Check roster legality for current and next week" ) + @app_commands.describe(team="Team abbreviation to check (defaults to your team)") @requires_team() @logged_command("/legal") - async def legal( - self, - interaction: discord.Interaction, - team: Optional[str] = None - ): + async def legal(self, interaction: discord.Interaction, team: Optional[str] = None): """Check roster legality and display detailed validation results.""" await interaction.response.defer() - + # Get target team if team: - target_team = await team_service.get_team_by_abbrev(team.upper(), get_config().sba_season) + target_team = await team_service.get_team_by_abbrev( + team.upper(), get_config().sba_season + ) if not target_team: await interaction.followup.send( f"❌ Could not find team '{team}' in season {get_config().sba_season}.", - ephemeral=True + ephemeral=True, ) return else: # Get user's team - user_teams = await team_service.get_teams_by_owner(interaction.user.id, get_config().sba_season) + user_teams = await team_service.get_teams_by_owner( + interaction.user.id, get_config().sba_season + ) if not user_teams: await interaction.followup.send( "❌ You don't appear to own a team. Please specify a team abbreviation.", - ephemeral=True + ephemeral=True, ) return target_team = user_teams[0] - + # Get rosters in parallel current_roster, next_roster = await asyncio.gather( roster_service.get_current_roster(target_team.id), - roster_service.get_next_roster(target_team.id) + roster_service.get_next_roster(target_team.id), ) - + if not current_roster and not next_roster: await interaction.followup.send( f"❌ Could not retrieve roster data for {target_team.abbrev}.", - ephemeral=True + ephemeral=True, ) return - + # Validate rosters in parallel validation_tasks = [] if current_roster: validation_tasks.append(roster_service.validate_roster(current_roster)) else: validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task - + if next_roster: validation_tasks.append(roster_service.validate_roster(next_roster)) else: validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task - + validation_results = await asyncio.gather(*validation_tasks) current_validation = validation_results[0] if current_roster else None next_validation = validation_results[1] if next_roster else None - + embed = await self._create_legal_embed( target_team, current_roster, - next_roster, + next_roster, current_validation, - next_validation + next_validation, ) - + await interaction.followup.send(embed=embed) - + def _create_my_moves_pages( self, team, pending_transactions, frozen_transactions, processed_transactions, - cancelled_transactions + cancelled_transactions, ) -> list[discord.Embed]: """Create paginated embeds showing user's transaction status.""" @@ -277,7 +283,9 @@ class TransactionCommands(commands.Cog): # Page 1: Summary + Pending Transactions if pending_transactions: total_pending = len(pending_transactions) - total_pages = (total_pending + transactions_per_page - 1) // transactions_per_page + total_pages = ( + total_pending + transactions_per_page - 1 + ) // transactions_per_page for page_num in range(total_pages): start_idx = page_num * transactions_per_page @@ -287,11 +295,11 @@ class TransactionCommands(commands.Cog): embed = EmbedTemplate.create_base_embed( title=f"📋 Transaction Status - {team.abbrev}", description=f"{team.lname} • Season {get_config().sba_season}", - color=EmbedColors.INFO + color=EmbedColors.INFO, ) # Add team thumbnail if available - if hasattr(team, 'thumbnail') and team.thumbnail: + if hasattr(team, "thumbnail") and team.thumbnail: embed.set_thumbnail(url=team.thumbnail) # Pending transactions for this page @@ -300,7 +308,7 @@ class TransactionCommands(commands.Cog): embed.add_field( name=f"⏳ Pending Transactions ({total_pending} total)", value="\n".join(pending_lines), - inline=False + inline=False, ) # Add summary only on first page @@ -314,8 +322,12 @@ class TransactionCommands(commands.Cog): embed.add_field( name="Summary", - value=", ".join(status_text) if status_text else "No active transactions", - inline=True + value=( + ", ".join(status_text) + if status_text + else "No active transactions" + ), + inline=True, ) pages.append(embed) @@ -324,16 +336,16 @@ class TransactionCommands(commands.Cog): embed = EmbedTemplate.create_base_embed( title=f"📋 Transaction Status - {team.abbrev}", description=f"{team.lname} • Season {get_config().sba_season}", - color=EmbedColors.INFO + color=EmbedColors.INFO, ) - if hasattr(team, 'thumbnail') and team.thumbnail: + if hasattr(team, "thumbnail") and team.thumbnail: embed.set_thumbnail(url=team.thumbnail) embed.add_field( name="⏳ Pending Transactions", value="No pending transactions", - inline=False + inline=False, ) total_frozen = len(frozen_transactions) @@ -343,8 +355,10 @@ class TransactionCommands(commands.Cog): embed.add_field( name="Summary", - value=", ".join(status_text) if status_text else "No active transactions", - inline=True + value=( + ", ".join(status_text) if status_text else "No active transactions" + ), + inline=True, ) pages.append(embed) @@ -354,10 +368,10 @@ class TransactionCommands(commands.Cog): embed = EmbedTemplate.create_base_embed( title=f"📋 Transaction Status - {team.abbrev}", description=f"{team.lname} • Season {get_config().sba_season}", - color=EmbedColors.INFO + color=EmbedColors.INFO, ) - if hasattr(team, 'thumbnail') and team.thumbnail: + if hasattr(team, "thumbnail") and team.thumbnail: embed.set_thumbnail(url=team.thumbnail) frozen_lines = [format_transaction(tx) for tx in frozen_transactions] @@ -365,7 +379,7 @@ class TransactionCommands(commands.Cog): embed.add_field( name=f"❄️ Scheduled for Processing ({len(frozen_transactions)} total)", value="\n".join(frozen_lines), - inline=False + inline=False, ) pages.append(embed) @@ -375,18 +389,20 @@ class TransactionCommands(commands.Cog): embed = EmbedTemplate.create_base_embed( title=f"📋 Transaction Status - {team.abbrev}", description=f"{team.lname} • Season {get_config().sba_season}", - color=EmbedColors.INFO + color=EmbedColors.INFO, ) - if hasattr(team, 'thumbnail') and team.thumbnail: + if hasattr(team, "thumbnail") and team.thumbnail: embed.set_thumbnail(url=team.thumbnail) - processed_lines = [format_transaction(tx) for tx in processed_transactions[-20:]] # Last 20 + processed_lines = [ + format_transaction(tx) for tx in processed_transactions[-20:] + ] # Last 20 embed.add_field( name=f"✅ Recently Processed ({len(processed_transactions[-20:])} shown)", value="\n".join(processed_lines), - inline=False + inline=False, ) pages.append(embed) @@ -396,18 +412,20 @@ class TransactionCommands(commands.Cog): embed = EmbedTemplate.create_base_embed( title=f"📋 Transaction Status - {team.abbrev}", description=f"{team.lname} • Season {get_config().sba_season}", - color=EmbedColors.INFO + color=EmbedColors.INFO, ) - if hasattr(team, 'thumbnail') and team.thumbnail: + if hasattr(team, "thumbnail") and team.thumbnail: embed.set_thumbnail(url=team.thumbnail) - cancelled_lines = [format_transaction(tx) for tx in cancelled_transactions[-20:]] # Last 20 + cancelled_lines = [ + format_transaction(tx) for tx in cancelled_transactions[-20:] + ] # Last 20 embed.add_field( name=f"❌ Cancelled Transactions ({len(cancelled_transactions[-20:])} shown)", value="\n".join(cancelled_lines), - inline=False + inline=False, ) pages.append(embed) @@ -417,111 +435,106 @@ class TransactionCommands(commands.Cog): page.set_footer(text="Use /legal to check roster legality") return pages - + async def _create_legal_embed( - self, - team, - current_roster, - next_roster, - current_validation, - next_validation + self, team, current_roster, next_roster, current_validation, next_validation ) -> discord.Embed: """Create embed showing roster legality check results.""" - + # Determine overall status overall_legal = True if current_validation and not current_validation.is_legal: overall_legal = False if next_validation and not next_validation.is_legal: overall_legal = False - + status_emoji = "✅" if overall_legal else "❌" embed_color = EmbedColors.SUCCESS if overall_legal else EmbedColors.ERROR - + embed = EmbedTemplate.create_base_embed( title=f"{status_emoji} Roster Check - {team.abbrev}", description=f"{team.lname} • Season {get_config().sba_season}", - color=embed_color + color=embed_color, ) - + # Add team thumbnail if available - if hasattr(team, 'thumbnail') and team.thumbnail: + if hasattr(team, "thumbnail") and team.thumbnail: embed.set_thumbnail(url=team.thumbnail) - + # Current week roster if current_roster and current_validation: current_lines = [] - current_lines.append(f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL") + current_lines.append( + f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL" + ) current_lines.append(f"**sWAR:** {current_validation.total_sWAR:.2f}") - + if current_validation.errors: current_lines.append(f"**❌ Errors:** {len(current_validation.errors)}") for error in current_validation.errors[:3]: # Show first 3 errors current_lines.append(f"• {error}") - + if current_validation.warnings: - current_lines.append(f"**⚠️ Warnings:** {len(current_validation.warnings)}") + current_lines.append( + f"**⚠️ Warnings:** {len(current_validation.warnings)}" + ) for warning in current_validation.warnings[:2]: # Show first 2 warnings current_lines.append(f"• {warning}") - + embed.add_field( name=f"{current_validation.status_emoji} Current Week", value="\n".join(current_lines), - inline=True + inline=True, ) else: embed.add_field( - name="❓ Current Week", - value="Roster data not available", - inline=True + name="❓ Current Week", value="Roster data not available", inline=True ) - - # Next week roster + + # Next week roster if next_roster and next_validation: next_lines = [] - next_lines.append(f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL") + next_lines.append( + f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL" + ) next_lines.append(f"**sWAR:** {next_validation.total_sWAR:.2f}") - + if next_validation.errors: next_lines.append(f"**❌ Errors:** {len(next_validation.errors)}") for error in next_validation.errors[:3]: # Show first 3 errors next_lines.append(f"• {error}") - + if next_validation.warnings: next_lines.append(f"**⚠️ Warnings:** {len(next_validation.warnings)}") for warning in next_validation.warnings[:2]: # Show first 2 warnings next_lines.append(f"• {warning}") - + embed.add_field( name=f"{next_validation.status_emoji} Next Week", value="\n".join(next_lines), - inline=True + inline=True, ) else: embed.add_field( - name="❓ Next Week", - value="Roster data not available", - inline=True + name="❓ Next Week", value="Roster data not available", inline=True ) - + # Overall status if overall_legal: embed.add_field( - name="Overall Status", - value="✅ All rosters are legal", - inline=False + name="Overall Status", value="✅ All rosters are legal", inline=False ) else: embed.add_field( - name="Overall Status", + name="Overall Status", value="❌ Roster violations found - please review and correct", - inline=False + inline=False, ) - + embed.set_footer(text="Roster validation based on current league rules") return embed async def setup(bot: commands.Bot): """Load the transaction commands cog.""" - await bot.add_cog(TransactionCommands(bot)) \ No newline at end of file + await bot.add_cog(TransactionCommands(bot)) diff --git a/commands/transactions/trade_channel_tracker.py b/commands/transactions/trade_channel_tracker.py index f3d34c3..1399649 100644 --- a/commands/transactions/trade_channel_tracker.py +++ b/commands/transactions/trade_channel_tracker.py @@ -3,6 +3,7 @@ Trade Channel Tracker Provides persistent tracking of bot-created trade discussion channels using JSON file storage. """ + import json from datetime import datetime, UTC from pathlib import Path @@ -12,7 +13,7 @@ import discord from utils.logging import get_contextual_logger -logger = get_contextual_logger(f'{__name__}.TradeChannelTracker') +logger = get_contextual_logger(f"{__name__}.TradeChannelTracker") class TradeChannelTracker: @@ -26,7 +27,7 @@ class TradeChannelTracker: - Automatic stale entry removal """ - def __init__(self, data_file: str = "data/trade_channels.json"): + def __init__(self, data_file: str = "storage/trade_channels.json"): """ Initialize the trade channel tracker. @@ -42,9 +43,11 @@ class TradeChannelTracker: """Load channel data from JSON file.""" try: if self.data_file.exists(): - with open(self.data_file, 'r') as f: + with open(self.data_file, "r") as f: self._data = json.load(f) - logger.debug(f"Loaded {len(self._data.get('trade_channels', {}))} tracked trade channels") + logger.debug( + f"Loaded {len(self._data.get('trade_channels', {}))} tracked trade channels" + ) else: self._data = {"trade_channels": {}} logger.info("No existing trade channel data found, starting fresh") @@ -55,7 +58,7 @@ class TradeChannelTracker: def save_data(self) -> None: """Save channel data to JSON file.""" try: - with open(self.data_file, 'w') as f: + with open(self.data_file, "w") as f: json.dump(self._data, f, indent=2, default=str) logger.debug("Trade channel data saved successfully") except Exception as e: @@ -67,7 +70,7 @@ class TradeChannelTracker: trade_id: str, team1_abbrev: str, team2_abbrev: str, - creator_id: int + creator_id: int, ) -> None: """ Add a new trade channel to tracking. @@ -87,10 +90,12 @@ class TradeChannelTracker: "team1_abbrev": team1_abbrev, "team2_abbrev": team2_abbrev, "created_at": datetime.now(UTC).isoformat(), - "creator_id": str(creator_id) + "creator_id": str(creator_id), } self.save_data() - logger.info(f"Added trade channel to tracking: {channel.name} (ID: {channel.id}, Trade: {trade_id})") + logger.info( + f"Added trade channel to tracking: {channel.name} (ID: {channel.id}, Trade: {trade_id})" + ) def remove_channel(self, channel_id: int) -> None: """ @@ -108,7 +113,9 @@ class TradeChannelTracker: channel_name = channel_data["name"] del channels[channel_key] self.save_data() - logger.info(f"Removed trade channel from tracking: {channel_name} (ID: {channel_id}, Trade: {trade_id})") + logger.info( + f"Removed trade channel from tracking: {channel_name} (ID: {channel_id}, Trade: {trade_id})" + ) def get_channel_by_trade_id(self, trade_id: str) -> Optional[Dict[str, Any]]: """ @@ -175,7 +182,9 @@ class TradeChannelTracker: channel_name = channels[channel_id_str].get("name", "unknown") trade_id = channels[channel_id_str].get("trade_id", "unknown") del channels[channel_id_str] - logger.info(f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str}, Trade: {trade_id})") + logger.info( + f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str}, Trade: {trade_id})" + ) if stale_entries: self.save_data() diff --git a/commands/voice/cleanup_service.py b/commands/voice/cleanup_service.py index ab8792c..c7343d9 100644 --- a/commands/voice/cleanup_service.py +++ b/commands/voice/cleanup_service.py @@ -3,6 +3,7 @@ Voice Channel Cleanup Service Provides automatic cleanup of empty voice channels with restart resilience. """ + import logging import discord @@ -12,7 +13,7 @@ from .tracker import VoiceChannelTracker from commands.gameplay.scorecard_tracker import ScorecardTracker from utils.logging import get_contextual_logger -logger = logging.getLogger(f'{__name__}.VoiceChannelCleanupService') +logger = logging.getLogger(f"{__name__}.VoiceChannelCleanupService") class VoiceChannelCleanupService: @@ -27,7 +28,9 @@ class VoiceChannelCleanupService: - Automatic scorecard unpublishing when voice channel is cleaned up """ - def __init__(self, bot: commands.Bot, data_file: str = "data/voice_channels.json"): + def __init__( + self, bot: commands.Bot, data_file: str = "storage/voice_channels.json" + ): """ Initialize the cleanup service. @@ -36,10 +39,10 @@ class VoiceChannelCleanupService: data_file: Path to the JSON data file for persistence """ self.bot = bot - self.logger = get_contextual_logger(f'{__name__}.VoiceChannelCleanupService') + self.logger = get_contextual_logger(f"{__name__}.VoiceChannelCleanupService") self.tracker = VoiceChannelTracker(data_file) self.scorecard_tracker = ScorecardTracker() - self.empty_threshold = 5 # Delete after 5 minutes empty + self.empty_threshold = 5 # Delete after 5 minutes empty # Start the cleanup task - @before_loop will wait for bot readiness self.cleanup_loop.start() @@ -90,13 +93,17 @@ class VoiceChannelCleanupService: guild = bot.get_guild(guild_id) if not guild: - self.logger.warning(f"Guild {guild_id} not found, removing channel {channel_data['name']}") + self.logger.warning( + f"Guild {guild_id} not found, removing channel {channel_data['name']}" + ) channels_to_remove.append(channel_id) continue channel = guild.get_channel(channel_id) if not channel: - self.logger.warning(f"Channel {channel_data['name']} (ID: {channel_id}) no longer exists") + self.logger.warning( + f"Channel {channel_data['name']} (ID: {channel_id}) no longer exists" + ) channels_to_remove.append(channel_id) continue @@ -121,18 +128,26 @@ class VoiceChannelCleanupService: if channel_data and channel_data.get("text_channel_id"): try: text_channel_id_int = int(channel_data["text_channel_id"]) - was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int) + was_unpublished = self.scorecard_tracker.unpublish_scorecard( + text_channel_id_int + ) if was_unpublished: - self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel)") + self.logger.info( + f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel)" + ) except (ValueError, TypeError) as e: - self.logger.warning(f"Invalid text_channel_id in stale voice channel data: {e}") + self.logger.warning( + f"Invalid text_channel_id in stale voice channel data: {e}" + ) # Also clean up any additional stale entries stale_removed = self.tracker.cleanup_stale_entries(valid_channel_ids) total_removed = len(channels_to_remove) + stale_removed if total_removed > 0: - self.logger.info(f"Cleaned up {total_removed} stale channel tracking entries") + self.logger.info( + f"Cleaned up {total_removed} stale channel tracking entries" + ) self.logger.info(f"Verified {len(valid_channel_ids)} valid tracked channels") @@ -149,10 +164,14 @@ class VoiceChannelCleanupService: await self.update_all_channel_statuses(bot) # Get channels ready for cleanup - channels_for_cleanup = self.tracker.get_channels_for_cleanup(self.empty_threshold) + channels_for_cleanup = self.tracker.get_channels_for_cleanup( + self.empty_threshold + ) if channels_for_cleanup: - self.logger.info(f"Found {len(channels_for_cleanup)} channels ready for cleanup") + self.logger.info( + f"Found {len(channels_for_cleanup)} channels ready for cleanup" + ) # Delete empty channels for channel_data in channels_for_cleanup: @@ -182,12 +201,16 @@ class VoiceChannelCleanupService: guild = bot.get_guild(guild_id) if not guild: - self.logger.debug(f"Guild {guild_id} not found for channel {channel_data['name']}") + self.logger.debug( + f"Guild {guild_id} not found for channel {channel_data['name']}" + ) return channel = guild.get_channel(channel_id) if not channel: - self.logger.debug(f"Channel {channel_data['name']} no longer exists, removing from tracking") + self.logger.debug( + f"Channel {channel_data['name']} no longer exists, removing from tracking" + ) self.tracker.remove_channel(channel_id) # Unpublish associated scorecard if it exists @@ -195,17 +218,25 @@ class VoiceChannelCleanupService: if text_channel_id: try: text_channel_id_int = int(text_channel_id) - was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int) + was_unpublished = self.scorecard_tracker.unpublish_scorecard( + text_channel_id_int + ) if was_unpublished: - self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (manually deleted voice channel)") + self.logger.info( + f"📋 Unpublished scorecard from text channel {text_channel_id_int} (manually deleted voice channel)" + ) except (ValueError, TypeError) as e: - self.logger.warning(f"Invalid text_channel_id in manually deleted voice channel data: {e}") + self.logger.warning( + f"Invalid text_channel_id in manually deleted voice channel data: {e}" + ) return # Ensure it's a voice channel before checking members if not isinstance(channel, discord.VoiceChannel): - self.logger.warning(f"Channel {channel_data['name']} is not a voice channel, removing from tracking") + self.logger.warning( + f"Channel {channel_data['name']} is not a voice channel, removing from tracking" + ) self.tracker.remove_channel(channel_id) # Unpublish associated scorecard if it exists @@ -213,11 +244,17 @@ class VoiceChannelCleanupService: if text_channel_id: try: text_channel_id_int = int(text_channel_id) - was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int) + was_unpublished = self.scorecard_tracker.unpublish_scorecard( + text_channel_id_int + ) if was_unpublished: - self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (wrong channel type)") + self.logger.info( + f"📋 Unpublished scorecard from text channel {text_channel_id_int} (wrong channel type)" + ) except (ValueError, TypeError) as e: - self.logger.warning(f"Invalid text_channel_id in wrong channel type data: {e}") + self.logger.warning( + f"Invalid text_channel_id in wrong channel type data: {e}" + ) return @@ -225,11 +262,15 @@ class VoiceChannelCleanupService: is_empty = len(channel.members) == 0 self.tracker.update_channel_status(channel_id, is_empty) - self.logger.debug(f"Channel {channel_data['name']}: {'empty' if is_empty else 'occupied'} " - f"({len(channel.members)} members)") + self.logger.debug( + f"Channel {channel_data['name']}: {'empty' if is_empty else 'occupied'} " + f"({len(channel.members)} members)" + ) except Exception as e: - self.logger.error(f"Error checking channel status for {channel_data.get('name', 'unknown')}: {e}") + self.logger.error( + f"Error checking channel status for {channel_data.get('name', 'unknown')}: {e}" + ) async def cleanup_channel(self, bot: commands.Bot, channel_data: dict) -> None: """ @@ -246,25 +287,33 @@ class VoiceChannelCleanupService: guild = bot.get_guild(guild_id) if not guild: - self.logger.info(f"Guild {guild_id} not found, removing tracking for {channel_name}") + self.logger.info( + f"Guild {guild_id} not found, removing tracking for {channel_name}" + ) self.tracker.remove_channel(channel_id) return channel = guild.get_channel(channel_id) if not channel: - self.logger.info(f"Channel {channel_name} already deleted, removing from tracking") + self.logger.info( + f"Channel {channel_name} already deleted, removing from tracking" + ) self.tracker.remove_channel(channel_id) return # Ensure it's a voice channel before checking members if not isinstance(channel, discord.VoiceChannel): - self.logger.warning(f"Channel {channel_name} is not a voice channel, removing from tracking") + self.logger.warning( + f"Channel {channel_name} is not a voice channel, removing from tracking" + ) self.tracker.remove_channel(channel_id) return # Final check: make sure channel is still empty before deleting if len(channel.members) > 0: - self.logger.info(f"Channel {channel_name} is no longer empty, skipping cleanup") + self.logger.info( + f"Channel {channel_name} is no longer empty, skipping cleanup" + ) self.tracker.update_channel_status(channel_id, False) return @@ -272,24 +321,36 @@ class VoiceChannelCleanupService: await channel.delete(reason="Automatic cleanup - empty for 5+ minutes") self.tracker.remove_channel(channel_id) - self.logger.info(f"✅ Cleaned up empty voice channel: {channel_name} (ID: {channel_id})") + self.logger.info( + f"✅ Cleaned up empty voice channel: {channel_name} (ID: {channel_id})" + ) # Unpublish associated scorecard if it exists text_channel_id = channel_data.get("text_channel_id") if text_channel_id: try: text_channel_id_int = int(text_channel_id) - was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int) + was_unpublished = self.scorecard_tracker.unpublish_scorecard( + text_channel_id_int + ) if was_unpublished: - self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (voice channel cleanup)") + self.logger.info( + f"📋 Unpublished scorecard from text channel {text_channel_id_int} (voice channel cleanup)" + ) else: - self.logger.debug(f"No scorecard found for text channel {text_channel_id_int}") + self.logger.debug( + f"No scorecard found for text channel {text_channel_id_int}" + ) except (ValueError, TypeError) as e: - self.logger.warning(f"Invalid text_channel_id in voice channel data: {e}") + self.logger.warning( + f"Invalid text_channel_id in voice channel data: {e}" + ) except discord.NotFound: # Channel was already deleted - self.logger.info(f"Channel {channel_data.get('name', 'unknown')} was already deleted") + self.logger.info( + f"Channel {channel_data.get('name', 'unknown')} was already deleted" + ) self.tracker.remove_channel(int(channel_data["channel_id"])) # Still try to unpublish associated scorecard @@ -297,15 +358,25 @@ class VoiceChannelCleanupService: if text_channel_id: try: text_channel_id_int = int(text_channel_id) - was_unpublished = self.scorecard_tracker.unpublish_scorecard(text_channel_id_int) + was_unpublished = self.scorecard_tracker.unpublish_scorecard( + text_channel_id_int + ) if was_unpublished: - self.logger.info(f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel cleanup)") + self.logger.info( + f"📋 Unpublished scorecard from text channel {text_channel_id_int} (stale voice channel cleanup)" + ) except (ValueError, TypeError) as e: - self.logger.warning(f"Invalid text_channel_id in voice channel data: {e}") + self.logger.warning( + f"Invalid text_channel_id in voice channel data: {e}" + ) except discord.Forbidden: - self.logger.error(f"Missing permissions to delete channel {channel_data.get('name', 'unknown')}") + self.logger.error( + f"Missing permissions to delete channel {channel_data.get('name', 'unknown')}" + ) except Exception as e: - self.logger.error(f"Error cleaning up channel {channel_data.get('name', 'unknown')}: {e}") + self.logger.error( + f"Error cleaning up channel {channel_data.get('name', 'unknown')}: {e}" + ) def get_tracker(self) -> VoiceChannelTracker: """ @@ -330,7 +401,7 @@ class VoiceChannelCleanupService: "running": self.cleanup_loop.is_running(), "total_tracked": len(all_channels), "empty_channels": len(empty_channels), - "empty_threshold": self.empty_threshold + "empty_threshold": self.empty_threshold, } @@ -344,4 +415,4 @@ def setup_voice_cleanup(bot: commands.Bot) -> VoiceChannelCleanupService: Returns: VoiceChannelCleanupService instance """ - return VoiceChannelCleanupService(bot) \ No newline at end of file + return VoiceChannelCleanupService(bot) diff --git a/commands/voice/tracker.py b/commands/voice/tracker.py index 4e85080..3002d1d 100644 --- a/commands/voice/tracker.py +++ b/commands/voice/tracker.py @@ -3,6 +3,7 @@ Voice Channel Tracker Provides persistent tracking of bot-created voice channels using JSON file storage. """ + import json import logging from datetime import datetime, timedelta, UTC @@ -11,7 +12,7 @@ from typing import Dict, List, Optional, Any import discord -logger = logging.getLogger(f'{__name__}.VoiceChannelTracker') +logger = logging.getLogger(f"{__name__}.VoiceChannelTracker") class VoiceChannelTracker: @@ -25,7 +26,7 @@ class VoiceChannelTracker: - Automatic stale entry removal """ - def __init__(self, data_file: str = "data/voice_channels.json"): + def __init__(self, data_file: str = "storage/voice_channels.json"): """ Initialize the voice channel tracker. @@ -41,9 +42,11 @@ class VoiceChannelTracker: """Load channel data from JSON file.""" try: if self.data_file.exists(): - with open(self.data_file, 'r') as f: + with open(self.data_file, "r") as f: self._data = json.load(f) - logger.debug(f"Loaded {len(self._data.get('voice_channels', {}))} tracked channels") + logger.debug( + f"Loaded {len(self._data.get('voice_channels', {}))} tracked channels" + ) else: self._data = {"voice_channels": {}} logger.info("No existing voice channel data found, starting fresh") @@ -54,7 +57,7 @@ class VoiceChannelTracker: def save_data(self) -> None: """Save channel data to JSON file.""" try: - with open(self.data_file, 'w') as f: + with open(self.data_file, "w") as f: json.dump(self._data, f, indent=2, default=str) logger.debug("Voice channel data saved successfully") except Exception as e: @@ -65,7 +68,7 @@ class VoiceChannelTracker: channel: discord.VoiceChannel, channel_type: str, creator_id: int, - text_channel_id: Optional[int] = None + text_channel_id: Optional[int] = None, ) -> None: """ Add a new channel to tracking. @@ -85,7 +88,7 @@ class VoiceChannelTracker: "last_checked": datetime.now(UTC).isoformat(), "empty_since": None, "creator_id": str(creator_id), - "text_channel_id": str(text_channel_id) if text_channel_id else None + "text_channel_id": str(text_channel_id) if text_channel_id else None, } self.save_data() logger.info(f"Added channel to tracking: {channel.name} (ID: {channel.id})") @@ -130,9 +133,13 @@ class VoiceChannelTracker: channel_name = channels[channel_key]["name"] del channels[channel_key] self.save_data() - logger.info(f"Removed channel from tracking: {channel_name} (ID: {channel_id})") + logger.info( + f"Removed channel from tracking: {channel_name} (ID: {channel_id})" + ) - def get_channels_for_cleanup(self, empty_threshold_minutes: int = 15) -> List[Dict[str, Any]]: + def get_channels_for_cleanup( + self, empty_threshold_minutes: int = 15 + ) -> List[Dict[str, Any]]: """ Get channels that should be deleted based on empty duration. @@ -153,10 +160,12 @@ class VoiceChannelTracker: # Parse empty_since timestamp empty_since_str = channel_data["empty_since"] # Handle both with and without timezone info - if empty_since_str.endswith('Z'): - empty_since_str = empty_since_str[:-1] + '+00:00' + if empty_since_str.endswith("Z"): + empty_since_str = empty_since_str[:-1] + "+00:00" - empty_since = datetime.fromisoformat(empty_since_str.replace('Z', '+00:00')) + empty_since = datetime.fromisoformat( + empty_since_str.replace("Z", "+00:00") + ) # Remove timezone info for comparison (both times are UTC) if empty_since.tzinfo: @@ -164,10 +173,14 @@ class VoiceChannelTracker: if empty_since <= cutoff_time: cleanup_candidates.append(channel_data) - logger.debug(f"Channel {channel_data['name']} ready for cleanup (empty since {empty_since})") + logger.debug( + f"Channel {channel_data['name']} ready for cleanup (empty since {empty_since})" + ) except (ValueError, TypeError) as e: - logger.warning(f"Invalid timestamp for channel {channel_data.get('name', 'unknown')}: {e}") + logger.warning( + f"Invalid timestamp for channel {channel_data.get('name', 'unknown')}: {e}" + ) return cleanup_candidates @@ -242,9 +255,11 @@ class VoiceChannelTracker: for channel_id_str in stale_entries: channel_name = channels[channel_id_str].get("name", "unknown") del channels[channel_id_str] - logger.info(f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str})") + logger.info( + f"Removed stale tracking entry: {channel_name} (ID: {channel_id_str})" + ) if stale_entries: self.save_data() - return len(stale_entries) \ No newline at end of file + return len(stale_entries) diff --git a/docker-compose.yml b/docker-compose.yml index 7c39ee1..f98d698 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -36,8 +36,11 @@ services: # Volume mounts volumes: - # Google Sheets credentials (required) - - ${SHEETS_CREDENTIALS_HOST_PATH:-./data}:/app/data:ro + # Google Sheets credentials (read-only, file mount) + - ${SHEETS_CREDENTIALS_HOST_PATH:-./data/major-domo-service-creds.json}:/app/data/major-domo-service-creds.json:ro + + # Runtime state files (writable) - scorecards, voice channels, trade channels, soak data + - ${STATE_HOST_PATH:-./storage}:/app/storage:rw # Logs directory (persistent) - mounted to /app/logs where the application expects it - ${LOGS_HOST_PATH:-./logs}:/app/logs:rw diff --git a/services/schedule_service.py b/services/schedule_service.py index 78ee51d..cb3c101 100644 --- a/services/schedule_service.py +++ b/services/schedule_service.py @@ -4,6 +4,7 @@ Schedule service for Discord Bot v2.0 Handles game schedule and results retrieval and processing. """ +import asyncio import logging from typing import Optional, List, Dict, Tuple @@ -102,10 +103,10 @@ class ScheduleService: # If weeks not specified, try a reasonable range (18 weeks typical) week_range = range(1, (weeks + 1) if weeks else 19) - for week in week_range: - week_games = await self.get_week_schedule(season, week) - - # Filter games involving this team + all_week_games = await asyncio.gather( + *[self.get_week_schedule(season, week) for week in week_range] + ) + for week_games in all_week_games: for game in week_games: if ( game.away_team.abbrev.upper() == team_abbrev_upper @@ -135,15 +136,13 @@ class ScheduleService: recent_games = [] # Get games from recent weeks - for week_offset in range(weeks_back): - # This is simplified - in production you'd want to determine current week - week = 10 - week_offset # Assuming we're around week 10 - if week <= 0: - break - - week_games = await self.get_week_schedule(season, week) - - # Only include completed games + weeks_to_fetch = [ + (10 - offset) for offset in range(weeks_back) if (10 - offset) > 0 + ] + all_week_games = await asyncio.gather( + *[self.get_week_schedule(season, week) for week in weeks_to_fetch] + ) + for week_games in all_week_games: completed_games = [game for game in week_games if game.is_completed] recent_games.extend(completed_games) @@ -157,13 +156,12 @@ class ScheduleService: logger.error(f"Error getting recent games: {e}") return [] - async def get_upcoming_games(self, season: int, weeks_ahead: int = 6) -> List[Game]: + async def get_upcoming_games(self, season: int) -> List[Game]: """ - Get upcoming scheduled games by scanning multiple weeks. + Get upcoming scheduled games by scanning all weeks. Args: season: Season number - weeks_ahead: Number of weeks to scan ahead (default 6) Returns: List of upcoming Game instances @@ -171,20 +169,16 @@ class ScheduleService: try: upcoming_games = [] - # Scan through weeks to find games without scores - for week in range(1, 19): # Standard season length - week_games = await self.get_week_schedule(season, week) - - # Find games without scores (not yet played) + # Fetch all weeks in parallel and filter for incomplete games + all_week_games = await asyncio.gather( + *[self.get_week_schedule(season, week) for week in range(1, 19)] + ) + for week_games in all_week_games: upcoming_games_week = [ game for game in week_games if not game.is_completed ] upcoming_games.extend(upcoming_games_week) - # If we found upcoming games, we can limit how many more weeks to check - if upcoming_games and len(upcoming_games) >= 20: # Reasonable limit - break - # Sort by week, then game number upcoming_games.sort(key=lambda x: (x.week, x.game_num or 0)) diff --git a/services/stats_service.py b/services/stats_service.py index a3b3a06..3441288 100644 --- a/services/stats_service.py +++ b/services/stats_service.py @@ -4,6 +4,7 @@ Statistics service for Discord Bot v2.0 Handles batting and pitching statistics retrieval and processing. """ +import asyncio import logging from typing import Optional @@ -144,11 +145,10 @@ class StatsService: """ try: # Get both types of stats concurrently - batting_task = self.get_batting_stats(player_id, season) - pitching_task = self.get_pitching_stats(player_id, season) - - batting_stats = await batting_task - pitching_stats = await pitching_task + batting_stats, pitching_stats = await asyncio.gather( + self.get_batting_stats(player_id, season), + self.get_pitching_stats(player_id, season), + ) logger.debug( f"Retrieved stats for player {player_id}: " diff --git a/services/trade_builder.py b/services/trade_builder.py index aa05b37..26f61aa 100644 --- a/services/trade_builder.py +++ b/services/trade_builder.py @@ -4,6 +4,7 @@ Trade Builder Service Extends the TransactionBuilder to support multi-team trades and player exchanges. """ +import asyncio import logging from typing import Dict, List, Optional, Set from datetime import datetime, timezone @@ -524,14 +525,22 @@ class TradeBuilder: # Validate each team's roster after the trade for participant in self.trade.participants: - team_id = participant.team.id - result.team_abbrevs[team_id] = participant.team.abbrev - if team_id in self._team_builders: - builder = self._team_builders[team_id] - roster_validation = await builder.validate_transaction(next_week) + result.team_abbrevs[participant.team.id] = participant.team.abbrev + team_ids_to_validate = [ + participant.team.id + for participant in self.trade.participants + if participant.team.id in self._team_builders + ] + if team_ids_to_validate: + validations = await asyncio.gather( + *[ + self._team_builders[tid].validate_transaction(next_week) + for tid in team_ids_to_validate + ] + ) + for team_id, roster_validation in zip(team_ids_to_validate, validations): result.participant_validations[team_id] = roster_validation - if not roster_validation.is_legal: result.is_legal = False diff --git a/tasks/live_scorebug_tracker.py b/tasks/live_scorebug_tracker.py index 9013ac2..4a4d355 100644 --- a/tasks/live_scorebug_tracker.py +++ b/tasks/live_scorebug_tracker.py @@ -95,7 +95,7 @@ class LiveScorebugTracker: # Don't return - still update voice channels else: # Get all published scorecards - all_scorecards = self.scorecard_tracker.get_all_scorecards() + all_scorecards = await self.scorecard_tracker.get_all_scorecards() if not all_scorecards: # No active scorebugs - clear the channel and hide it @@ -112,17 +112,16 @@ class LiveScorebugTracker: for text_channel_id, sheet_url in all_scorecards: try: scorebug_data = await self.scorebug_service.read_scorebug_data( - sheet_url, full_length=False # Compact view for live channel + sheet_url, + full_length=False, # Compact view for live channel ) # Only include active (non-final) games if scorebug_data.is_active: # Get team data - away_team = await team_service.get_team( - scorebug_data.away_team_id - ) - home_team = await team_service.get_team( - scorebug_data.home_team_id + away_team, home_team = await asyncio.gather( + team_service.get_team(scorebug_data.away_team_id), + team_service.get_team(scorebug_data.home_team_id), ) if away_team is None or home_team is None: @@ -188,9 +187,8 @@ class LiveScorebugTracker: embeds: List of scorebug embeds """ try: - # Clear old messages - async for message in channel.history(limit=25): - await message.delete() + # Clear old messages using bulk delete + await channel.purge(limit=25) # Post new scorebugs (Discord allows up to 10 embeds per message) if len(embeds) <= 10: @@ -216,9 +214,8 @@ class LiveScorebugTracker: channel: Discord text channel """ try: - # Clear all messages - async for message in channel.history(limit=25): - await message.delete() + # Clear all messages using bulk delete + await channel.purge(limit=25) self.logger.info("Cleared live-sba-scores channel (no active games)") diff --git a/tests/test_services_schedule.py b/tests/test_services_schedule.py new file mode 100644 index 0000000..70e60c7 --- /dev/null +++ b/tests/test_services_schedule.py @@ -0,0 +1,284 @@ +""" +Tests for schedule service functionality. + +Covers get_week_schedule, get_team_schedule, get_recent_games, +get_upcoming_games, and group_games_by_series — verifying the +asyncio.gather parallelization and post-fetch filtering logic. +""" + +import pytest +from unittest.mock import AsyncMock, patch + +from services.schedule_service import ScheduleService +from tests.factories import GameFactory, TeamFactory + + +def _game(game_id, week, away_abbrev, home_abbrev, **kwargs): + """Create a Game with distinct team IDs per matchup.""" + return GameFactory.create( + id=game_id, + week=week, + away_team=TeamFactory.create(id=game_id * 10, abbrev=away_abbrev), + home_team=TeamFactory.create(id=game_id * 10 + 1, abbrev=home_abbrev), + **kwargs, + ) + + +class TestGetWeekSchedule: + """Tests for ScheduleService.get_week_schedule — the HTTP layer.""" + + @pytest.fixture + def service(self): + svc = ScheduleService() + svc.get_client = AsyncMock() + return svc + + @pytest.mark.asyncio + async def test_success(self, service): + """get_week_schedule returns parsed Game objects on a normal response.""" + mock_client = AsyncMock() + mock_client.get.return_value = { + "games": [ + { + "id": 1, + "season": 12, + "week": 5, + "game_num": 1, + "season_type": "regular", + "away_team": { + "id": 10, + "abbrev": "NYY", + "sname": "NYY", + "lname": "New York", + "season": 12, + }, + "home_team": { + "id": 11, + "abbrev": "BOS", + "sname": "BOS", + "lname": "Boston", + "season": 12, + }, + "away_score": 4, + "home_score": 2, + } + ] + } + service.get_client.return_value = mock_client + + games = await service.get_week_schedule(12, 5) + + assert len(games) == 1 + assert games[0].away_team.abbrev == "NYY" + assert games[0].home_team.abbrev == "BOS" + assert games[0].is_completed + + @pytest.mark.asyncio + async def test_empty_response(self, service): + """get_week_schedule returns [] when the API has no games.""" + mock_client = AsyncMock() + mock_client.get.return_value = {"games": []} + service.get_client.return_value = mock_client + + games = await service.get_week_schedule(12, 99) + assert games == [] + + @pytest.mark.asyncio + async def test_api_error_returns_empty(self, service): + """get_week_schedule returns [] on API error (no exception raised).""" + service.get_client.side_effect = Exception("connection refused") + + games = await service.get_week_schedule(12, 1) + assert games == [] + + @pytest.mark.asyncio + async def test_missing_games_key(self, service): + """get_week_schedule returns [] when response lacks 'games' key.""" + mock_client = AsyncMock() + mock_client.get.return_value = {"status": "ok"} + service.get_client.return_value = mock_client + + games = await service.get_week_schedule(12, 1) + assert games == [] + + +class TestGetTeamSchedule: + """Tests for get_team_schedule — gather + team-abbrev filter.""" + + @pytest.fixture + def service(self): + return ScheduleService() + + @pytest.mark.asyncio + async def test_filters_by_team_case_insensitive(self, service): + """get_team_schedule returns only games involving the requested team, + regardless of abbreviation casing.""" + week1 = [ + _game(1, 1, "NYY", "BOS", away_score=3, home_score=1), + _game(2, 1, "LAD", "CHC", away_score=5, home_score=2), + ] + week2 = [ + _game(3, 2, "BOS", "NYY", away_score=2, home_score=4), + ] + + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.side_effect = [week1, week2] + result = await service.get_team_schedule(12, "nyy", weeks=2) + + assert len(result) == 2 + assert all( + g.away_team.abbrev == "NYY" or g.home_team.abbrev == "NYY" for g in result + ) + + @pytest.mark.asyncio + async def test_full_season_fetches_18_weeks(self, service): + """When weeks is None, all 18 weeks are fetched via gather.""" + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.return_value = [] + await service.get_team_schedule(12, "NYY") + + assert mock.call_count == 18 + + @pytest.mark.asyncio + async def test_limited_weeks(self, service): + """When weeks=5, only 5 weeks are fetched.""" + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.return_value = [] + await service.get_team_schedule(12, "NYY", weeks=5) + + assert mock.call_count == 5 + + +class TestGetRecentGames: + """Tests for get_recent_games — gather + completed-only filter.""" + + @pytest.fixture + def service(self): + return ScheduleService() + + @pytest.mark.asyncio + async def test_returns_only_completed_games(self, service): + """get_recent_games filters out games without scores.""" + completed = GameFactory.completed(id=1, week=10) + incomplete = GameFactory.upcoming(id=2, week=10) + + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.return_value = [completed, incomplete] + result = await service.get_recent_games(12, weeks_back=1) + + assert len(result) == 1 + assert result[0].is_completed + + @pytest.mark.asyncio + async def test_sorted_descending_by_week_and_game_num(self, service): + """Recent games are sorted most-recent first.""" + game_w10 = GameFactory.completed(id=1, week=10, game_num=2) + game_w9 = GameFactory.completed(id=2, week=9, game_num=1) + + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.side_effect = [[game_w10], [game_w9]] + result = await service.get_recent_games(12, weeks_back=2) + + assert result[0].week == 10 + assert result[1].week == 9 + + @pytest.mark.asyncio + async def test_skips_negative_weeks(self, service): + """Weeks that would be <= 0 are excluded from fetch.""" + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.return_value = [] + await service.get_recent_games(12, weeks_back=15) + + # weeks_to_fetch = [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] — only 10 valid weeks + assert mock.call_count == 10 + + +class TestGetUpcomingGames: + """Tests for get_upcoming_games — gather all 18 weeks + incomplete filter.""" + + @pytest.fixture + def service(self): + return ScheduleService() + + @pytest.mark.asyncio + async def test_returns_only_incomplete_games(self, service): + """get_upcoming_games filters out completed games.""" + completed = GameFactory.completed(id=1, week=5) + upcoming = GameFactory.upcoming(id=2, week=5) + + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.return_value = [completed, upcoming] + result = await service.get_upcoming_games(12) + + assert len(result) == 18 # 1 incomplete game per week × 18 weeks + assert all(not g.is_completed for g in result) + + @pytest.mark.asyncio + async def test_sorted_ascending_by_week_and_game_num(self, service): + """Upcoming games are sorted earliest first.""" + game_w3 = GameFactory.upcoming(id=1, week=3, game_num=1) + game_w1 = GameFactory.upcoming(id=2, week=1, game_num=2) + + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + + def side_effect(season, week): + if week == 1: + return [game_w1] + if week == 3: + return [game_w3] + return [] + + mock.side_effect = side_effect + result = await service.get_upcoming_games(12) + + assert result[0].week == 1 + assert result[1].week == 3 + + @pytest.mark.asyncio + async def test_fetches_all_18_weeks(self, service): + """All 18 weeks are fetched in parallel (no early exit).""" + with patch.object(service, "get_week_schedule", new_callable=AsyncMock) as mock: + mock.return_value = [] + await service.get_upcoming_games(12) + + assert mock.call_count == 18 + + +class TestGroupGamesBySeries: + """Tests for group_games_by_series — synchronous grouping logic.""" + + @pytest.fixture + def service(self): + return ScheduleService() + + def test_groups_by_alphabetical_pairing(self, service): + """Games between the same two teams are grouped under one key, + with the alphabetically-first team first in the tuple.""" + games = [ + _game(1, 1, "NYY", "BOS", game_num=1), + _game(2, 1, "BOS", "NYY", game_num=2), + _game(3, 1, "LAD", "CHC", game_num=1), + ] + + result = service.group_games_by_series(games) + + assert ("BOS", "NYY") in result + assert len(result[("BOS", "NYY")]) == 2 + assert ("CHC", "LAD") in result + assert len(result[("CHC", "LAD")]) == 1 + + def test_sorted_by_game_num_within_series(self, service): + """Games within each series are sorted by game_num.""" + games = [ + _game(1, 1, "NYY", "BOS", game_num=3), + _game(2, 1, "NYY", "BOS", game_num=1), + _game(3, 1, "NYY", "BOS", game_num=2), + ] + + result = service.group_games_by_series(games) + series = result[("BOS", "NYY")] + assert [g.game_num for g in series] == [1, 2, 3] + + def test_empty_input(self, service): + """Empty games list returns empty dict.""" + assert service.group_games_by_series([]) == {} diff --git a/tests/test_services_stats.py b/tests/test_services_stats.py new file mode 100644 index 0000000..39f89b0 --- /dev/null +++ b/tests/test_services_stats.py @@ -0,0 +1,111 @@ +""" +Tests for StatsService + +Validates stats service functionality including concurrent stat retrieval +and error handling in get_player_stats(). +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from services.stats_service import StatsService + + +class TestStatsServiceGetPlayerStats: + """Test StatsService.get_player_stats() concurrent retrieval.""" + + @pytest.fixture + def service(self): + """Create a fresh StatsService instance for testing.""" + return StatsService() + + @pytest.fixture + def mock_batting_stats(self): + """Create a mock BattingStats object.""" + stats = MagicMock() + stats.avg = 0.300 + return stats + + @pytest.fixture + def mock_pitching_stats(self): + """Create a mock PitchingStats object.""" + stats = MagicMock() + stats.era = 3.50 + return stats + + @pytest.mark.asyncio + async def test_both_stats_returned( + self, service, mock_batting_stats, mock_pitching_stats + ): + """When both batting and pitching stats exist, both are returned. + + Verifies that get_player_stats returns a tuple of (batting, pitching) + when both stat types are available for the player. + """ + service.get_batting_stats = AsyncMock(return_value=mock_batting_stats) + service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats) + + batting, pitching = await service.get_player_stats(player_id=100, season=12) + + assert batting is mock_batting_stats + assert pitching is mock_pitching_stats + service.get_batting_stats.assert_called_once_with(100, 12) + service.get_pitching_stats.assert_called_once_with(100, 12) + + @pytest.mark.asyncio + async def test_batting_only(self, service, mock_batting_stats): + """When only batting stats exist, pitching is None. + + Covers the case of a position player with no pitching record. + """ + service.get_batting_stats = AsyncMock(return_value=mock_batting_stats) + service.get_pitching_stats = AsyncMock(return_value=None) + + batting, pitching = await service.get_player_stats(player_id=200, season=12) + + assert batting is mock_batting_stats + assert pitching is None + + @pytest.mark.asyncio + async def test_pitching_only(self, service, mock_pitching_stats): + """When only pitching stats exist, batting is None. + + Covers the case of a pitcher with no batting record. + """ + service.get_batting_stats = AsyncMock(return_value=None) + service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats) + + batting, pitching = await service.get_player_stats(player_id=300, season=12) + + assert batting is None + assert pitching is mock_pitching_stats + + @pytest.mark.asyncio + async def test_no_stats_found(self, service): + """When no stats exist for the player, both are None. + + Covers the case where a player has no stats for the given season + (e.g., didn't play). + """ + service.get_batting_stats = AsyncMock(return_value=None) + service.get_pitching_stats = AsyncMock(return_value=None) + + batting, pitching = await service.get_player_stats(player_id=400, season=12) + + assert batting is None + assert pitching is None + + @pytest.mark.asyncio + async def test_exception_returns_none_tuple(self, service): + """When an exception occurs, (None, None) is returned. + + The get_player_stats method wraps both calls in a try/except and + returns (None, None) on any error, ensuring callers always get a tuple. + """ + service.get_batting_stats = AsyncMock(side_effect=RuntimeError("API down")) + service.get_pitching_stats = AsyncMock(return_value=None) + + batting, pitching = await service.get_player_stats(player_id=500, season=12) + + assert batting is None + assert pitching is None diff --git a/tests/test_utils_autocomplete.py b/tests/test_utils_autocomplete.py index 50cd651..420a31c 100644 --- a/tests/test_utils_autocomplete.py +++ b/tests/test_utils_autocomplete.py @@ -3,10 +3,16 @@ Tests for shared autocomplete utility functions. Validates the shared autocomplete functions used across multiple command modules. """ + import pytest from unittest.mock import AsyncMock, MagicMock, patch -from utils.autocomplete import player_autocomplete, team_autocomplete, major_league_team_autocomplete +import utils.autocomplete +from utils.autocomplete import ( + player_autocomplete, + team_autocomplete, + major_league_team_autocomplete, +) from tests.factories import PlayerFactory, TeamFactory from models.team import RosterType @@ -14,6 +20,13 @@ from models.team import RosterType class TestPlayerAutocomplete: """Test player autocomplete functionality.""" + @pytest.fixture(autouse=True) + def clear_user_team_cache(self): + """Clear the module-level user team cache before each test to prevent interference.""" + utils.autocomplete._user_team_cache.clear() + yield + utils.autocomplete._user_team_cache.clear() + @pytest.fixture def mock_interaction(self): """Create a mock Discord interaction.""" @@ -26,41 +39,43 @@ class TestPlayerAutocomplete: """Test successful player autocomplete.""" mock_players = [ PlayerFactory.mike_trout(id=1), - PlayerFactory.ronald_acuna(id=2) + PlayerFactory.ronald_acuna(id=2), ] - with patch('utils.autocomplete.player_service') as mock_service: + with patch("utils.autocomplete.player_service") as mock_service: mock_service.search_players = AsyncMock(return_value=mock_players) - choices = await player_autocomplete(mock_interaction, 'Trout') + choices = await player_autocomplete(mock_interaction, "Trout") assert len(choices) == 2 - assert choices[0].name == 'Mike Trout (CF)' - assert choices[0].value == 'Mike Trout' - assert choices[1].name == 'Ronald Acuna Jr. (OF)' - assert choices[1].value == 'Ronald Acuna Jr.' + assert choices[0].name == "Mike Trout (CF)" + assert choices[0].value == "Mike Trout" + assert choices[1].name == "Ronald Acuna Jr. (OF)" + assert choices[1].value == "Ronald Acuna Jr." @pytest.mark.asyncio async def test_player_autocomplete_with_team_info(self, mock_interaction): """Test player autocomplete with team information.""" - mock_team = TeamFactory.create(id=499, abbrev='LAA', sname='Angels', lname='Los Angeles Angels') + mock_team = TeamFactory.create( + id=499, abbrev="LAA", sname="Angels", lname="Los Angeles Angels" + ) mock_player = PlayerFactory.mike_trout(id=1) mock_player.team = mock_team - with patch('utils.autocomplete.player_service') as mock_service: + with patch("utils.autocomplete.player_service") as mock_service: mock_service.search_players = AsyncMock(return_value=[mock_player]) - choices = await player_autocomplete(mock_interaction, 'Trout') + choices = await player_autocomplete(mock_interaction, "Trout") assert len(choices) == 1 - assert choices[0].name == 'Mike Trout (CF - LAA)' - assert choices[0].value == 'Mike Trout' + assert choices[0].name == "Mike Trout (CF - LAA)" + assert choices[0].value == "Mike Trout" @pytest.mark.asyncio async def test_player_autocomplete_prioritizes_user_team(self, mock_interaction): """Test that user's team players are prioritized in autocomplete.""" - user_team = TeamFactory.create(id=1, abbrev='POR', sname='Loggers') - other_team = TeamFactory.create(id=2, abbrev='LAA', sname='Angels') + user_team = TeamFactory.create(id=1, abbrev="POR", sname="Loggers") + other_team = TeamFactory.create(id=2, abbrev="LAA", sname="Angels") # Create players - one from user's team, one from other team user_player = PlayerFactory.mike_trout(id=1) @@ -71,32 +86,35 @@ class TestPlayerAutocomplete: other_player.team = other_team other_player.team_id = other_team.id - with patch('utils.autocomplete.player_service') as mock_service, \ - patch('utils.autocomplete.get_user_major_league_team') as mock_get_team: - - mock_service.search_players = AsyncMock(return_value=[other_player, user_player]) + with ( + patch("utils.autocomplete.player_service") as mock_service, + patch("utils.autocomplete.get_user_major_league_team") as mock_get_team, + ): + mock_service.search_players = AsyncMock( + return_value=[other_player, user_player] + ) mock_get_team.return_value = user_team - choices = await player_autocomplete(mock_interaction, 'player') + choices = await player_autocomplete(mock_interaction, "player") assert len(choices) == 2 # User's team player should be first - assert choices[0].name == 'Mike Trout (CF - POR)' - assert choices[1].name == 'Ronald Acuna Jr. (OF - LAA)' + assert choices[0].name == "Mike Trout (CF - POR)" + assert choices[1].name == "Ronald Acuna Jr. (OF - LAA)" @pytest.mark.asyncio async def test_player_autocomplete_short_input(self, mock_interaction): """Test player autocomplete with short input returns empty.""" - choices = await player_autocomplete(mock_interaction, 'T') + choices = await player_autocomplete(mock_interaction, "T") assert len(choices) == 0 @pytest.mark.asyncio async def test_player_autocomplete_error_handling(self, mock_interaction): """Test player autocomplete error handling.""" - with patch('utils.autocomplete.player_service') as mock_service: + with patch("utils.autocomplete.player_service") as mock_service: mock_service.search_players.side_effect = Exception("API Error") - choices = await player_autocomplete(mock_interaction, 'Trout') + choices = await player_autocomplete(mock_interaction, "Trout") assert len(choices) == 0 @@ -114,35 +132,35 @@ class TestTeamAutocomplete: async def test_team_autocomplete_success(self, mock_interaction): """Test successful team autocomplete.""" mock_teams = [ - TeamFactory.create(id=1, abbrev='LAA', sname='Angels'), - TeamFactory.create(id=2, abbrev='LAAMIL', sname='Salt Lake Bees'), - TeamFactory.create(id=3, abbrev='LAAAIL', sname='Angels IL'), - TeamFactory.create(id=4, abbrev='POR', sname='Loggers') + TeamFactory.create(id=1, abbrev="LAA", sname="Angels"), + TeamFactory.create(id=2, abbrev="LAAMIL", sname="Salt Lake Bees"), + TeamFactory.create(id=3, abbrev="LAAAIL", sname="Angels IL"), + TeamFactory.create(id=4, abbrev="POR", sname="Loggers"), ] - with patch('utils.autocomplete.team_service') as mock_service: + with patch("utils.autocomplete.team_service") as mock_service: mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams) - choices = await team_autocomplete(mock_interaction, 'la') + choices = await team_autocomplete(mock_interaction, "la") assert len(choices) == 3 # All teams with 'la' in abbrev or sname - assert any('LAA' in choice.name for choice in choices) - assert any('LAAMIL' in choice.name for choice in choices) - assert any('LAAAIL' in choice.name for choice in choices) + assert any("LAA" in choice.name for choice in choices) + assert any("LAAMIL" in choice.name for choice in choices) + assert any("LAAAIL" in choice.name for choice in choices) @pytest.mark.asyncio async def test_team_autocomplete_short_input(self, mock_interaction): """Test team autocomplete with very short input.""" - choices = await team_autocomplete(mock_interaction, '') + choices = await team_autocomplete(mock_interaction, "") assert len(choices) == 0 @pytest.mark.asyncio async def test_team_autocomplete_error_handling(self, mock_interaction): """Test team autocomplete error handling.""" - with patch('utils.autocomplete.team_service') as mock_service: + with patch("utils.autocomplete.team_service") as mock_service: mock_service.get_teams_by_season.side_effect = Exception("API Error") - choices = await team_autocomplete(mock_interaction, 'LAA') + choices = await team_autocomplete(mock_interaction, "LAA") assert len(choices) == 0 @@ -157,101 +175,197 @@ class TestMajorLeagueTeamAutocomplete: return interaction @pytest.mark.asyncio - async def test_major_league_team_autocomplete_filters_correctly(self, mock_interaction): + async def test_major_league_team_autocomplete_filters_correctly( + self, mock_interaction + ): """Test that only major league teams are returned.""" # Create teams with different roster types mock_teams = [ - TeamFactory.create(id=1, abbrev='LAA', sname='Angels'), # ML - TeamFactory.create(id=2, abbrev='LAAMIL', sname='Salt Lake Bees'), # MiL - TeamFactory.create(id=3, abbrev='LAAAIL', sname='Angels IL'), # IL - TeamFactory.create(id=4, abbrev='FA', sname='Free Agents'), # FA - TeamFactory.create(id=5, abbrev='POR', sname='Loggers'), # ML - TeamFactory.create(id=6, abbrev='PORMIL', sname='Portland MiL'), # MiL + TeamFactory.create(id=1, abbrev="LAA", sname="Angels"), # ML + TeamFactory.create(id=2, abbrev="LAAMIL", sname="Salt Lake Bees"), # MiL + TeamFactory.create(id=3, abbrev="LAAAIL", sname="Angels IL"), # IL + TeamFactory.create(id=4, abbrev="FA", sname="Free Agents"), # FA + TeamFactory.create(id=5, abbrev="POR", sname="Loggers"), # ML + TeamFactory.create(id=6, abbrev="PORMIL", sname="Portland MiL"), # MiL ] - with patch('utils.autocomplete.team_service') as mock_service: + with patch("utils.autocomplete.team_service") as mock_service: mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams) - choices = await major_league_team_autocomplete(mock_interaction, 'l') + choices = await major_league_team_autocomplete(mock_interaction, "l") # Should only return major league teams that match 'l' (LAA, POR) choice_values = [choice.value for choice in choices] - assert 'LAA' in choice_values - assert 'POR' in choice_values + assert "LAA" in choice_values + assert "POR" in choice_values assert len(choice_values) == 2 # Should NOT include MiL, IL, or FA teams - assert 'LAAMIL' not in choice_values - assert 'LAAAIL' not in choice_values - assert 'FA' not in choice_values - assert 'PORMIL' not in choice_values + assert "LAAMIL" not in choice_values + assert "LAAAIL" not in choice_values + assert "FA" not in choice_values + assert "PORMIL" not in choice_values @pytest.mark.asyncio async def test_major_league_team_autocomplete_matching(self, mock_interaction): """Test search matching on abbreviation and short name.""" mock_teams = [ - TeamFactory.create(id=1, abbrev='LAA', sname='Angels'), - TeamFactory.create(id=2, abbrev='LAD', sname='Dodgers'), - TeamFactory.create(id=3, abbrev='POR', sname='Loggers'), - TeamFactory.create(id=4, abbrev='BOS', sname='Red Sox'), + TeamFactory.create(id=1, abbrev="LAA", sname="Angels"), + TeamFactory.create(id=2, abbrev="LAD", sname="Dodgers"), + TeamFactory.create(id=3, abbrev="POR", sname="Loggers"), + TeamFactory.create(id=4, abbrev="BOS", sname="Red Sox"), ] - with patch('utils.autocomplete.team_service') as mock_service: + with patch("utils.autocomplete.team_service") as mock_service: mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams) # Test abbreviation matching - choices = await major_league_team_autocomplete(mock_interaction, 'la') + choices = await major_league_team_autocomplete(mock_interaction, "la") assert len(choices) == 2 # LAA and LAD choice_values = [choice.value for choice in choices] - assert 'LAA' in choice_values - assert 'LAD' in choice_values + assert "LAA" in choice_values + assert "LAD" in choice_values # Test short name matching - choices = await major_league_team_autocomplete(mock_interaction, 'red') + choices = await major_league_team_autocomplete(mock_interaction, "red") assert len(choices) == 1 - assert choices[0].value == 'BOS' + assert choices[0].value == "BOS" @pytest.mark.asyncio async def test_major_league_team_autocomplete_short_input(self, mock_interaction): """Test major league team autocomplete with very short input.""" - choices = await major_league_team_autocomplete(mock_interaction, '') + choices = await major_league_team_autocomplete(mock_interaction, "") assert len(choices) == 0 @pytest.mark.asyncio - async def test_major_league_team_autocomplete_error_handling(self, mock_interaction): + async def test_major_league_team_autocomplete_error_handling( + self, mock_interaction + ): """Test major league team autocomplete error handling.""" - with patch('utils.autocomplete.team_service') as mock_service: + with patch("utils.autocomplete.team_service") as mock_service: mock_service.get_teams_by_season.side_effect = Exception("API Error") - choices = await major_league_team_autocomplete(mock_interaction, 'LAA') + choices = await major_league_team_autocomplete(mock_interaction, "LAA") assert len(choices) == 0 @pytest.mark.asyncio - async def test_major_league_team_autocomplete_roster_type_detection(self, mock_interaction): + async def test_major_league_team_autocomplete_roster_type_detection( + self, mock_interaction + ): """Test that roster type detection works correctly for edge cases.""" # Test edge cases like teams whose abbreviation ends in 'M' + 'IL' mock_teams = [ - TeamFactory.create(id=1, abbrev='BHM', sname='Iron'), # ML team ending in 'M' - TeamFactory.create(id=2, abbrev='BHMIL', sname='Iron IL'), # IL team (BHM + IL) - TeamFactory.create(id=3, abbrev='NYYMIL', sname='Staten Island RailRiders'), # MiL team (NYY + MIL) - TeamFactory.create(id=4, abbrev='NYY', sname='Yankees'), # ML team + TeamFactory.create( + id=1, abbrev="BHM", sname="Iron" + ), # ML team ending in 'M' + TeamFactory.create( + id=2, abbrev="BHMIL", sname="Iron IL" + ), # IL team (BHM + IL) + TeamFactory.create( + id=3, abbrev="NYYMIL", sname="Staten Island RailRiders" + ), # MiL team (NYY + MIL) + TeamFactory.create(id=4, abbrev="NYY", sname="Yankees"), # ML team ] - with patch('utils.autocomplete.team_service') as mock_service: + with patch("utils.autocomplete.team_service") as mock_service: mock_service.get_teams_by_season = AsyncMock(return_value=mock_teams) - choices = await major_league_team_autocomplete(mock_interaction, 'b') + choices = await major_league_team_autocomplete(mock_interaction, "b") # Should only return major league teams choice_values = [choice.value for choice in choices] - assert 'BHM' in choice_values # Major league team - assert 'BHMIL' not in choice_values # Should be detected as IL, not MiL - assert 'NYYMIL' not in choice_values # Minor league team + assert "BHM" in choice_values # Major league team + assert "BHMIL" not in choice_values # Should be detected as IL, not MiL + assert "NYYMIL" not in choice_values # Minor league team # Verify the roster type detection is working - bhm_team = next(t for t in mock_teams if t.abbrev == 'BHM') - bhmil_team = next(t for t in mock_teams if t.abbrev == 'BHMIL') - nyymil_team = next(t for t in mock_teams if t.abbrev == 'NYYMIL') + bhm_team = next(t for t in mock_teams if t.abbrev == "BHM") + bhmil_team = next(t for t in mock_teams if t.abbrev == "BHMIL") + nyymil_team = next(t for t in mock_teams if t.abbrev == "NYYMIL") assert bhm_team.roster_type() == RosterType.MAJOR_LEAGUE assert bhmil_team.roster_type() == RosterType.INJURED_LIST - assert nyymil_team.roster_type() == RosterType.MINOR_LEAGUE \ No newline at end of file + assert nyymil_team.roster_type() == RosterType.MINOR_LEAGUE + + +class TestGetCachedUserTeam: + """Test the _get_cached_user_team caching helper. + + Verifies that the cache avoids redundant get_user_major_league_team calls + on repeated invocations within the TTL window, and that expired entries are + re-fetched. + """ + + @pytest.fixture(autouse=True) + def clear_cache(self): + """Isolate each test from cache state left by other tests.""" + utils.autocomplete._user_team_cache.clear() + yield + utils.autocomplete._user_team_cache.clear() + + @pytest.fixture + def mock_interaction(self): + interaction = MagicMock() + interaction.user.id = 99999 + return interaction + + @pytest.mark.asyncio + async def test_caches_result_on_first_call(self, mock_interaction): + """First call populates the cache; API function called exactly once.""" + user_team = TeamFactory.create(id=1, abbrev="POR", sname="Loggers") + + with patch( + "utils.autocomplete.get_user_major_league_team", new_callable=AsyncMock + ) as mock_get_team: + mock_get_team.return_value = user_team + + from utils.autocomplete import _get_cached_user_team + + result1 = await _get_cached_user_team(mock_interaction) + result2 = await _get_cached_user_team(mock_interaction) + + assert result1 is user_team + assert result2 is user_team + # API called only once despite two invocations + mock_get_team.assert_called_once_with(99999) + + @pytest.mark.asyncio + async def test_re_fetches_after_ttl_expires(self, mock_interaction): + """Expired cache entries cause a fresh API call.""" + import time + + user_team = TeamFactory.create(id=1, abbrev="POR", sname="Loggers") + + with patch( + "utils.autocomplete.get_user_major_league_team", new_callable=AsyncMock + ) as mock_get_team: + mock_get_team.return_value = user_team + + from utils.autocomplete import _get_cached_user_team, _USER_TEAM_CACHE_TTL + + # Seed the cache with a timestamp that is already expired + utils.autocomplete._user_team_cache[99999] = ( + user_team, + time.time() - _USER_TEAM_CACHE_TTL - 1, + ) + + await _get_cached_user_team(mock_interaction) + + # Should have called the API to refresh the stale entry + mock_get_team.assert_called_once_with(99999) + + @pytest.mark.asyncio + async def test_caches_none_result(self, mock_interaction): + """None (user has no team) is cached to avoid repeated API calls.""" + with patch( + "utils.autocomplete.get_user_major_league_team", new_callable=AsyncMock + ) as mock_get_team: + mock_get_team.return_value = None + + from utils.autocomplete import _get_cached_user_team + + result1 = await _get_cached_user_team(mock_interaction) + result2 = await _get_cached_user_team(mock_interaction) + + assert result1 is None + assert result2 is None + mock_get_team.assert_called_once() diff --git a/utils/autocomplete.py b/utils/autocomplete.py index 6980a1e..db3f4f9 100644 --- a/utils/autocomplete.py +++ b/utils/autocomplete.py @@ -4,16 +4,33 @@ Autocomplete Utilities Shared autocomplete functions for Discord slash commands. """ -from typing import List +import time +from typing import Dict, List, Optional, Tuple import discord from discord import app_commands from config import get_config -from models.team import RosterType +from models.team import RosterType, Team from services.player_service import player_service from services.team_service import team_service from utils.team_utils import get_user_major_league_team +# Cache for user team lookups: user_id -> (team, cached_at) +_user_team_cache: Dict[int, Tuple[Optional[Team], float]] = {} +_USER_TEAM_CACHE_TTL = 60 # seconds + + +async def _get_cached_user_team(interaction: discord.Interaction) -> Optional[Team]: + """Return the user's major league team, cached for 60 seconds per user.""" + user_id = interaction.user.id + if user_id in _user_team_cache: + team, cached_at = _user_team_cache[user_id] + if time.time() - cached_at < _USER_TEAM_CACHE_TTL: + return team + team = await get_user_major_league_team(user_id) + _user_team_cache[user_id] = (team, time.time()) + return team + async def player_autocomplete( interaction: discord.Interaction, current: str @@ -34,12 +51,12 @@ async def player_autocomplete( return [] try: - # Get user's team for prioritization - user_team = await get_user_major_league_team(interaction.user.id) + # Get user's team for prioritization (cached per user, 60s TTL) + user_team = await _get_cached_user_team(interaction) # Search for players using the search endpoint players = await player_service.search_players( - current, limit=50, season=get_config().sba_season + current, limit=25, season=get_config().sba_season ) # Separate players by team (user's team vs others) diff --git a/utils/cache.py b/utils/cache.py index 9f8eee6..4baf97b 100644 --- a/utils/cache.py +++ b/utils/cache.py @@ -188,9 +188,11 @@ class CacheManager: try: pattern = f"{prefix}:*" - keys = await client.keys(pattern) - if keys: - deleted = await client.delete(*keys) + keys_to_delete = [] + async for key in client.scan_iter(match=pattern): + keys_to_delete.append(key) + if keys_to_delete: + deleted = await client.delete(*keys_to_delete) logger.info(f"Cleared {deleted} cache keys with prefix '{prefix}'") return deleted except Exception as e: diff --git a/utils/decorators.py b/utils/decorators.py index 1b73305..7bd5a3d 100644 --- a/utils/decorators.py +++ b/utils/decorators.py @@ -11,29 +11,29 @@ from functools import wraps from typing import List, Optional, Callable, Any from utils.logging import set_discord_context, get_contextual_logger -cache_logger = logging.getLogger(f'{__name__}.CacheDecorators') -period_check_logger = logging.getLogger(f'{__name__}.PeriodCheckDecorators') +cache_logger = logging.getLogger(f"{__name__}.CacheDecorators") +period_check_logger = logging.getLogger(f"{__name__}.PeriodCheckDecorators") def logged_command( - command_name: Optional[str] = None, + command_name: Optional[str] = None, log_params: bool = True, - exclude_params: Optional[List[str]] = None + exclude_params: Optional[List[str]] = None, ): """ Decorator for Discord commands that adds comprehensive logging. - + This decorator automatically handles: - Setting Discord context with interaction details - Starting/ending operation timing - Logging command start/completion/failure - Preserving function metadata and signature - + Args: command_name: Override command name (defaults to function name with slashes) log_params: Whether to log command parameters (default: True) exclude_params: List of parameter names to exclude from logging - + Example: @logged_command("/roster", exclude_params=["sensitive_data"]) async def team_roster(self, interaction, team_name: str, season: int = None): @@ -42,57 +42,65 @@ def logged_command( players = await team_service.get_roster(team.id, season) embed = create_roster_embed(team, players) await interaction.followup.send(embed=embed) - + Side Effects: - Automatically sets Discord context for all subsequent log entries - Creates trace_id for request correlation - Logs command execution timing and results - Re-raises all exceptions after logging (preserves original behavior) - + Requirements: - The decorated class must have a 'logger' attribute, or one will be created - Function must be an async method with (self, interaction, ...) signature - Preserves Discord.py command registration compatibility """ + def decorator(func): + sig = inspect.signature(func) + param_names = list(sig.parameters.keys())[2:] # Skip self, interaction + exclude_set = set(exclude_params or []) + @wraps(func) async def wrapper(self, interaction, *args, **kwargs): # Auto-detect command name if not provided cmd_name = command_name or f"/{func.__name__.replace('_', '-')}" - + # Build context with safe parameter logging context = {"command": cmd_name} if log_params: - sig = inspect.signature(func) - param_names = list(sig.parameters.keys())[2:] # Skip self, interaction - exclude_set = set(exclude_params or []) - for i, (name, value) in enumerate(zip(param_names, args)): if name not in exclude_set: context[f"param_{name}"] = value - + set_discord_context(interaction=interaction, **context) - + # Get logger from the class instance or create one - logger = getattr(self, 'logger', get_contextual_logger(f'{self.__class__.__module__}.{self.__class__.__name__}')) + logger = getattr( + self, + "logger", + get_contextual_logger( + f"{self.__class__.__module__}.{self.__class__.__name__}" + ), + ) trace_id = logger.start_operation(f"{func.__name__}_command") - + try: logger.info(f"{cmd_name} command started") result = await func(self, interaction, *args, **kwargs) logger.info(f"{cmd_name} command completed successfully") logger.end_operation(trace_id, "completed") return result - + except Exception as e: logger.error(f"{cmd_name} command failed", error=e) logger.end_operation(trace_id, "failed") # Re-raise to maintain original exception handling behavior raise - + # Preserve signature for Discord.py command registration - wrapper.__signature__ = inspect.signature(func) # type: ignore + wrapper.__signature__ = sig # type: ignore return wrapper + return decorator @@ -122,6 +130,7 @@ def requires_draft_period(func): - Should be placed before @logged_command decorator - league_service must be available via import """ + @wraps(func) async def wrapper(self, interaction, *args, **kwargs): # Import here to avoid circular imports @@ -133,10 +142,12 @@ def requires_draft_period(func): current = await league_service.get_current_state() if not current: - period_check_logger.error("Could not retrieve league state for draft period check") + period_check_logger.error( + "Could not retrieve league state for draft period check" + ) embed = EmbedTemplate.error( "System Error", - "Could not verify draft period status. Please try again later." + "Could not verify draft period status. Please try again later.", ) await interaction.response.send_message(embed=embed, ephemeral=True) return @@ -148,12 +159,12 @@ def requires_draft_period(func): extra={ "user_id": interaction.user.id, "command": func.__name__, - "current_week": current.week - } + "current_week": current.week, + }, ) embed = EmbedTemplate.error( "Not Available", - "Draft commands are only available in the offseason." + "Draft commands are only available in the offseason.", ) await interaction.response.send_message(embed=embed, ephemeral=True) return @@ -161,7 +172,7 @@ def requires_draft_period(func): # Week <= 0, allow command to proceed period_check_logger.debug( f"Draft period check passed - week {current.week}", - extra={"user_id": interaction.user.id, "command": func.__name__} + extra={"user_id": interaction.user.id, "command": func.__name__}, ) return await func(self, interaction, *args, **kwargs) @@ -169,7 +180,7 @@ def requires_draft_period(func): period_check_logger.error( f"Error in draft period check: {e}", exc_info=True, - extra={"user_id": interaction.user.id, "command": func.__name__} + extra={"user_id": interaction.user.id, "command": func.__name__}, ) # Re-raise to let error handling in logged_command handle it raise @@ -182,110 +193,115 @@ def requires_draft_period(func): def cached_api_call(ttl: Optional[int] = None, cache_key_suffix: str = ""): """ Decorator to add Redis caching to service methods that return List[T]. - + This decorator will: 1. Check cache for existing data using generated key 2. Return cached data if found 3. Execute original method if cache miss 4. Cache the result for future calls - + Args: ttl: Time-to-live override in seconds (uses service default if None) cache_key_suffix: Additional suffix for cache key differentiation - + Usage: @cached_api_call(ttl=600, cache_key_suffix="by_season") async def get_teams_by_season(self, season: int) -> List[Team]: # Original method implementation - + Requirements: - Method must be async - Method must return List[T] where T is a model - Class must have self.cache (CacheManager instance) - Class must have self._generate_cache_key, self._get_cached_items, self._cache_items methods """ + def decorator(func: Callable) -> Callable: + sig = inspect.signature(func) + @wraps(func) async def wrapper(self, *args, **kwargs) -> List[Any]: # Check if caching is available (service has cache manager) - if not hasattr(self, 'cache') or not hasattr(self, '_generate_cache_key'): + if not hasattr(self, "cache") or not hasattr(self, "_generate_cache_key"): # No caching available, execute original method return await func(self, *args, **kwargs) - + # Generate cache key from method name, args, and kwargs method_name = f"{func.__name__}{cache_key_suffix}" - + # Convert args and kwargs to params list for consistent cache key - sig = inspect.signature(func) bound_args = sig.bind(self, *args, **kwargs) bound_args.apply_defaults() - + # Skip 'self' and convert to params format params = [] for param_name, param_value in bound_args.arguments.items(): - if param_name != 'self' and param_value is not None: + if param_name != "self" and param_value is not None: params.append((param_name, param_value)) - + cache_key = self._generate_cache_key(method_name, params) - + # Try to get from cache - if hasattr(self, '_get_cached_items'): + if hasattr(self, "_get_cached_items"): cached_result = await self._get_cached_items(cache_key) if cached_result is not None: cache_logger.debug(f"Cache hit: {method_name}") return cached_result - + # Cache miss - execute original method cache_logger.debug(f"Cache miss: {method_name}") result = await func(self, *args, **kwargs) - + # Cache the result if we have items and caching methods - if result and hasattr(self, '_cache_items'): + if result and hasattr(self, "_cache_items"): await self._cache_items(cache_key, result, ttl) cache_logger.debug(f"Cached {len(result)} items for {method_name}") - + return result - + return wrapper + return decorator def cached_single_item(ttl: Optional[int] = None, cache_key_suffix: str = ""): """ Decorator to add Redis caching to service methods that return Optional[T]. - + Similar to cached_api_call but for methods returning a single model instance. - + Args: ttl: Time-to-live override in seconds cache_key_suffix: Additional suffix for cache key differentiation - + Usage: @cached_single_item(ttl=300, cache_key_suffix="by_id") async def get_player(self, player_id: int) -> Optional[Player]: # Original method implementation """ + def decorator(func: Callable) -> Callable: + sig = inspect.signature(func) + @wraps(func) async def wrapper(self, *args, **kwargs) -> Optional[Any]: # Check if caching is available - if not hasattr(self, 'cache') or not hasattr(self, '_generate_cache_key'): + if not hasattr(self, "cache") or not hasattr(self, "_generate_cache_key"): return await func(self, *args, **kwargs) - + # Generate cache key method_name = f"{func.__name__}{cache_key_suffix}" - - sig = inspect.signature(func) + bound_args = sig.bind(self, *args, **kwargs) bound_args.apply_defaults() - + params = [] for param_name, param_value in bound_args.arguments.items(): - if param_name != 'self' and param_value is not None: + if param_name != "self" and param_value is not None: params.append((param_name, param_value)) - + cache_key = self._generate_cache_key(method_name, params) - + # Try cache first try: cached_data = await self.cache.get(cache_key) @@ -293,12 +309,14 @@ def cached_single_item(ttl: Optional[int] = None, cache_key_suffix: str = ""): cache_logger.debug(f"Cache hit: {method_name}") return self.model_class.from_api_data(cached_data) except Exception as e: - cache_logger.warning(f"Error reading single item cache for {cache_key}: {e}") - + cache_logger.warning( + f"Error reading single item cache for {cache_key}: {e}" + ) + # Cache miss - execute original method cache_logger.debug(f"Cache miss: {method_name}") result = await func(self, *args, **kwargs) - + # Cache the single result if result: try: @@ -306,43 +324,54 @@ def cached_single_item(ttl: Optional[int] = None, cache_key_suffix: str = ""): await self.cache.set(cache_key, cache_data, ttl) cache_logger.debug(f"Cached single item for {method_name}") except Exception as e: - cache_logger.warning(f"Error caching single item for {cache_key}: {e}") - + cache_logger.warning( + f"Error caching single item for {cache_key}: {e}" + ) + return result - + return wrapper + return decorator def cache_invalidate(*cache_patterns: str): """ Decorator to invalidate cache entries when data is modified. - + Args: cache_patterns: Cache key patterns to invalidate (supports prefix matching) - + Usage: @cache_invalidate("players_by_team", "teams_by_season") async def update_player(self, player_id: int, updates: dict) -> Optional[Player]: # Original method implementation """ + def decorator(func: Callable) -> Callable: @wraps(func) async def wrapper(self, *args, **kwargs): # Execute original method first result = await func(self, *args, **kwargs) - + # Invalidate specified cache patterns - if hasattr(self, 'cache'): + if hasattr(self, "cache"): for pattern in cache_patterns: try: - cleared = await self.cache.clear_prefix(f"sba:{self.endpoint}_{pattern}") + cleared = await self.cache.clear_prefix( + f"sba:{self.endpoint}_{pattern}" + ) if cleared > 0: - cache_logger.info(f"Invalidated {cleared} cache entries for pattern: {pattern}") + cache_logger.info( + f"Invalidated {cleared} cache entries for pattern: {pattern}" + ) except Exception as e: - cache_logger.warning(f"Error invalidating cache pattern {pattern}: {e}") - + cache_logger.warning( + f"Error invalidating cache pattern {pattern}: {e}" + ) + return result - + return wrapper - return decorator \ No newline at end of file + + return decorator diff --git a/utils/logging.py b/utils/logging.py index 92c0f05..6c6dfde 100644 --- a/utils/logging.py +++ b/utils/logging.py @@ -24,6 +24,8 @@ JSONValue = Union[ str, int, float, bool, None, dict[str, Any], list[Any] # nested object # arrays ] +_SERIALIZABLE_TYPES = (str, int, float, bool, type(None)) + class JSONFormatter(logging.Formatter): """Custom JSON formatter for structured file logging.""" @@ -93,11 +95,11 @@ class JSONFormatter(logging.Formatter): extra_data = {} for key, value in record.__dict__.items(): if key not in excluded_keys: - # Ensure JSON serializable - try: - json.dumps(value) + if isinstance(value, _SERIALIZABLE_TYPES) or isinstance( + value, (list, dict) + ): extra_data[key] = value - except (TypeError, ValueError): + else: extra_data[key] = str(value) if extra_data: