Merge pull request 'perf: replace sequential awaits with asyncio.gather()' (#102) from fix/sequential-awaits into next-release
All checks were successful
Build Docker Image / build (push) Successful in 1m25s
All checks were successful
Build Docker Image / build (push) Successful in 1m25s
Reviewed-on: #102
This commit is contained in:
commit
6d3c7305ce
@ -3,6 +3,7 @@ Transaction Management Commands
|
||||
|
||||
Core transaction commands for roster management and transaction tracking.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
|
||||
@ -21,6 +22,7 @@ from views.base import PaginationView
|
||||
from services.transaction_service import transaction_service
|
||||
from services.roster_service import roster_service
|
||||
from services.team_service import team_service
|
||||
|
||||
# No longer need TransactionStatus enum
|
||||
|
||||
|
||||
@ -34,25 +36,28 @@ class TransactionPaginationView(PaginationView):
|
||||
all_transactions: list,
|
||||
user_id: int,
|
||||
timeout: float = 300.0,
|
||||
show_page_numbers: bool = True
|
||||
show_page_numbers: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
pages=pages,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
show_page_numbers=show_page_numbers
|
||||
show_page_numbers=show_page_numbers,
|
||||
)
|
||||
self.all_transactions = all_transactions
|
||||
|
||||
@discord.ui.button(label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1)
|
||||
async def show_move_ids(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
@discord.ui.button(
|
||||
label="Show Move IDs", style=discord.ButtonStyle.secondary, emoji="🔍", row=1
|
||||
)
|
||||
async def show_move_ids(
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
"""Show all move IDs in an ephemeral message."""
|
||||
self.increment_interaction_count()
|
||||
|
||||
if not self.all_transactions:
|
||||
await interaction.response.send_message(
|
||||
"No transactions to show.",
|
||||
ephemeral=True
|
||||
"No transactions to show.", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
@ -85,8 +90,7 @@ class TransactionPaginationView(PaginationView):
|
||||
# Send the messages
|
||||
if not messages:
|
||||
await interaction.response.send_message(
|
||||
"No transactions to display.",
|
||||
ephemeral=True
|
||||
"No transactions to display.", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
@ -101,14 +105,13 @@ class TransactionPaginationView(PaginationView):
|
||||
|
||||
class TransactionCommands(commands.Cog):
|
||||
"""Transaction command handlers for roster management."""
|
||||
|
||||
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
self.logger = get_contextual_logger(f'{__name__}.TransactionCommands')
|
||||
|
||||
self.logger = get_contextual_logger(f"{__name__}.TransactionCommands")
|
||||
|
||||
@app_commands.command(
|
||||
name="mymoves",
|
||||
description="View your pending and scheduled transactions"
|
||||
name="mymoves", description="View your pending and scheduled transactions"
|
||||
)
|
||||
@app_commands.describe(
|
||||
show_cancelled="Include cancelled transactions in the display (default: False)"
|
||||
@ -116,39 +119,45 @@ class TransactionCommands(commands.Cog):
|
||||
@requires_team()
|
||||
@logged_command("/mymoves")
|
||||
async def my_moves(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
show_cancelled: bool = False
|
||||
self, interaction: discord.Interaction, show_cancelled: bool = False
|
||||
):
|
||||
"""Display user's transaction status and history."""
|
||||
await interaction.response.defer()
|
||||
|
||||
|
||||
# Get user's team
|
||||
team = await get_user_major_league_team(interaction.user.id, get_config().sba_season)
|
||||
|
||||
team = await get_user_major_league_team(
|
||||
interaction.user.id, get_config().sba_season
|
||||
)
|
||||
|
||||
if not team:
|
||||
await interaction.followup.send(
|
||||
"❌ You don't appear to own a team in the current season.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# Get transactions in parallel
|
||||
pending_task = transaction_service.get_pending_transactions(team.abbrev, get_config().sba_season)
|
||||
frozen_task = transaction_service.get_frozen_transactions(team.abbrev, get_config().sba_season)
|
||||
processed_task = transaction_service.get_processed_transactions(team.abbrev, get_config().sba_season)
|
||||
|
||||
pending_transactions = await pending_task
|
||||
frozen_transactions = await frozen_task
|
||||
processed_transactions = await processed_task
|
||||
|
||||
(
|
||||
pending_transactions,
|
||||
frozen_transactions,
|
||||
processed_transactions,
|
||||
) = await asyncio.gather(
|
||||
transaction_service.get_pending_transactions(
|
||||
team.abbrev, get_config().sba_season
|
||||
),
|
||||
transaction_service.get_frozen_transactions(
|
||||
team.abbrev, get_config().sba_season
|
||||
),
|
||||
transaction_service.get_processed_transactions(
|
||||
team.abbrev, get_config().sba_season
|
||||
),
|
||||
)
|
||||
|
||||
# Get cancelled if requested
|
||||
cancelled_transactions = []
|
||||
if show_cancelled:
|
||||
cancelled_transactions = await transaction_service.get_team_transactions(
|
||||
team.abbrev,
|
||||
get_config().sba_season,
|
||||
cancelled=True
|
||||
team.abbrev, get_config().sba_season, cancelled=True
|
||||
)
|
||||
|
||||
pages = self._create_my_moves_pages(
|
||||
@ -156,15 +165,15 @@ class TransactionCommands(commands.Cog):
|
||||
pending_transactions,
|
||||
frozen_transactions,
|
||||
processed_transactions,
|
||||
cancelled_transactions
|
||||
cancelled_transactions,
|
||||
)
|
||||
|
||||
# Collect all transactions for the "Show Move IDs" button
|
||||
all_transactions = (
|
||||
pending_transactions +
|
||||
frozen_transactions +
|
||||
processed_transactions +
|
||||
cancelled_transactions
|
||||
pending_transactions
|
||||
+ frozen_transactions
|
||||
+ processed_transactions
|
||||
+ cancelled_transactions
|
||||
)
|
||||
|
||||
# If only one page and no transactions, send without any buttons
|
||||
@ -177,93 +186,90 @@ class TransactionCommands(commands.Cog):
|
||||
all_transactions=all_transactions,
|
||||
user_id=interaction.user.id,
|
||||
timeout=300.0,
|
||||
show_page_numbers=True
|
||||
show_page_numbers=True,
|
||||
)
|
||||
await interaction.followup.send(embed=view.get_current_embed(), view=view)
|
||||
|
||||
|
||||
@app_commands.command(
|
||||
name="legal",
|
||||
description="Check roster legality for current and next week"
|
||||
)
|
||||
@app_commands.describe(
|
||||
team="Team abbreviation to check (defaults to your team)"
|
||||
name="legal", description="Check roster legality for current and next week"
|
||||
)
|
||||
@app_commands.describe(team="Team abbreviation to check (defaults to your team)")
|
||||
@requires_team()
|
||||
@logged_command("/legal")
|
||||
async def legal(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
team: Optional[str] = None
|
||||
):
|
||||
async def legal(self, interaction: discord.Interaction, team: Optional[str] = None):
|
||||
"""Check roster legality and display detailed validation results."""
|
||||
await interaction.response.defer()
|
||||
|
||||
|
||||
# Get target team
|
||||
if team:
|
||||
target_team = await team_service.get_team_by_abbrev(team.upper(), get_config().sba_season)
|
||||
target_team = await team_service.get_team_by_abbrev(
|
||||
team.upper(), get_config().sba_season
|
||||
)
|
||||
if not target_team:
|
||||
await interaction.followup.send(
|
||||
f"❌ Could not find team '{team}' in season {get_config().sba_season}.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Get user's team
|
||||
user_teams = await team_service.get_teams_by_owner(interaction.user.id, get_config().sba_season)
|
||||
user_teams = await team_service.get_teams_by_owner(
|
||||
interaction.user.id, get_config().sba_season
|
||||
)
|
||||
if not user_teams:
|
||||
await interaction.followup.send(
|
||||
"❌ You don't appear to own a team. Please specify a team abbreviation.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
target_team = user_teams[0]
|
||||
|
||||
|
||||
# Get rosters in parallel
|
||||
current_roster, next_roster = await asyncio.gather(
|
||||
roster_service.get_current_roster(target_team.id),
|
||||
roster_service.get_next_roster(target_team.id)
|
||||
roster_service.get_next_roster(target_team.id),
|
||||
)
|
||||
|
||||
|
||||
if not current_roster and not next_roster:
|
||||
await interaction.followup.send(
|
||||
f"❌ Could not retrieve roster data for {target_team.abbrev}.",
|
||||
ephemeral=True
|
||||
ephemeral=True,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# Validate rosters in parallel
|
||||
validation_tasks = []
|
||||
if current_roster:
|
||||
validation_tasks.append(roster_service.validate_roster(current_roster))
|
||||
else:
|
||||
validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task
|
||||
|
||||
|
||||
if next_roster:
|
||||
validation_tasks.append(roster_service.validate_roster(next_roster))
|
||||
else:
|
||||
validation_tasks.append(asyncio.create_task(asyncio.sleep(0))) # Dummy task
|
||||
|
||||
|
||||
validation_results = await asyncio.gather(*validation_tasks)
|
||||
current_validation = validation_results[0] if current_roster else None
|
||||
next_validation = validation_results[1] if next_roster else None
|
||||
|
||||
|
||||
embed = await self._create_legal_embed(
|
||||
target_team,
|
||||
current_roster,
|
||||
next_roster,
|
||||
next_roster,
|
||||
current_validation,
|
||||
next_validation
|
||||
next_validation,
|
||||
)
|
||||
|
||||
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
|
||||
def _create_my_moves_pages(
|
||||
self,
|
||||
team,
|
||||
pending_transactions,
|
||||
frozen_transactions,
|
||||
processed_transactions,
|
||||
cancelled_transactions
|
||||
cancelled_transactions,
|
||||
) -> list[discord.Embed]:
|
||||
"""Create paginated embeds showing user's transaction status."""
|
||||
|
||||
@ -277,7 +283,9 @@ class TransactionCommands(commands.Cog):
|
||||
# Page 1: Summary + Pending Transactions
|
||||
if pending_transactions:
|
||||
total_pending = len(pending_transactions)
|
||||
total_pages = (total_pending + transactions_per_page - 1) // transactions_per_page
|
||||
total_pages = (
|
||||
total_pending + transactions_per_page - 1
|
||||
) // transactions_per_page
|
||||
|
||||
for page_num in range(total_pages):
|
||||
start_idx = page_num * transactions_per_page
|
||||
@ -287,11 +295,11 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
# Add team thumbnail if available
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
# Pending transactions for this page
|
||||
@ -300,7 +308,7 @@ class TransactionCommands(commands.Cog):
|
||||
embed.add_field(
|
||||
name=f"⏳ Pending Transactions ({total_pending} total)",
|
||||
value="\n".join(pending_lines),
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
# Add summary only on first page
|
||||
@ -314,8 +322,12 @@ class TransactionCommands(commands.Cog):
|
||||
|
||||
embed.add_field(
|
||||
name="Summary",
|
||||
value=", ".join(status_text) if status_text else "No active transactions",
|
||||
inline=True
|
||||
value=(
|
||||
", ".join(status_text)
|
||||
if status_text
|
||||
else "No active transactions"
|
||||
),
|
||||
inline=True,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -324,16 +336,16 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
embed.add_field(
|
||||
name="⏳ Pending Transactions",
|
||||
value="No pending transactions",
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
total_frozen = len(frozen_transactions)
|
||||
@ -343,8 +355,10 @@ class TransactionCommands(commands.Cog):
|
||||
|
||||
embed.add_field(
|
||||
name="Summary",
|
||||
value=", ".join(status_text) if status_text else "No active transactions",
|
||||
inline=True
|
||||
value=(
|
||||
", ".join(status_text) if status_text else "No active transactions"
|
||||
),
|
||||
inline=True,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -354,10 +368,10 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
frozen_lines = [format_transaction(tx) for tx in frozen_transactions]
|
||||
@ -365,7 +379,7 @@ class TransactionCommands(commands.Cog):
|
||||
embed.add_field(
|
||||
name=f"❄️ Scheduled for Processing ({len(frozen_transactions)} total)",
|
||||
value="\n".join(frozen_lines),
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -375,18 +389,20 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
processed_lines = [format_transaction(tx) for tx in processed_transactions[-20:]] # Last 20
|
||||
processed_lines = [
|
||||
format_transaction(tx) for tx in processed_transactions[-20:]
|
||||
] # Last 20
|
||||
|
||||
embed.add_field(
|
||||
name=f"✅ Recently Processed ({len(processed_transactions[-20:])} shown)",
|
||||
value="\n".join(processed_lines),
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -396,18 +412,20 @@ class TransactionCommands(commands.Cog):
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"📋 Transaction Status - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=EmbedColors.INFO
|
||||
color=EmbedColors.INFO,
|
||||
)
|
||||
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
cancelled_lines = [format_transaction(tx) for tx in cancelled_transactions[-20:]] # Last 20
|
||||
cancelled_lines = [
|
||||
format_transaction(tx) for tx in cancelled_transactions[-20:]
|
||||
] # Last 20
|
||||
|
||||
embed.add_field(
|
||||
name=f"❌ Cancelled Transactions ({len(cancelled_transactions[-20:])} shown)",
|
||||
value="\n".join(cancelled_lines),
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
pages.append(embed)
|
||||
@ -417,111 +435,106 @@ class TransactionCommands(commands.Cog):
|
||||
page.set_footer(text="Use /legal to check roster legality")
|
||||
|
||||
return pages
|
||||
|
||||
|
||||
async def _create_legal_embed(
|
||||
self,
|
||||
team,
|
||||
current_roster,
|
||||
next_roster,
|
||||
current_validation,
|
||||
next_validation
|
||||
self, team, current_roster, next_roster, current_validation, next_validation
|
||||
) -> discord.Embed:
|
||||
"""Create embed showing roster legality check results."""
|
||||
|
||||
|
||||
# Determine overall status
|
||||
overall_legal = True
|
||||
if current_validation and not current_validation.is_legal:
|
||||
overall_legal = False
|
||||
if next_validation and not next_validation.is_legal:
|
||||
overall_legal = False
|
||||
|
||||
|
||||
status_emoji = "✅" if overall_legal else "❌"
|
||||
embed_color = EmbedColors.SUCCESS if overall_legal else EmbedColors.ERROR
|
||||
|
||||
|
||||
embed = EmbedTemplate.create_base_embed(
|
||||
title=f"{status_emoji} Roster Check - {team.abbrev}",
|
||||
description=f"{team.lname} • Season {get_config().sba_season}",
|
||||
color=embed_color
|
||||
color=embed_color,
|
||||
)
|
||||
|
||||
|
||||
# Add team thumbnail if available
|
||||
if hasattr(team, 'thumbnail') and team.thumbnail:
|
||||
if hasattr(team, "thumbnail") and team.thumbnail:
|
||||
embed.set_thumbnail(url=team.thumbnail)
|
||||
|
||||
|
||||
# Current week roster
|
||||
if current_roster and current_validation:
|
||||
current_lines = []
|
||||
current_lines.append(f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL")
|
||||
current_lines.append(
|
||||
f"**Players:** {current_validation.active_players} active, {current_validation.il_players} IL"
|
||||
)
|
||||
current_lines.append(f"**sWAR:** {current_validation.total_sWAR:.2f}")
|
||||
|
||||
|
||||
if current_validation.errors:
|
||||
current_lines.append(f"**❌ Errors:** {len(current_validation.errors)}")
|
||||
for error in current_validation.errors[:3]: # Show first 3 errors
|
||||
current_lines.append(f"• {error}")
|
||||
|
||||
|
||||
if current_validation.warnings:
|
||||
current_lines.append(f"**⚠️ Warnings:** {len(current_validation.warnings)}")
|
||||
current_lines.append(
|
||||
f"**⚠️ Warnings:** {len(current_validation.warnings)}"
|
||||
)
|
||||
for warning in current_validation.warnings[:2]: # Show first 2 warnings
|
||||
current_lines.append(f"• {warning}")
|
||||
|
||||
|
||||
embed.add_field(
|
||||
name=f"{current_validation.status_emoji} Current Week",
|
||||
value="\n".join(current_lines),
|
||||
inline=True
|
||||
inline=True,
|
||||
)
|
||||
else:
|
||||
embed.add_field(
|
||||
name="❓ Current Week",
|
||||
value="Roster data not available",
|
||||
inline=True
|
||||
name="❓ Current Week", value="Roster data not available", inline=True
|
||||
)
|
||||
|
||||
# Next week roster
|
||||
|
||||
# Next week roster
|
||||
if next_roster and next_validation:
|
||||
next_lines = []
|
||||
next_lines.append(f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL")
|
||||
next_lines.append(
|
||||
f"**Players:** {next_validation.active_players} active, {next_validation.il_players} IL"
|
||||
)
|
||||
next_lines.append(f"**sWAR:** {next_validation.total_sWAR:.2f}")
|
||||
|
||||
|
||||
if next_validation.errors:
|
||||
next_lines.append(f"**❌ Errors:** {len(next_validation.errors)}")
|
||||
for error in next_validation.errors[:3]: # Show first 3 errors
|
||||
next_lines.append(f"• {error}")
|
||||
|
||||
|
||||
if next_validation.warnings:
|
||||
next_lines.append(f"**⚠️ Warnings:** {len(next_validation.warnings)}")
|
||||
for warning in next_validation.warnings[:2]: # Show first 2 warnings
|
||||
next_lines.append(f"• {warning}")
|
||||
|
||||
|
||||
embed.add_field(
|
||||
name=f"{next_validation.status_emoji} Next Week",
|
||||
value="\n".join(next_lines),
|
||||
inline=True
|
||||
inline=True,
|
||||
)
|
||||
else:
|
||||
embed.add_field(
|
||||
name="❓ Next Week",
|
||||
value="Roster data not available",
|
||||
inline=True
|
||||
name="❓ Next Week", value="Roster data not available", inline=True
|
||||
)
|
||||
|
||||
|
||||
# Overall status
|
||||
if overall_legal:
|
||||
embed.add_field(
|
||||
name="Overall Status",
|
||||
value="✅ All rosters are legal",
|
||||
inline=False
|
||||
name="Overall Status", value="✅ All rosters are legal", inline=False
|
||||
)
|
||||
else:
|
||||
embed.add_field(
|
||||
name="Overall Status",
|
||||
name="Overall Status",
|
||||
value="❌ Roster violations found - please review and correct",
|
||||
inline=False
|
||||
inline=False,
|
||||
)
|
||||
|
||||
|
||||
embed.set_footer(text="Roster validation based on current league rules")
|
||||
return embed
|
||||
|
||||
|
||||
async def setup(bot: commands.Bot):
|
||||
"""Load the transaction commands cog."""
|
||||
await bot.add_cog(TransactionCommands(bot))
|
||||
await bot.add_cog(TransactionCommands(bot))
|
||||
|
||||
@ -4,6 +4,7 @@ Statistics service for Discord Bot v2.0
|
||||
Handles batting and pitching statistics retrieval and processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@ -144,11 +145,10 @@ class StatsService:
|
||||
"""
|
||||
try:
|
||||
# Get both types of stats concurrently
|
||||
batting_task = self.get_batting_stats(player_id, season)
|
||||
pitching_task = self.get_pitching_stats(player_id, season)
|
||||
|
||||
batting_stats = await batting_task
|
||||
pitching_stats = await pitching_task
|
||||
batting_stats, pitching_stats = await asyncio.gather(
|
||||
self.get_batting_stats(player_id, season),
|
||||
self.get_pitching_stats(player_id, season),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Retrieved stats for player {player_id}: "
|
||||
|
||||
111
tests/test_services_stats.py
Normal file
111
tests/test_services_stats.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""
|
||||
Tests for StatsService
|
||||
|
||||
Validates stats service functionality including concurrent stat retrieval
|
||||
and error handling in get_player_stats().
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from services.stats_service import StatsService
|
||||
|
||||
|
||||
class TestStatsServiceGetPlayerStats:
|
||||
"""Test StatsService.get_player_stats() concurrent retrieval."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
"""Create a fresh StatsService instance for testing."""
|
||||
return StatsService()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_batting_stats(self):
|
||||
"""Create a mock BattingStats object."""
|
||||
stats = MagicMock()
|
||||
stats.avg = 0.300
|
||||
return stats
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pitching_stats(self):
|
||||
"""Create a mock PitchingStats object."""
|
||||
stats = MagicMock()
|
||||
stats.era = 3.50
|
||||
return stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_both_stats_returned(
|
||||
self, service, mock_batting_stats, mock_pitching_stats
|
||||
):
|
||||
"""When both batting and pitching stats exist, both are returned.
|
||||
|
||||
Verifies that get_player_stats returns a tuple of (batting, pitching)
|
||||
when both stat types are available for the player.
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(return_value=mock_batting_stats)
|
||||
service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=100, season=12)
|
||||
|
||||
assert batting is mock_batting_stats
|
||||
assert pitching is mock_pitching_stats
|
||||
service.get_batting_stats.assert_called_once_with(100, 12)
|
||||
service.get_pitching_stats.assert_called_once_with(100, 12)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batting_only(self, service, mock_batting_stats):
|
||||
"""When only batting stats exist, pitching is None.
|
||||
|
||||
Covers the case of a position player with no pitching record.
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(return_value=mock_batting_stats)
|
||||
service.get_pitching_stats = AsyncMock(return_value=None)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=200, season=12)
|
||||
|
||||
assert batting is mock_batting_stats
|
||||
assert pitching is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pitching_only(self, service, mock_pitching_stats):
|
||||
"""When only pitching stats exist, batting is None.
|
||||
|
||||
Covers the case of a pitcher with no batting record.
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(return_value=None)
|
||||
service.get_pitching_stats = AsyncMock(return_value=mock_pitching_stats)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=300, season=12)
|
||||
|
||||
assert batting is None
|
||||
assert pitching is mock_pitching_stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_stats_found(self, service):
|
||||
"""When no stats exist for the player, both are None.
|
||||
|
||||
Covers the case where a player has no stats for the given season
|
||||
(e.g., didn't play).
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(return_value=None)
|
||||
service.get_pitching_stats = AsyncMock(return_value=None)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=400, season=12)
|
||||
|
||||
assert batting is None
|
||||
assert pitching is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_returns_none_tuple(self, service):
|
||||
"""When an exception occurs, (None, None) is returned.
|
||||
|
||||
The get_player_stats method wraps both calls in a try/except and
|
||||
returns (None, None) on any error, ensuring callers always get a tuple.
|
||||
"""
|
||||
service.get_batting_stats = AsyncMock(side_effect=RuntimeError("API down"))
|
||||
service.get_pitching_stats = AsyncMock(return_value=None)
|
||||
|
||||
batting, pitching = await service.get_player_stats(player_id=500, season=12)
|
||||
|
||||
assert batting is None
|
||||
assert pitching is None
|
||||
Loading…
Reference in New Issue
Block a user