perf: replace sequential awaits with asyncio.gather() for true parallelism

Fixes #87

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2026-03-20 09:14:14 -05:00
parent 910a27e356
commit 9df8d77fa0
3 changed files with 259 additions and 135 deletions

View File

@ -3,6 +3,7 @@ Transaction Management Commands
Core transaction commands for roster management and transaction tracking. Core transaction commands for roster management and transaction tracking.
""" """
from typing import Optional from typing import Optional
import asyncio import asyncio
@ -21,6 +22,7 @@ from views.base import PaginationView
from services.transaction_service import transaction_service from services.transaction_service import transaction_service
from services.roster_service import roster_service from services.roster_service import roster_service
from services.team_service import team_service from services.team_service import team_service
# No longer need TransactionStatus enum # No longer need TransactionStatus enum
@ -34,25 +36,28 @@ class TransactionPaginationView(PaginationView):
all_transactions: list, all_transactions: list,
user_id: int, user_id: int,
timeout: float = 300.0, timeout: float = 300.0,
show_page_numbers: bool = True show_page_numbers: bool = True,
): ):
super().__init__( super().__init__(
pages=pages, pages=pages,
user_id=user_id, user_id=user_id,
timeout=timeout, timeout=timeout,
show_page_numbers=show_page_numbers show_page_numbers=show_page_numbers,
) )
self.all_transactions = all_transactions self.all_transactions = all_transactions
@discord.ui.button(label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1) @discord.ui.button(
async def show_move_ids(self, interaction: discord.Interaction, button: discord.ui.Button): label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1
)
async def show_move_ids(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Show all move IDs in an ephemeral message.""" """Show all move IDs in an ephemeral message."""
self.increment_interaction_count() self.increment_interaction_count()
if not self.all_transactions: if not self.all_transactions:
await interaction.response.send_message( await interaction.response.send_message(
"No transactions to show.", "No transactions to show.", ephemeral=True
ephemeral=True
) )
return return
@ -85,8 +90,7 @@ class TransactionPaginationView(PaginationView):
# Send the messages # Send the messages
if not messages: if not messages:
await interaction.response.send_message( await interaction.response.send_message(
"No transactions to display.", "No transactions to display.", ephemeral=True
ephemeral=True
) )
return return
@ -101,14 +105,13 @@ class TransactionPaginationView(PaginationView):
class TransactionCommands(commands.Cog): class TransactionCommands(commands.Cog):
"""Transaction command handlers for roster management.""" """Transaction command handlers for roster management."""
def __init__(self, bot: commands.Bot): def __init__(self, bot: commands.Bot):
self.bot = bot self.bot = bot
self.logger = get_contextual_logger(f'{__name__}.TransactionCommands') self.logger = get_contextual_logger(f"{__name__}.TransactionCommands")
@app_commands.command( @app_commands.command(
name="mymoves", name="mymoves", description="View your pending and scheduled transactions"
description="View your pending and scheduled transactions"
) )
@app_commands.describe( @app_commands.describe(
show_cancelled="Include cancelled transactions in the display (default: False)" show_cancelled="Include cancelled transactions in the display (default: False)"
@ -116,39 +119,45 @@ class TransactionCommands(commands.Cog):
@requires_team() @requires_team()
@logged_command("/mymoves") @logged_command("/mymoves")
async def my_moves( async def my_moves(
self, self, interaction: discord.Interaction, show_cancelled: bool = False
interaction: discord.Interaction,
show_cancelled: bool = False
): ):
"""Display user's transaction status and history.""" """Display user's transaction status and history."""
await interaction.response.defer() await interaction.response.defer()
# Get user's team # Get user's team
team = await get_user_major_league_team(interaction.user.id, get_config().sba_season) team = await get_user_major_league_team(
interaction.user.id, get_config().sba_season
)
if not team: if not team:
await interaction.followup.send( await interaction.followup.send(
"❌ You don't appear to own a team in the current season.", "❌ You don't appear to own a team in the current season.",
ephemeral=True ephemeral=True,
) )
return return
# Get transactions in parallel # Get transactions in parallel
pending_task = transaction_service.get_pending_transactions(team.abbrev, get_config().sba_season) (
frozen_task = transaction_service.get_frozen_transactions(team.abbrev, get_config().sba_season) pending_transactions,
processed_task = transaction_service.get_processed_transactions(team.abbrev, get_config().sba_season) frozen_transactions,
processed_transactions,
pending_transactions = await pending_task ) = await asyncio.gather(
frozen_transactions = await frozen_task transaction_service.get_pending_transactions(
processed_transactions = await processed_task team.abbrev, get_config().sba_season
),
transaction_service.get_frozen_transactions(
team.abbrev, get_config().sba_season
),
transaction_service.get_processed_transactions(
team.abbrev, get_config().sba_season
),
)
# Get cancelled if requested # Get cancelled if requested
cancelled_transactions = [] cancelled_transactions = []
if show_cancelled: if show_cancelled:
cancelled_transactions = await transaction_service.get_team_transactions( cancelled_transactions = await transaction_service.get_team_transactions(
team.abbrev, team.abbrev, get_config().sba_season, cancelled=True
get_config().sba_season,
cancelled=True
) )
pages = self._create_my_moves_pages( pages = self._create_my_moves_pages(
@ -156,15 +165,15 @@ class TransactionCommands(commands.Cog):
pending_transactions, pending_transactions,
frozen_transactions, frozen_transactions,
processed_transactions, processed_transactions,
cancelled_transactions cancelled_transactions,
) )
# Collect all transactions for the "Show Move IDs" button # Collect all transactions for the "Show Move IDs" button
all_transactions = ( all_transactions = (
pending_transactions + pending_transactions
frozen_transactions + + frozen_transactions
processed_transactions + + processed_transactions
cancelled_transactions + cancelled_transactions
) )
# If only one page and no transactions, send without any buttons # If only one page and no transactions, send without any buttons
@ -177,93 +186,90 @@ class TransactionCommands(commands.Cog):
all_transactions=all_transactions, all_transactions=all_transactions,
user_id=interaction.user.id, user_id=interaction.user.id,
timeout=300.0, timeout=300.0,
show_page_numbers=True show_page_numbers=True,
) )
await interaction.followup.send(embed=view.get_current_embed(), view=view) await interaction.followup.send(embed=view.get_current_embed(), view=view)
@app_commands.command( @app_commands.command(
name="legal", name="legal", description="Check roster legality for current and next week"
description="Check roster legality for current and next week"
)
@app_commands.describe(
team="Team abbreviation to check (defaults to your team)"
) )
@app_commands.describe(team="Team abbreviation to check (defaults to your team)")
@requires_team() @requires_team()
@logged_command("/legal") @logged_command("/legal")
async def legal( async def legal(self, interaction: discord.Interaction, team: Optional[str] = None):
self,
interaction: discord.Interaction,
team: Optional[str] = None
):
"""Check roster legality and display detailed validation results.""" """Check roster legality and display detailed validation results."""
await interaction.response.defer() await interaction.response.defer()
# Get target team # Get target team
if team: if team:
target_team = await team_service.get_team_by_abbrev(team.upper(), get_config().sba_season) target_team = await team_service.get_team_by_abbrev(
team.upper(), get_config().sba_season
)
if not target_team: if not target_team:
await interaction.followup.send( await interaction.followup.send(
f"❌ Could not find team '{team}' in season {get_config().sba_season}.", f"❌ Could not find team '{team}' in season {get_config().sba_season}.",
ephemeral=True ephemeral=True,
) )
return return
else: else:
# Get user's team # Get user's team
user_teams = await team_service.get_teams_by_owner(interaction.user.id, get_config().sba_season) user_teams = await team_service.get_teams_by_owner(
interaction.user.id, get_config().sba_season
)
if not user_teams: if not user_teams:
await interaction.followup.send( await interaction.followup.send(
"❌ You don't appear to own a team. Please specify a team abbreviation.", "❌ You don't appear to own a team. Please specify a team abbreviation.",
ephemeral=True ephemeral=True,
) )
return return
target_team = user_teams[0] target_team = user_teams[0]
# Get rosters in parallel # Get rosters in parallel
current_roster, next_roster = await asyncio.gather( current_roster, next_roster = await asyncio.gather(
roster_service.get_current_roster(target_team.id), roster_service.get_current_roster(target_team.id),
roster_service.get_next_roster(target_team.id) roster_service.get_next_roster(target_team.id),
) )
if not current_roster and not next_roster: if not current_roster and not next_roster:
await interaction.followup.send( await interaction.followup.send(
f"❌ Could not retrieve roster data for {target_team.abbrev}.", f"❌ Could not retrieve roster data for {target_team.abbrev}.",
ephemeral=True ephemeral=True,
) )
return return
# Validate rosters in parallel # Validate rosters in parallel
validation_tasks = [] validation_tasks = []
if current_roster: if current_roster:
validation_tasks.append(roster_service.validate_roster(current_roster)) validation_tasks.append(roster_service.validate_roster(current_roster))
else: else:
validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task
if next_roster: if next_roster:
validation_tasks.append(roster_service.validate_roster(next_roster)) validation_tasks.append(roster_service.validate_roster(next_roster))
else: else:
validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task
validation_results = await asyncio.gather(*validation_tasks) validation_results = await asyncio.gather(*validation_tasks)
current_validation = validation_results[0] if current_roster else None current_validation = validation_results[0] if current_roster else None
next_validation = validation_results[1] if next_roster else None next_validation = validation_results[1] if next_roster else None
embed = await self._create_legal_embed( embed = await self._create_legal_embed(
target_team, target_team,
current_roster, current_roster,
next_roster, next_roster,
current_validation, current_validation,
next_validation next_validation,
) )
await interaction.followup.send(embed=embed) await interaction.followup.send(embed=embed)
def _create_my_moves_pages( def _create_my_moves_pages(
self, self,
team, team,
pending_transactions, pending_transactions,
frozen_transactions, frozen_transactions,
processed_transactions, processed_transactions,
cancelled_transactions cancelled_transactions,
) -> list[discord.Embed]: ) -> list[discord.Embed]:
"""Create paginated embeds showing user's transaction status.""" """Create paginated embeds showing user's transaction status."""
@ -277,7 +283,9 @@ class TransactionCommands(commands.Cog):
# Page 1: Summary + Pending Transactions # Page 1: Summary + Pending Transactions
if pending_transactions: if pending_transactions:
total_pending = len(pending_transactions) total_pending = len(pending_transactions)
total_pages = (total_pending + transactions_per_page - 1) // transactions_per_page total_pages = (
total_pending + transactions_per_page - 1
) // transactions_per_page
for page_num in range(total_pages): for page_num in range(total_pages):
start_idx = page_num * transactions_per_page start_idx = page_num * transactions_per_page
@ -287,11 +295,11 @@ class TransactionCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title=f"📋 Transaction Status - {team.abbrev}", title=f"📋 Transaction Status - {team.abbrev}",
description=f"{team.lname} • Season {get_config().sba_season}", description=f"{team.lname} • Season {get_config().sba_season}",
color=EmbedColors.INFO color=EmbedColors.INFO,
) )
# Add team thumbnail if available # Add team thumbnail if available
if hasattr(team, 'thumbnail') and team.thumbnail: if hasattr(team, "thumbnail") and team.thumbnail:
embed.set_thumbnail(url=team.thumbnail) embed.set_thumbnail(url=team.thumbnail)
# Pending transactions for this page # Pending transactions for this page
@ -300,7 +308,7 @@ class TransactionCommands(commands.Cog):
embed.add_field( embed.add_field(
name=f"⏳ Pending Transactions ({total_pending} total)", name=f"⏳ Pending Transactions ({total_pending} total)",
value="\n".join(pending_lines), value="\n".join(pending_lines),
inline=False inline=False,
) )
# Add summary only on first page # Add summary only on first page
@ -314,8 +322,12 @@ class TransactionCommands(commands.Cog):
embed.add_field( embed.add_field(
name="Summary", name="Summary",
value=", ".join(status_text) if status_text else "No active transactions", value=(
inline=True ", ".join(status_text)
if status_text
else "No active transactions"
),
inline=True,
) )
pages.append(embed) pages.append(embed)
@ -324,16 +336,16 @@ class TransactionCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title=f"📋 Transaction Status - {team.abbrev}", title=f"📋 Transaction Status - {team.abbrev}",
description=f"{team.lname} • Season {get_config().sba_season}", description=f"{team.lname} • Season {get_config().sba_season}",
color=EmbedColors.INFO color=EmbedColors.INFO,
) )
if hasattr(team, 'thumbnail') and team.thumbnail: if hasattr(team, "thumbnail") and team.thumbnail:
embed.set_thumbnail(url=team.thumbnail) embed.set_thumbnail(url=team.thumbnail)
embed.add_field( embed.add_field(
name="⏳ Pending Transactions", name="⏳ Pending Transactions",
value="No pending transactions", value="No pending transactions",
inline=False inline=False,
) )
total_frozen = len(frozen_transactions) total_frozen = len(frozen_transactions)
@ -343,8 +355,10 @@ class TransactionCommands(commands.Cog):
embed.add_field( embed.add_field(
name="Summary", name="Summary",
value=", ".join(status_text) if status_text else "No active transactions", value=(
inline=True ", ".join(status_text) if status_text else "No active transactions"
),
inline=True,
) )
pages.append(embed) pages.append(embed)
@ -354,10 +368,10 @@ class TransactionCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title=f"📋 Transaction Status - {team.abbrev}", title=f"📋 Transaction Status - {team.abbrev}",
description=f"{team.lname} • Season {get_config().sba_season}", description=f"{team.lname} • Season {get_config().sba_season}",
color=EmbedColors.INFO color=EmbedColors.INFO,
) )
if hasattr(team, 'thumbnail') and team.thumbnail: if hasattr(team, "thumbnail") and team.thumbnail:
embed.set_thumbnail(url=team.thumbnail) embed.set_thumbnail(url=team.thumbnail)
frozen_lines = [format_transaction(tx) for tx in frozen_transactions] frozen_lines = [format_transaction(tx) for tx in frozen_transactions]
@ -365,7 +379,7 @@ class TransactionCommands(commands.Cog):
embed.add_field( embed.add_field(
name=f"❄️ Scheduled for Processing ({len(frozen_transactions)} total)", name=f"❄️ Scheduled for Processing ({len(frozen_transactions)} total)",
value="\n".join(frozen_lines), value="\n".join(frozen_lines),
inline=False inline=False,
) )
pages.append(embed) pages.append(embed)
@ -375,18 +389,20 @@ class TransactionCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title=f"📋 Transaction Status - {team.abbrev}", title=f"📋 Transaction Status - {team.abbrev}",
description=f"{team.lname} • Season {get_config().sba_season}", description=f"{team.lname} • Season {get_config().sba_season}",
color=EmbedColors.INFO color=EmbedColors.INFO,
) )
if hasattr(team, 'thumbnail') and team.thumbnail: if hasattr(team, "thumbnail") and team.thumbnail:
embed.set_thumbnail(url=team.thumbnail) embed.set_thumbnail(url=team.thumbnail)
processed_lines = [format_transaction(tx) for tx in processed_transactions[-20:]] # Last 20 processed_lines = [
format_transaction(tx) for tx in processed_transactions[-20:]
] # Last 20
embed.add_field( embed.add_field(
name=f"✅ Recently Processed ({len(processed_transactions[-20:])} shown)", name=f"✅ Recently Processed ({len(processed_transactions[-20:])} shown)",
value="\n".join(processed_lines), value="\n".join(processed_lines),
inline=False inline=False,
) )
pages.append(embed) pages.append(embed)
@ -396,18 +412,20 @@ class TransactionCommands(commands.Cog):
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title=f"📋 Transaction Status - {team.abbrev}", title=f"📋 Transaction Status - {team.abbrev}",
description=f"{team.lname} • Season {get_config().sba_season}", description=f"{team.lname} • Season {get_config().sba_season}",
color=EmbedColors.INFO color=EmbedColors.INFO,
) )
if hasattr(team, 'thumbnail') and team.thumbnail: if hasattr(team, "thumbnail") and team.thumbnail:
embed.set_thumbnail(url=team.thumbnail) embed.set_thumbnail(url=team.thumbnail)
cancelled_lines = [format_transaction(tx) for tx in cancelled_transactions[-20:]] # Last 20 cancelled_lines = [
format_transaction(tx) for tx in cancelled_transactions[-20:]
] # Last 20
embed.add_field( embed.add_field(
name=f"❌ Cancelled Transactions ({len(cancelled_transactions[-20:])} shown)", name=f"❌ Cancelled Transactions ({len(cancelled_transactions[-20:])} shown)",
value="\n".join(cancelled_lines), value="\n".join(cancelled_lines),
inline=False inline=False,
) )
pages.append(embed) pages.append(embed)
@ -417,111 +435,106 @@ class TransactionCommands(commands.Cog):
page.set_footer(text="Use /legal to check roster legality") page.set_footer(text="Use /legal to check roster legality")
return pages return pages
async def _create_legal_embed( async def _create_legal_embed(
self, self, team, current_roster, next_roster, current_validation, next_validation
team,
current_roster,
next_roster,
current_validation,
next_validation
) -> discord.Embed: ) -> discord.Embed:
"""Create embed showing roster legality check results.""" """Create embed showing roster legality check results."""
# Determine overall status # Determine overall status
overall_legal = True overall_legal = True
if current_validation and not current_validation.is_legal: if current_validation and not current_validation.is_legal:
overall_legal = False overall_legal = False
if next_validation and not next_validation.is_legal: if next_validation and not next_validation.is_legal:
overall_legal = False overall_legal = False
status_emoji = "" if overall_legal else "" status_emoji = "" if overall_legal else ""
embed_color = EmbedColors.SUCCESS if overall_legal else EmbedColors.ERROR embed_color = EmbedColors.SUCCESS if overall_legal else EmbedColors.ERROR
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title=f"{status_emoji} Roster Check - {team.abbrev}", title=f"{status_emoji} Roster Check - {team.abbrev}",
description=f"{team.lname} • Season {get_config().sba_season}", description=f"{team.lname} • Season {get_config().sba_season}",
color=embed_color color=embed_color,
) )
# Add team thumbnail if available # Add team thumbnail if available
if hasattr(team, 'thumbnail') and team.thumbnail: if hasattr(team, "thumbnail") and team.thumbnail:
embed.set_thumbnail(url=team.thumbnail) embed.set_thumbnail(url=team.thumbnail)
# Current week roster # Current week roster
if current_roster and current_validation: if current_roster and current_validation:
current_lines = [] current_lines = []
current_lines.append(f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL") current_lines.append(
f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL"
)
current_lines.append(f"**sWAR:** {current_validation.total_sWAR:.2f}") current_lines.append(f"**sWAR:** {current_validation.total_sWAR:.2f}")
if current_validation.errors: if current_validation.errors:
current_lines.append(f"**❌ Errors:** {len(current_validation.errors)}") current_lines.append(f"**❌ Errors:** {len(current_validation.errors)}")
for error in current_validation.errors[:3]: # Show first 3 errors for error in current_validation.errors[:3]: # Show first 3 errors
current_lines.append(f"{error}") current_lines.append(f"{error}")
if current_validation.warnings: if current_validation.warnings:
current_lines.append(f"**⚠️ Warnings:** {len(current_validation.warnings)}") current_lines.append(
f"**⚠️ Warnings:** {len(current_validation.warnings)}"
)
for warning in current_validation.warnings[:2]: # Show first 2 warnings for warning in current_validation.warnings[:2]: # Show first 2 warnings
current_lines.append(f"{warning}") current_lines.append(f"{warning}")
embed.add_field( embed.add_field(
name=f"{current_validation.status_emoji} Current Week", name=f"{current_validation.status_emoji} Current Week",
value="\n".join(current_lines), value="\n".join(current_lines),
inline=True inline=True,
) )
else: else:
embed.add_field( embed.add_field(
name="❓ Current Week", name="❓ Current Week", value="Roster data not available", inline=True
value="Roster data not available",
inline=True
) )
# Next week roster # Next week roster
if next_roster and next_validation: if next_roster and next_validation:
next_lines = [] next_lines = []
next_lines.append(f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL") next_lines.append(
f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL"
)
next_lines.append(f"**sWAR:** {next_validation.total_sWAR:.2f}") next_lines.append(f"**sWAR:** {next_validation.total_sWAR:.2f}")
if next_validation.errors: if next_validation.errors:
next_lines.append(f"**❌ Errors:** {len(next_validation.errors)}") next_lines.append(f"**❌ Errors:** {len(next_validation.errors)}")
for error in next_validation.errors[:3]: # Show first 3 errors for error in next_validation.errors[:3]: # Show first 3 errors
next_lines.append(f"{error}") next_lines.append(f"{error}")
if next_validation.warnings: if next_validation.warnings:
next_lines.append(f"**⚠️ Warnings:** {len(next_validation.warnings)}") next_lines.append(f"**⚠️ Warnings:** {len(next_validation.warnings)}")
for warning in next_validation.warnings[:2]: # Show first 2 warnings for warning in next_validation.warnings[:2]: # Show first 2 warnings
next_lines.append(f"{warning}") next_lines.append(f"{warning}")
embed.add_field( embed.add_field(
name=f"{next_validation.status_emoji} Next Week", name=f"{next_validation.status_emoji} Next Week",
value="\n".join(next_lines), value="\n".join(next_lines),
inline=True inline=True,
) )
else: else:
embed.add_field( embed.add_field(
name="❓ Next Week", name="❓ Next Week", value="Roster data not available", inline=True
value="Roster data not available",
inline=True
) )
# Overall status # Overall status
if overall_legal: if overall_legal:
embed.add_field( embed.add_field(
name="Overall Status", name="Overall Status", value="✅ All rosters are legal", inline=False
value="✅ All rosters are legal",
inline=False
) )
else: else:
embed.add_field( embed.add_field(
name="Overall Status", name="Overall Status",
value="❌ Roster violations found - please review and correct", value="❌ Roster violations found - please review and correct",
inline=False inline=False,
) )
embed.set_footer(text="Roster validation based on current league rules") embed.set_footer(text="Roster validation based on current league rules")
return embed return embed
async def setup(bot: commands.Bot): async def setup(bot: commands.Bot):
"""Load the transaction commands cog.""" """Load the transaction commands cog."""
await bot.add_cog(TransactionCommands(bot)) await bot.add_cog(TransactionCommands(bot))

View File

@ -4,6 +4,7 @@ Statistics service for Discord Bot v2.0
Handles batting and pitching statistics retrieval and processing. Handles batting and pitching statistics retrieval and processing.
""" """
import asyncio
import logging import logging
from typing import Optional from typing import Optional
@ -144,11 +145,10 @@ class StatsService:
""" """
try: try:
# Get both types of stats concurrently # Get both types of stats concurrently
batting_task = self.get_batting_stats(player_id, season) batting_stats, pitching_stats = await asyncio.gather(
pitching_task = self.get_pitching_stats(player_id, season) self.get_batting_stats(player_id, season),
self.get_pitching_stats(player_id, season),
batting_stats = await batting_task )
pitching_stats = await pitching_task
logger.debug( logger.debug(
f"Retrieved stats for player {player_id}: " f"Retrieved stats for player {player_id}: "

View File

@ -0,0 +1,111 @@
"""
Tests for StatsService
Validates stats service functionality including concurrent stat retrieval
and error handling in get_player_stats().
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from services.stats_service import StatsService
class TestStatsServiceGetPlayerStats:
"""Test StatsService.get_player_stats() concurrent retrieval."""
@pytest.fixture
def service(self):
"""Create a fresh StatsService instance for testing."""
return StatsService()
@pytest.fixture
def mock_batting_stats(self):
"""Create a mock BattingStats object."""
stats = MagicMock()
stats.avg = 0.300
return stats
@pytest.fixture
def mock_pitching_stats(self):
"""Create a mock PitchingStats object."""
stats = MagicMock()
stats.era = 3.50
return stats
@pytest.mark.asyncio
async def test_both_stats_returned(
self, service, mock_batting_stats, mock_pitching_stats
):
"""When both batting and pitching stats exist, both are returned.
Verifies that get_player_stats returns a tuple of (batting, pitching)
when both stat types are available for the player.
"""
service.get_batting_stats = AsyncMock(return_value=mock_batting_stats)
service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats)
batting, pitching = await service.get_player_stats(player_id=100, season=12)
assert batting is mock_batting_stats
assert pitching is mock_pitching_stats
service.get_batting_stats.assert_called_once_with(100, 12)
service.get_pitching_stats.assert_called_once_with(100, 12)
@pytest.mark.asyncio
async def test_batting_only(self, service, mock_batting_stats):
"""When only batting stats exist, pitching is None.
Covers the case of a position player with no pitching record.
"""
service.get_batting_stats = AsyncMock(return_value=mock_batting_stats)
service.get_pitching_stats = AsyncMock(return_value=None)
batting, pitching = await service.get_player_stats(player_id=200, season=12)
assert batting is mock_batting_stats
assert pitching is None
@pytest.mark.asyncio
async def test_pitching_only(self, service, mock_pitching_stats):
"""When only pitching stats exist, batting is None.
Covers the case of a pitcher with no batting record.
"""
service.get_batting_stats = AsyncMock(return_value=None)
service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats)
batting, pitching = await service.get_player_stats(player_id=300, season=12)
assert batting is None
assert pitching is mock_pitching_stats
@pytest.mark.asyncio
async def test_no_stats_found(self, service):
"""When no stats exist for the player, both are None.
Covers the case where a player has no stats for the given season
(e.g., didn't play).
"""
service.get_batting_stats = AsyncMock(return_value=None)
service.get_pitching_stats = AsyncMock(return_value=None)
batting, pitching = await service.get_player_stats(player_id=400, season=12)
assert batting is None
assert pitching is None
@pytest.mark.asyncio
async def test_exception_returns_none_tuple(self, service):
"""When an exception occurs, (None, None) is returned.
The get_player_stats method wraps both calls in a try/except and
returns (None, None) on any error, ensuring callers always get a tuple.
"""
service.get_batting_stats = AsyncMock(side_effect=RuntimeError("API down"))
service.get_pitching_stats = AsyncMock(return_value=None)
batting, pitching = await service.get_player_stats(player_id=500, season=12)
assert batting is None
assert pitching is None