From 9df8d77fa03493b0053d4f89edb5ad90d5526aa6 Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Fri, 20 Mar 2026 09:14:14 -0500 Subject: [PATCH] perf: replace sequential awaits with asyncio.gather() for true parallelism Fixes #87 Co-Authored-By: Claude Opus 4.6 (1M context) --- commands/transactions/management.py | 273 +++++++++++++++------------- services/stats_service.py | 10 +- tests/test_services_stats.py | 111 +++++++++++ 3 files changed, 259 insertions(+), 135 deletions(-) create mode 100644 tests/test_services_stats.py diff --git a/commands/transactions/management.py b/commands/transactions/management.py index b740b03..45d7e33 100644 --- a/commands/transactions/management.py +++ b/commands/transactions/management.py @@ -3,6 +3,7 @@ Transaction Management Commands Core transaction commands for roster management and transaction tracking. """ + from typing import Optional import asyncio @@ -21,6 +22,7 @@ from views.base import PaginationView from services.transaction_service import transaction_service from services.roster_service import roster_service from services.team_service import team_service + # No longer need TransactionStatus enum @@ -34,25 +36,28 @@ class TransactionPaginationView(PaginationView): all_transactions: list, user_id: int, timeout: float = 300.0, - show_page_numbers: bool = True + show_page_numbers: bool = True, ): super().__init__( pages=pages, user_id=user_id, timeout=timeout, - show_page_numbers=show_page_numbers + show_page_numbers=show_page_numbers, ) self.all_transactions = all_transactions - @discord.ui.button(label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1) - async def show_move_ids(self, interaction: discord.Interaction, button: discord.ui.Button): + @discord.ui.button( + label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1 + ) + async def show_move_ids( + self, interaction: discord.Interaction, button: discord.ui.Button + ): """Show all move IDs in an ephemeral message.""" self.increment_interaction_count() if not self.all_transactions: await interaction.response.send_message( - "No transactions to show.", - ephemeral=True + "No transactions to show.", ephemeral=True ) return @@ -85,8 +90,7 @@ class TransactionPaginationView(PaginationView): # Send the messages if not messages: await interaction.response.send_message( - "No transactions to display.", - ephemeral=True + "No transactions to display.", ephemeral=True ) return @@ -101,14 +105,13 @@ class TransactionPaginationView(PaginationView): class TransactionCommands(commands.Cog): """Transaction command handlers for roster management.""" - + def __init__(self, bot: commands.Bot): self.bot = bot - self.logger = get_contextual_logger(f'{__name__}.TransactionCommands') - + self.logger = get_contextual_logger(f"{__name__}.TransactionCommands") + @app_commands.command( - name="mymoves", - description="View your pending and scheduled transactions" + name="mymoves", description="View your pending and scheduled transactions" ) @app_commands.describe( show_cancelled="Include cancelled transactions in the display (default: False)" @@ -116,39 +119,45 @@ class TransactionCommands(commands.Cog): @requires_team() @logged_command("/mymoves") async def my_moves( - self, - interaction: discord.Interaction, - show_cancelled: bool = False + self, interaction: discord.Interaction, show_cancelled: bool = False ): """Display user's transaction status and history.""" await interaction.response.defer() - + # Get user's team - team = await get_user_major_league_team(interaction.user.id, get_config().sba_season) - + team = await get_user_major_league_team( + interaction.user.id, get_config().sba_season + ) + if not team: await interaction.followup.send( "❌ You don't appear to own a team in the current season.", - ephemeral=True + ephemeral=True, ) return - + # Get transactions in parallel - pending_task = transaction_service.get_pending_transactions(team.abbrev, get_config().sba_season) - frozen_task = transaction_service.get_frozen_transactions(team.abbrev, get_config().sba_season) - processed_task = transaction_service.get_processed_transactions(team.abbrev, get_config().sba_season) - - pending_transactions = await pending_task - frozen_transactions = await frozen_task - processed_transactions = await processed_task - + ( + pending_transactions, + frozen_transactions, + processed_transactions, + ) = await asyncio.gather( + transaction_service.get_pending_transactions( + team.abbrev, get_config().sba_season + ), + transaction_service.get_frozen_transactions( + team.abbrev, get_config().sba_season + ), + transaction_service.get_processed_transactions( + team.abbrev, get_config().sba_season + ), + ) + # Get cancelled if requested cancelled_transactions = [] if show_cancelled: cancelled_transactions = await transaction_service.get_team_transactions( - team.abbrev, - get_config().sba_season, - cancelled=True + team.abbrev, get_config().sba_season, cancelled=True ) pages = self._create_my_moves_pages( @@ -156,15 +165,15 @@ class TransactionCommands(commands.Cog): pending_transactions, frozen_transactions, processed_transactions, - cancelled_transactions + cancelled_transactions, ) # Collect all transactions for the "Show Move IDs" button all_transactions = ( - pending_transactions + - frozen_transactions + - processed_transactions + - cancelled_transactions + pending_transactions + + frozen_transactions + + processed_transactions + + cancelled_transactions ) # If only one page and no transactions, send without any buttons @@ -177,93 +186,90 @@ class TransactionCommands(commands.Cog): all_transactions=all_transactions, user_id=interaction.user.id, timeout=300.0, - show_page_numbers=True + show_page_numbers=True, ) await interaction.followup.send(embed=view.get_current_embed(), view=view) - + @app_commands.command( - name="legal", - description="Check roster legality for current and next week" - ) - @app_commands.describe( - team="Team abbreviation to check (defaults to your team)" + name="legal", description="Check roster legality for current and next week" ) + @app_commands.describe(team="Team abbreviation to check (defaults to your team)") @requires_team() @logged_command("/legal") - async def legal( - self, - interaction: discord.Interaction, - team: Optional[str] = None - ): + async def legal(self, interaction: discord.Interaction, team: Optional[str] = None): """Check roster legality and display detailed validation results.""" await interaction.response.defer() - + # Get target team if team: - target_team = await team_service.get_team_by_abbrev(team.upper(), get_config().sba_season) + target_team = await team_service.get_team_by_abbrev( + team.upper(), get_config().sba_season + ) if not target_team: await interaction.followup.send( f"❌ Could not find team '{team}' in season {get_config().sba_season}.", - ephemeral=True + ephemeral=True, ) return else: # Get user's team - user_teams = await team_service.get_teams_by_owner(interaction.user.id, get_config().sba_season) + user_teams = await team_service.get_teams_by_owner( + interaction.user.id, get_config().sba_season + ) if not user_teams: await interaction.followup.send( "❌ You don't appear to own a team. Please specify a team abbreviation.", - ephemeral=True + ephemeral=True, ) return target_team = user_teams[0] - + # Get rosters in parallel current_roster, next_roster = await asyncio.gather( roster_service.get_current_roster(target_team.id), - roster_service.get_next_roster(target_team.id) + roster_service.get_next_roster(target_team.id), ) - + if not current_roster and not next_roster: await interaction.followup.send( f"❌ Could not retrieve roster data for {target_team.abbrev}.", - ephemeral=True + ephemeral=True, ) return - + # Validate rosters in parallel validation_tasks = [] if current_roster: validation_tasks.append(roster_service.validate_roster(current_roster)) else: validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task - + if next_roster: validation_tasks.append(roster_service.validate_roster(next_roster)) else: validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task - + validation_results = await asyncio.gather(*validation_tasks) current_validation = validation_results[0] if current_roster else None next_validation = validation_results[1] if next_roster else None - + embed = await self._create_legal_embed( target_team, current_roster, - next_roster, + next_roster, current_validation, - next_validation + next_validation, ) - + await interaction.followup.send(embed=embed) - + def _create_my_moves_pages( self, team, pending_transactions, frozen_transactions, processed_transactions, - cancelled_transactions + cancelled_transactions, ) -> list[discord.Embed]: """Create paginated embeds showing user's transaction status.""" @@ -277,7 +283,9 @@ class TransactionCommands(commands.Cog): # Page 1: Summary + Pending Transactions if pending_transactions: total_pending = len(pending_transactions) - total_pages = (total_pending + transactions_per_page - 1) // transactions_per_page + total_pages = ( + total_pending + transactions_per_page - 1 + ) // transactions_per_page for page_num in range(total_pages): start_idx = page_num * transactions_per_page @@ -287,11 +295,11 @@ class TransactionCommands(commands.Cog): embed = EmbedTemplate.create_base_embed( title=f"📋 Transaction Status - {team.abbrev}", description=f"{team.lname} • Season {get_config().sba_season}", - color=EmbedColors.INFO + color=EmbedColors.INFO, ) # Add team thumbnail if available - if hasattr(team, 'thumbnail') and team.thumbnail: + if hasattr(team, "thumbnail") and team.thumbnail: embed.set_thumbnail(url=team.thumbnail) # Pending transactions for this page @@ -300,7 +308,7 @@ class TransactionCommands(commands.Cog): embed.add_field( name=f"⏳ Pending Transactions ({total_pending} total)", value="\n".join(pending_lines), - inline=False + inline=False, ) # Add summary only on first page @@ -314,8 +322,12 @@ class TransactionCommands(commands.Cog): embed.add_field( name="Summary", - value=", ".join(status_text) if status_text else "No active transactions", - inline=True + value=( + ", ".join(status_text) + if status_text + else "No active transactions" + ), + inline=True, ) pages.append(embed) @@ -324,16 +336,16 @@ class TransactionCommands(commands.Cog): embed = EmbedTemplate.create_base_embed( title=f"📋 Transaction Status - {team.abbrev}", description=f"{team.lname} • Season {get_config().sba_season}", - color=EmbedColors.INFO + color=EmbedColors.INFO, ) - if hasattr(team, 'thumbnail') and team.thumbnail: + if hasattr(team, "thumbnail") and team.thumbnail: embed.set_thumbnail(url=team.thumbnail) embed.add_field( name="⏳ Pending Transactions", value="No pending transactions", - inline=False + inline=False, ) total_frozen = len(frozen_transactions) @@ -343,8 +355,10 @@ class TransactionCommands(commands.Cog): embed.add_field( name="Summary", - value=", ".join(status_text) if status_text else "No active transactions", - inline=True + value=( + ", ".join(status_text) if status_text else "No active transactions" + ), + inline=True, ) pages.append(embed) @@ -354,10 +368,10 @@ class TransactionCommands(commands.Cog): embed = EmbedTemplate.create_base_embed( title=f"📋 Transaction Status - {team.abbrev}", description=f"{team.lname} • Season {get_config().sba_season}", - color=EmbedColors.INFO + color=EmbedColors.INFO, ) - if hasattr(team, 'thumbnail') and team.thumbnail: + if hasattr(team, "thumbnail") and team.thumbnail: embed.set_thumbnail(url=team.thumbnail) frozen_lines = [format_transaction(tx) for tx in frozen_transactions] @@ -365,7 +379,7 @@ class TransactionCommands(commands.Cog): embed.add_field( name=f"❄️ Scheduled for Processing ({len(frozen_transactions)} total)", value="\n".join(frozen_lines), - inline=False + inline=False, ) pages.append(embed) @@ -375,18 +389,20 @@ class TransactionCommands(commands.Cog): embed = EmbedTemplate.create_base_embed( title=f"📋 Transaction Status - {team.abbrev}", description=f"{team.lname} • Season {get_config().sba_season}", - color=EmbedColors.INFO + color=EmbedColors.INFO, ) - if hasattr(team, 'thumbnail') and team.thumbnail: + if hasattr(team, "thumbnail") and team.thumbnail: embed.set_thumbnail(url=team.thumbnail) - processed_lines = [format_transaction(tx) for tx in processed_transactions[-20:]] # Last 20 + processed_lines = [ + format_transaction(tx) for tx in processed_transactions[-20:] + ] # Last 20 embed.add_field( name=f"✅ Recently Processed ({len(processed_transactions[-20:])} shown)", value="\n".join(processed_lines), - inline=False + inline=False, ) pages.append(embed) @@ -396,18 +412,20 @@ class TransactionCommands(commands.Cog): embed = EmbedTemplate.create_base_embed( title=f"📋 Transaction Status - {team.abbrev}", description=f"{team.lname} • Season {get_config().sba_season}", - color=EmbedColors.INFO + color=EmbedColors.INFO, ) - if hasattr(team, 'thumbnail') and team.thumbnail: + if hasattr(team, "thumbnail") and team.thumbnail: embed.set_thumbnail(url=team.thumbnail) - cancelled_lines = [format_transaction(tx) for tx in cancelled_transactions[-20:]] # Last 20 + cancelled_lines = [ + format_transaction(tx) for tx in cancelled_transactions[-20:] + ] # Last 20 embed.add_field( name=f"❌ Cancelled Transactions ({len(cancelled_transactions[-20:])} shown)", value="\n".join(cancelled_lines), - inline=False + inline=False, ) pages.append(embed) @@ -417,111 +435,106 @@ class TransactionCommands(commands.Cog): page.set_footer(text="Use /legal to check roster legality") return pages - + async def _create_legal_embed( - self, - team, - current_roster, - next_roster, - current_validation, - next_validation + self, team, current_roster, next_roster, current_validation, next_validation ) -> discord.Embed: """Create embed showing roster legality check results.""" - + # Determine overall status overall_legal = True if current_validation and not current_validation.is_legal: overall_legal = False if next_validation and not next_validation.is_legal: overall_legal = False - + status_emoji = "✅" if overall_legal else "❌" embed_color = EmbedColors.SUCCESS if overall_legal else EmbedColors.ERROR - + embed = EmbedTemplate.create_base_embed( title=f"{status_emoji} Roster Check - {team.abbrev}", description=f"{team.lname} • Season {get_config().sba_season}", - color=embed_color + color=embed_color, ) - + # Add team thumbnail if available - if hasattr(team, 'thumbnail') and team.thumbnail: + if hasattr(team, "thumbnail") and team.thumbnail: embed.set_thumbnail(url=team.thumbnail) - + # Current week roster if current_roster and current_validation: current_lines = [] - current_lines.append(f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL") + current_lines.append( + f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL" + ) current_lines.append(f"**sWAR:** {current_validation.total_sWAR:.2f}") - + if current_validation.errors: current_lines.append(f"**❌ Errors:** {len(current_validation.errors)}") for error in current_validation.errors[:3]: # Show first 3 errors current_lines.append(f"• {error}") - + if current_validation.warnings: - current_lines.append(f"**⚠️ Warnings:** {len(current_validation.warnings)}") + current_lines.append( + f"**⚠️ Warnings:** {len(current_validation.warnings)}" + ) for warning in current_validation.warnings[:2]: # Show first 2 warnings current_lines.append(f"• {warning}") - + embed.add_field( name=f"{current_validation.status_emoji} Current Week", value="\n".join(current_lines), - inline=True + inline=True, ) else: embed.add_field( - name="❓ Current Week", - value="Roster data not available", - inline=True + name="❓ Current Week", value="Roster data not available", inline=True ) - - # Next week roster + + # Next week roster if next_roster and next_validation: next_lines = [] - next_lines.append(f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL") + next_lines.append( + f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL" + ) next_lines.append(f"**sWAR:** {next_validation.total_sWAR:.2f}") - + if next_validation.errors: next_lines.append(f"**❌ Errors:** {len(next_validation.errors)}") for error in next_validation.errors[:3]: # Show first 3 errors next_lines.append(f"• {error}") - + if next_validation.warnings: next_lines.append(f"**⚠️ Warnings:** {len(next_validation.warnings)}") for warning in next_validation.warnings[:2]: # Show first 2 warnings next_lines.append(f"• {warning}") - + embed.add_field( name=f"{next_validation.status_emoji} Next Week", value="\n".join(next_lines), - inline=True + inline=True, ) else: embed.add_field( - name="❓ Next Week", - value="Roster data not available", - inline=True + name="❓ Next Week", value="Roster data not available", inline=True ) - + # Overall status if overall_legal: embed.add_field( - name="Overall Status", - value="✅ All rosters are legal", - inline=False + name="Overall Status", value="✅ All rosters are legal", inline=False ) else: embed.add_field( - name="Overall Status", + name="Overall Status", value="❌ Roster violations found - please review and correct", - inline=False + inline=False, ) - + embed.set_footer(text="Roster validation based on current league rules") return embed async def setup(bot: commands.Bot): """Load the transaction commands cog.""" - await bot.add_cog(TransactionCommands(bot)) \ No newline at end of file + await bot.add_cog(TransactionCommands(bot)) diff --git a/services/stats_service.py b/services/stats_service.py index a3b3a06..3441288 100644 --- a/services/stats_service.py +++ b/services/stats_service.py @@ -4,6 +4,7 @@ Statistics service for Discord Bot v2.0 Handles batting and pitching statistics retrieval and processing. """ +import asyncio import logging from typing import Optional @@ -144,11 +145,10 @@ class StatsService: """ try: # Get both types of stats concurrently - batting_task = self.get_batting_stats(player_id, season) - pitching_task = self.get_pitching_stats(player_id, season) - - batting_stats = await batting_task - pitching_stats = await pitching_task + batting_stats, pitching_stats = await asyncio.gather( + self.get_batting_stats(player_id, season), + self.get_pitching_stats(player_id, season), + ) logger.debug( f"Retrieved stats for player {player_id}: " diff --git a/tests/test_services_stats.py b/tests/test_services_stats.py new file mode 100644 index 0000000..39f89b0 --- /dev/null +++ b/tests/test_services_stats.py @@ -0,0 +1,111 @@ +""" +Tests for StatsService + +Validates stats service functionality including concurrent stat retrieval +and error handling in get_player_stats(). +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from services.stats_service import StatsService + + +class TestStatsServiceGetPlayerStats: + """Test StatsService.get_player_stats() concurrent retrieval.""" + + @pytest.fixture + def service(self): + """Create a fresh StatsService instance for testing.""" + return StatsService() + + @pytest.fixture + def mock_batting_stats(self): + """Create a mock BattingStats object.""" + stats = MagicMock() + stats.avg = 0.300 + return stats + + @pytest.fixture + def mock_pitching_stats(self): + """Create a mock PitchingStats object.""" + stats = MagicMock() + stats.era = 3.50 + return stats + + @pytest.mark.asyncio + async def test_both_stats_returned( + self, service, mock_batting_stats, mock_pitching_stats + ): + """When both batting and pitching stats exist, both are returned. + + Verifies that get_player_stats returns a tuple of (batting, pitching) + when both stat types are available for the player. + """ + service.get_batting_stats = AsyncMock(return_value=mock_batting_stats) + service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats) + + batting, pitching = await service.get_player_stats(player_id=100, season=12) + + assert batting is mock_batting_stats + assert pitching is mock_pitching_stats + service.get_batting_stats.assert_called_once_with(100, 12) + service.get_pitching_stats.assert_called_once_with(100, 12) + + @pytest.mark.asyncio + async def test_batting_only(self, service, mock_batting_stats): + """When only batting stats exist, pitching is None. + + Covers the case of a position player with no pitching record. + """ + service.get_batting_stats = AsyncMock(return_value=mock_batting_stats) + service.get_pitching_stats = AsyncMock(return_value=None) + + batting, pitching = await service.get_player_stats(player_id=200, season=12) + + assert batting is mock_batting_stats + assert pitching is None + + @pytest.mark.asyncio + async def test_pitching_only(self, service, mock_pitching_stats): + """When only pitching stats exist, batting is None. + + Covers the case of a pitcher with no batting record. + """ + service.get_batting_stats = AsyncMock(return_value=None) + service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats) + + batting, pitching = await service.get_player_stats(player_id=300, season=12) + + assert batting is None + assert pitching is mock_pitching_stats + + @pytest.mark.asyncio + async def test_no_stats_found(self, service): + """When no stats exist for the player, both are None. + + Covers the case where a player has no stats for the given season + (e.g., didn't play). + """ + service.get_batting_stats = AsyncMock(return_value=None) + service.get_pitching_stats = AsyncMock(return_value=None) + + batting, pitching = await service.get_player_stats(player_id=400, season=12) + + assert batting is None + assert pitching is None + + @pytest.mark.asyncio + async def test_exception_returns_none_tuple(self, service): + """When an exception occurs, (None, None) is returned. + + The get_player_stats method wraps both calls in a try/except and + returns (None, None) on any error, ensuring callers always get a tuple. + """ + service.get_batting_stats = AsyncMock(side_effect=RuntimeError("API down")) + service.get_pitching_stats = AsyncMock(return_value=None) + + batting, pitching = await service.get_player_stats(player_id=500, season=12) + + assert batting is None + assert pitching is None