diff --git a/.dockerignore b/.dockerignore index 96f010e..fc9bc1c 100644 --- a/.dockerignore +++ b/.dockerignore @@ -34,6 +34,7 @@ README.md **/storage **/htmlcov *_legacy.py +cogs/gameplay_legacy.py pytest.ini CLAUDE.md **.db diff --git a/cogs/players_new/__init__.py b/cogs/players_new/__init__.py index 736f370..c3fdd76 100644 --- a/cogs/players_new/__init__.py +++ b/cogs/players_new/__init__.py @@ -5,11 +5,7 @@ from .shared_utils import get_ai_records, get_record_embed, get_record_embed_leg import logging from discord.ext import commands -__all__ = [ - 'get_ai_records', - 'get_record_embed', - 'get_record_embed_legacy' -] +__all__ = ["get_ai_records", "get_record_embed", "get_record_embed_legacy"] async def setup(bot): @@ -24,12 +20,14 @@ async def setup(bot): from .standings_records import StandingsRecords from .team_management import TeamManagement from .utility_commands import UtilityCommands - + from .evolution import Evolution + await bot.add_cog(Gauntlet(bot)) await bot.add_cog(Paperdex(bot)) await bot.add_cog(PlayerLookup(bot)) await bot.add_cog(StandingsRecords(bot)) await bot.add_cog(TeamManagement(bot)) await bot.add_cog(UtilityCommands(bot)) - - logging.getLogger('discord_app').info('All player cogs loaded successfully') \ No newline at end of file + await bot.add_cog(Evolution(bot)) + + logging.getLogger("discord_app").info("All player cogs loaded successfully") diff --git a/cogs/players_new/evolution.py b/cogs/players_new/evolution.py new file mode 100644 index 0000000..902fd24 --- /dev/null +++ b/cogs/players_new/evolution.py @@ -0,0 +1,206 @@ +# Evolution Status Module +# Displays evolution tier progress for a team's cards + +from discord.ext import commands +from discord import app_commands +import discord +from typing import Optional +import logging + +from api_calls import db_get +from helpers import get_team_by_owner, is_ephemeral_channel + +logger = logging.getLogger("discord_app") + +# Tier display names +TIER_NAMES = { + 0: "Unranked", + 1: "Initiate", + 2: "Rising", + 3: "Ascendant", + 4: "Evolved", +} + +# Formula shorthands by card_type +FORMULA_SHORTHANDS = { + "batter": "PA+TB×2", + "sp": "IP+K", + "rp": "IP+K", +} + + +def render_progress_bar( + current_value: float, next_threshold: float | None, width: int = 10 +) -> str: + """Render a text progress bar. + + Args: + current_value: Current formula value. + next_threshold: Threshold for the next tier. None if fully evolved. + width: Number of characters in the bar. + + Returns: + A string like '[========--] 120/149' or '[==========] FULLY EVOLVED'. + """ + if next_threshold is None or next_threshold <= 0: + return f"[{'=' * width}] FULLY EVOLVED" + + ratio = min(current_value / next_threshold, 1.0) + filled = round(ratio * width) + empty = width - filled + bar = f"[{'=' * filled}{'-' * empty}]" + return f"{bar} {int(current_value)}/{int(next_threshold)}" + + +def format_evo_entry(state: dict) -> str: + """Format a single evolution card state into a display line. + + Args: + state: Card state dict from the API with nested track info. + + Returns: + Formatted string like 'Mike Trout [========--] 120/149 (PA+TB×2) T1 → T2' + """ + track = state.get("track", {}) + card_type = track.get("card_type", "batter") + formula = FORMULA_SHORTHANDS.get(card_type, "???") + current_tier = state.get("current_tier", 0) + current_value = state.get("current_value", 0.0) + next_threshold = state.get("next_threshold") + fully_evolved = state.get("fully_evolved", False) + + bar = render_progress_bar(current_value, next_threshold) + + if fully_evolved: + tier_label = f"T4 — {TIER_NAMES[4]}" + else: + next_tier = current_tier + 1 + tier_label = ( + f"{TIER_NAMES.get(current_tier, '?')} → {TIER_NAMES.get(next_tier, '?')}" + ) + + return f"{bar} ({formula}) {tier_label}" + + +def is_close_to_tierup(state: dict, threshold_pct: float = 0.80) -> bool: + """Check if a card is close to its next tier-up. + + Args: + state: Card state dict from the API. + threshold_pct: Fraction of next_threshold that counts as "close". + + Returns: + True if current_value >= threshold_pct * next_threshold. + """ + next_threshold = state.get("next_threshold") + if next_threshold is None or next_threshold <= 0: + return False + current_value = state.get("current_value", 0.0) + return current_value >= threshold_pct * next_threshold + + +class Evolution(commands.Cog): + """Evolution tier progress for Paper Dynasty cards.""" + + def __init__(self, bot): + self.bot = bot + + evo_group = app_commands.Group(name="evo", description="Evolution commands") + + @evo_group.command(name="status", description="View your team's evolution progress") + @app_commands.describe( + type="Filter by card type (batter, sp, rp)", + tier="Filter by minimum tier (0-4)", + progress="Show only cards close to tier-up (type 'close')", + page="Page number (default: 1)", + ) + async def evo_status( + self, + interaction: discord.Interaction, + type: Optional[str] = None, + tier: Optional[int] = None, + progress: Optional[str] = None, + page: int = 1, + ): + await interaction.response.defer( + ephemeral=is_ephemeral_channel(interaction.channel) + ) + + # Look up the user's team + team = await get_team_by_owner(interaction.user.id) + if not team: + await interaction.followup.send( + "You don't have a team registered. Use `/register` first.", + ephemeral=True, + ) + return + + team_id = team.get("team_id") or team.get("id") + + # Build query params + params = [("page", page), ("per_page", 10)] + if type: + params.append(("card_type", type)) + if tier is not None: + params.append(("tier", tier)) + + try: + result = await db_get( + f"teams/{team_id}/evolutions", + params=params, + none_okay=True, + ) + except Exception: + logger.warning( + f"Failed to fetch evolution data for team {team_id}", + exc_info=True, + ) + await interaction.followup.send( + "Could not fetch evolution data. Please try again later.", + ephemeral=True, + ) + return + + if not result or not result.get("items"): + await interaction.followup.send( + "No evolution cards found for your team.", + ephemeral=True, + ) + return + + items = result["items"] + total_count = result.get("count", len(items)) + + # Apply "close" filter client-side + if progress and progress.lower() == "close": + items = [s for s in items if is_close_to_tierup(s)] + if not items: + await interaction.followup.send( + "No cards are close to a tier-up right now.", + ephemeral=True, + ) + return + + # Build embed + embed = discord.Embed( + title=f"Evolution Progress — {team.get('lname', 'Your Team')}", + color=discord.Color.purple(), + ) + + lines = [] + for state in items: + # Try to get player name from the state + player_name = state.get( + "player_name", f"Player #{state.get('player_id', '?')}" + ) + entry = format_evo_entry(state) + lines.append(f"**{player_name}**\n{entry}") + + embed.description = "\n\n".join(lines) if lines else "No evolution data." + + # Pagination footer + per_page = 10 + total_pages = max(1, (total_count + per_page - 1) // per_page) + embed.set_footer(text=f"Page {page}/{total_pages} • {total_count} total cards") + + await interaction.followup.send(embed=embed) diff --git a/command_logic/logic_gameplay.py b/command_logic/logic_gameplay.py index 55f2532..4b782e9 100644 --- a/command_logic/logic_gameplay.py +++ b/command_logic/logic_gameplay.py @@ -4242,6 +4242,24 @@ async def get_game_summary_embed( return game_embed +async def notify_tier_completion(channel: discord.TextChannel, tier_up: dict) -> None: + """Stub for WP-14: log evolution tier-up events. + + WP-14 will replace this with a full Discord embed notification. For now we + only log the event so that the WP-13 hook has a callable target and the + tier-up data is visible in the application log. + + Args: + channel: The Discord channel where the game was played. + tier_up: Dict from the evolution API, expected to contain at minimum + 'player_id', 'old_tier', and 'new_tier' keys. + """ + logger.info( + f"[WP-14 stub] notify_tier_completion called for channel={channel.id if channel else 'N/A'} " + f"tier_up={tier_up}" + ) + + async def complete_game( session: Session, interaction: discord.Interaction, @@ -4345,6 +4363,26 @@ async def complete_game( await roll_back(db_game["id"], plays=True, decisions=True) log_exception(e, msg="Error while posting game rewards") + # Post-game evolution processing (non-blocking) + # WP-13: update season stats then evaluate evolution milestones for all + # participating players. Wrapped in try/except so any failure here is + # non-fatal — the game is already saved and evolution will catch up on the + # next evaluate call. + try: + await db_post(f"season-stats/update-game/{db_game['id']}") + evo_result = await db_post(f"evolution/evaluate-game/{db_game['id']}") + if evo_result and evo_result.get("tier_ups"): + for tier_up in evo_result["tier_ups"]: + # WP-14 will implement full Discord notification; stub for now + logger.info( + f"Evolution tier-up for player {tier_up.get('player_id')}: " + f"{tier_up.get('old_tier')} -> {tier_up.get('new_tier')} " + f"(game {db_game['id']})" + ) + await notify_tier_completion(interaction.channel, tier_up) + except Exception as e: + logger.warning(f"Post-game evolution processing failed (non-fatal): {e}") + session.delete(this_play) session.commit() diff --git a/discord_ui/selectors.py b/discord_ui/selectors.py index 0278945..0bd3243 100644 --- a/discord_ui/selectors.py +++ b/discord_ui/selectors.py @@ -3,126 +3,148 @@ Discord Select UI components. Contains all Select classes for various team, cardset, and pack selections. """ + import logging import discord from typing import Literal, Optional from helpers.constants import ALL_MLB_TEAMS, IMAGES, normalize_franchise -logger = logging.getLogger('discord_app') +logger = logging.getLogger("discord_app") # Team name to ID mappings AL_TEAM_IDS = { - 'Baltimore Orioles': 3, - 'Boston Red Sox': 4, - 'Chicago White Sox': 6, - 'Cleveland Guardians': 8, - 'Detroit Tigers': 10, - 'Houston Astros': 11, - 'Kansas City Royals': 12, - 'Los Angeles Angels': 13, - 'Minnesota Twins': 17, - 'New York Yankees': 19, - 'Oakland Athletics': 20, - 'Athletics': 20, # Alias for post-Oakland move - 'Seattle Mariners': 24, - 'Tampa Bay Rays': 27, - 'Texas Rangers': 28, - 'Toronto Blue Jays': 29 + "Baltimore Orioles": 3, + "Boston Red Sox": 4, + "Chicago White Sox": 6, + "Cleveland Guardians": 8, + "Detroit Tigers": 10, + "Houston Astros": 11, + "Kansas City Royals": 12, + "Los Angeles Angels": 13, + "Minnesota Twins": 17, + "New York Yankees": 19, + "Oakland Athletics": 20, + "Athletics": 20, # Alias for post-Oakland move + "Seattle Mariners": 24, + "Tampa Bay Rays": 27, + "Texas Rangers": 28, + "Toronto Blue Jays": 29, } NL_TEAM_IDS = { - 'Arizona Diamondbacks': 1, - 'Atlanta Braves': 2, - 'Chicago Cubs': 5, - 'Cincinnati Reds': 7, - 'Colorado Rockies': 9, - 'Los Angeles Dodgers': 14, - 'Miami Marlins': 15, - 'Milwaukee Brewers': 16, - 'New York Mets': 18, - 'Philadelphia Phillies': 21, - 'Pittsburgh Pirates': 22, - 'San Diego Padres': 23, - 'San Francisco Giants': 25, - 'St Louis Cardinals': 26, # Note: constants has 'St Louis Cardinals' not 'St. Louis Cardinals' - 'Washington Nationals': 30 + "Arizona Diamondbacks": 1, + "Atlanta Braves": 2, + "Chicago Cubs": 5, + "Cincinnati Reds": 7, + "Colorado Rockies": 9, + "Los Angeles Dodgers": 14, + "Miami Marlins": 15, + "Milwaukee Brewers": 16, + "New York Mets": 18, + "Philadelphia Phillies": 21, + "Pittsburgh Pirates": 22, + "San Diego Padres": 23, + "San Francisco Giants": 25, + "St Louis Cardinals": 26, # Note: constants has 'St Louis Cardinals' not 'St. Louis Cardinals' + "Washington Nationals": 30, } # Get AL teams from constants AL_TEAMS = [team for team in ALL_MLB_TEAMS.keys() if team in AL_TEAM_IDS] -NL_TEAMS = [team for team in ALL_MLB_TEAMS.keys() if team in NL_TEAM_IDS or team == 'St Louis Cardinals'] +NL_TEAMS = [ + team + for team in ALL_MLB_TEAMS.keys() + if team in NL_TEAM_IDS or team == "St Louis Cardinals" +] # Cardset mappings CARDSET_LABELS_TO_IDS = { - '2022 Season': 3, - '2022 Promos': 4, - '2021 Season': 1, - '2019 Season': 5, - '2013 Season': 6, - '2012 Season': 7, - 'Mario Super Sluggers': 8, - '2023 Season': 9, - '2016 Season': 11, - '2008 Season': 12, - '2018 Season': 13, - '2024 Season': 17, - '2024 Promos': 18, - '1998 Season': 20, - '2025 Season': 24, - '2005 Live': 27, - 'Pokemon - Brilliant Stars': 23 + "2022 Season": 3, + "2022 Promos": 4, + "2021 Season": 1, + "2019 Season": 5, + "2013 Season": 6, + "2012 Season": 7, + "Mario Super Sluggers": 8, + "2023 Season": 9, + "2016 Season": 11, + "2008 Season": 12, + "2018 Season": 13, + "2024 Season": 17, + "2024 Promos": 18, + "1998 Season": 20, + "2025 Season": 24, + "2005 Live": 27, + "Pokemon - Brilliant Stars": 23, } -def _get_team_id(team_name: str, league: Literal['AL', 'NL']) -> int: +def _get_team_id(team_name: str, league: Literal["AL", "NL"]) -> int: """Get team ID from team name and league.""" - if league == 'AL': + if league == "AL": return AL_TEAM_IDS.get(team_name) else: # Handle the St. Louis Cardinals special case - if team_name == 'St. Louis Cardinals': - return NL_TEAM_IDS.get('St Louis Cardinals') + if team_name == "St. Louis Cardinals": + return NL_TEAM_IDS.get("St Louis Cardinals") return NL_TEAM_IDS.get(team_name) class SelectChoicePackTeam(discord.ui.Select): - def __init__(self, which: Literal['AL', 'NL'], team, cardset_id: Optional[int] = None): + def __init__( + self, which: Literal["AL", "NL"], team, cardset_id: Optional[int] = None + ): self.which = which self.owner_team = team self.cardset_id = cardset_id - - if which == 'AL': + + if which == "AL": options = [discord.SelectOption(label=team) for team in AL_TEAMS] else: # Handle St. Louis Cardinals display name - options = [discord.SelectOption(label='St. Louis Cardinals' if team == 'St Louis Cardinals' else team) - for team in NL_TEAMS] - - super().__init__(placeholder=f'Select an {which} team', options=options) + options = [ + discord.SelectOption( + label=( + "St. Louis Cardinals" if team == "St Louis Cardinals" else team + ) + ) + for team in NL_TEAMS + ] + + super().__init__(placeholder=f"Select an {which} team", options=options) async def callback(self, interaction: discord.Interaction): # Import here to avoid circular imports from api_calls import db_get, db_patch from helpers import open_choice_pack - + team_id = _get_team_id(self.values[0], self.which) if team_id is None: - raise ValueError(f'Unknown team: {self.values[0]}') + raise ValueError(f"Unknown team: {self.values[0]}") - await interaction.response.edit_message(content=f'You selected the **{self.values[0]}**', view=None) + await interaction.response.edit_message( + content=f"You selected the **{self.values[0]}**", view=None + ) # Get the selected packs params = [ - ('pack_type_id', 8), ('team_id', self.owner_team['id']), ('opened', False), ('limit', 1), - ('exact_match', True) + ("pack_type_id", 8), + ("team_id", self.owner_team["id"]), + ("opened", False), + ("limit", 1), + ("exact_match", True), ] if self.cardset_id is not None: - params.append(('pack_cardset_id', self.cardset_id)) - p_query = await db_get('packs', params=params) - if p_query['count'] == 0: - logger.error(f'open-packs - no packs found with params: {params}') - raise ValueError(f'Unable to open packs') + params.append(("pack_cardset_id", self.cardset_id)) + p_query = await db_get("packs", params=params) + if p_query["count"] == 0: + logger.error(f"open-packs - no packs found with params: {params}") + raise ValueError("Unable to open packs") - this_pack = await db_patch('packs', object_id=p_query['packs'][0]['id'], params=[('pack_team_id', team_id)]) + this_pack = await db_patch( + "packs", + object_id=p_query["packs"][0]["id"], + params=[("pack_team_id", team_id)], + ) await open_choice_pack(this_pack, self.owner_team, interaction, self.cardset_id) @@ -130,104 +152,116 @@ class SelectChoicePackTeam(discord.ui.Select): class SelectOpenPack(discord.ui.Select): def __init__(self, options: list, team: dict): self.owner_team = team - super().__init__(placeholder='Select a Pack Type', options=options) + super().__init__(placeholder="Select a Pack Type", options=options) async def callback(self, interaction: discord.Interaction): # Import here to avoid circular imports from api_calls import db_get from helpers import open_st_pr_packs, open_choice_pack - - logger.info(f'SelectPackChoice - selection: {self.values[0]}') - pack_vals = self.values[0].split('-') - logger.info(f'pack_vals: {pack_vals}') + + logger.info(f"SelectPackChoice - selection: {self.values[0]}") + pack_vals = self.values[0].split("-") + logger.info(f"pack_vals: {pack_vals}") # Get the selected packs - params = [('team_id', self.owner_team['id']), ('opened', False), ('limit', 5), ('exact_match', True)] + params = [ + ("team_id", self.owner_team["id"]), + ("opened", False), + ("limit", 20), + ("exact_match", True), + ] - open_type = 'standard' - if 'Standard' in pack_vals: - open_type = 'standard' - params.append(('pack_type_id', 1)) - elif 'Premium' in pack_vals: - open_type = 'standard' - params.append(('pack_type_id', 3)) - elif 'Daily' in pack_vals: - params.append(('pack_type_id', 4)) - elif 'Promo Choice' in pack_vals: - open_type = 'choice' - params.append(('pack_type_id', 9)) - elif 'MVP' in pack_vals: - open_type = 'choice' - params.append(('pack_type_id', 5)) - elif 'All Star' in pack_vals: - open_type = 'choice' - params.append(('pack_type_id', 6)) - elif 'Mario' in pack_vals: - open_type = 'choice' - params.append(('pack_type_id', 7)) - elif 'Team Choice' in pack_vals: - open_type = 'choice' - params.append(('pack_type_id', 8)) + open_type = "standard" + if "Standard" in pack_vals: + open_type = "standard" + params.append(("pack_type_id", 1)) + elif "Premium" in pack_vals: + open_type = "standard" + params.append(("pack_type_id", 3)) + elif "Daily" in pack_vals: + params.append(("pack_type_id", 4)) + elif "Promo Choice" in pack_vals: + open_type = "choice" + params.append(("pack_type_id", 9)) + elif "MVP" in pack_vals: + open_type = "choice" + params.append(("pack_type_id", 5)) + elif "All Star" in pack_vals: + open_type = "choice" + params.append(("pack_type_id", 6)) + elif "Mario" in pack_vals: + open_type = "choice" + params.append(("pack_type_id", 7)) + elif "Team Choice" in pack_vals: + open_type = "choice" + params.append(("pack_type_id", 8)) else: - raise KeyError(f'Cannot identify pack details: {pack_vals}') + raise KeyError(f"Cannot identify pack details: {pack_vals}") # If team isn't already set on team choice pack, make team pack selection now await interaction.response.edit_message(view=None) cardset_id = None # Handle Team Choice packs with no team/cardset assigned - if 'Team Choice' in pack_vals and 'Team' not in pack_vals and 'Cardset' not in pack_vals: + if ( + "Team Choice" in pack_vals + and "Team" not in pack_vals + and "Cardset" not in pack_vals + ): await interaction.followup.send( - content='This Team Choice pack needs to be assigned a team and cardset. ' - 'Please contact an admin to configure this pack.', - ephemeral=True + content="This Team Choice pack needs to be assigned a team and cardset. " + "Please contact an admin to configure this pack.", + ephemeral=True, ) return - elif 'Team Choice' in pack_vals and 'Cardset' in pack_vals: + elif "Team Choice" in pack_vals and "Cardset" in pack_vals: # cardset_id = pack_vals[2] - cardset_index = pack_vals.index('Cardset') + cardset_index = pack_vals.index("Cardset") cardset_id = pack_vals[cardset_index + 1] - params.append(('pack_cardset_id', cardset_id)) - if 'Team' not in pack_vals: + params.append(("pack_cardset_id", cardset_id)) + if "Team" not in pack_vals: view = SelectView( - [SelectChoicePackTeam('AL', self.owner_team, cardset_id), - SelectChoicePackTeam('NL', self.owner_team, cardset_id)], - timeout=30 + [ + SelectChoicePackTeam("AL", self.owner_team, cardset_id), + SelectChoicePackTeam("NL", self.owner_team, cardset_id), + ], + timeout=30, ) await interaction.followup.send( - content='Please select a team for your Team Choice pack:', - view=view + content="Please select a team for your Team Choice pack:", view=view ) return - - params.append(('pack_team_id', pack_vals[pack_vals.index('Team') + 1])) - else: - if 'Team' in pack_vals: - params.append(('pack_team_id', pack_vals[pack_vals.index('Team') + 1])) - if 'Cardset' in pack_vals: - cardset_id = pack_vals[pack_vals.index('Cardset') + 1] - params.append(('pack_cardset_id', cardset_id)) - p_query = await db_get('packs', params=params) - if p_query['count'] == 0: - logger.error(f'open-packs - no packs found with params: {params}') + params.append(("pack_team_id", pack_vals[pack_vals.index("Team") + 1])) + else: + if "Team" in pack_vals: + params.append(("pack_team_id", pack_vals[pack_vals.index("Team") + 1])) + if "Cardset" in pack_vals: + cardset_id = pack_vals[pack_vals.index("Cardset") + 1] + params.append(("pack_cardset_id", cardset_id)) + + p_query = await db_get("packs", params=params) + if p_query["count"] == 0: + logger.error(f"open-packs - no packs found with params: {params}") await interaction.followup.send( - content='Unable to find the selected pack. Please contact an admin.', - ephemeral=True + content="Unable to find the selected pack. Please contact an admin.", + ephemeral=True, ) return # Open the packs try: - if open_type == 'standard': - await open_st_pr_packs(p_query['packs'], self.owner_team, interaction) - elif open_type == 'choice': - await open_choice_pack(p_query['packs'][0], self.owner_team, interaction, cardset_id) + if open_type == "standard": + await open_st_pr_packs(p_query["packs"], self.owner_team, interaction) + elif open_type == "choice": + await open_choice_pack( + p_query["packs"][0], self.owner_team, interaction, cardset_id + ) except Exception as e: - logger.error(f'Failed to open pack: {e}') + logger.error(f"Failed to open pack: {e}") await interaction.followup.send( - content=f'Failed to open pack. Please contact an admin. Error: {str(e)}', - ephemeral=True + content=f"Failed to open pack. Please contact an admin. Error: {str(e)}", + ephemeral=True, ) return @@ -235,275 +269,317 @@ class SelectOpenPack(discord.ui.Select): class SelectPaperdexCardset(discord.ui.Select): def __init__(self): options = [ - discord.SelectOption(label='2005 Live'), - discord.SelectOption(label='2025 Season'), - discord.SelectOption(label='1998 Season'), - discord.SelectOption(label='2024 Season'), - discord.SelectOption(label='2023 Season'), - discord.SelectOption(label='2022 Season'), - discord.SelectOption(label='2022 Promos'), - discord.SelectOption(label='2021 Season'), - discord.SelectOption(label='2019 Season'), - discord.SelectOption(label='2018 Season'), - discord.SelectOption(label='2016 Season'), - discord.SelectOption(label='2013 Season'), - discord.SelectOption(label='2012 Season'), - discord.SelectOption(label='2008 Season'), - discord.SelectOption(label='Mario Super Sluggers') + discord.SelectOption(label="2005 Live"), + discord.SelectOption(label="2025 Season"), + discord.SelectOption(label="1998 Season"), + discord.SelectOption(label="2024 Season"), + discord.SelectOption(label="2023 Season"), + discord.SelectOption(label="2022 Season"), + discord.SelectOption(label="2022 Promos"), + discord.SelectOption(label="2021 Season"), + discord.SelectOption(label="2019 Season"), + discord.SelectOption(label="2018 Season"), + discord.SelectOption(label="2016 Season"), + discord.SelectOption(label="2013 Season"), + discord.SelectOption(label="2012 Season"), + discord.SelectOption(label="2008 Season"), + discord.SelectOption(label="Mario Super Sluggers"), ] - super().__init__(placeholder='Select a Cardset', options=options) + super().__init__(placeholder="Select a Cardset", options=options) async def callback(self, interaction: discord.Interaction): # Import here to avoid circular imports from api_calls import db_get from helpers import get_team_by_owner, paperdex_cardset_embed, embed_pagination - - logger.info(f'SelectPaperdexCardset - selection: {self.values[0]}') + + logger.info(f"SelectPaperdexCardset - selection: {self.values[0]}") cardset_id = CARDSET_LABELS_TO_IDS.get(self.values[0]) if cardset_id is None: - raise ValueError(f'Unknown cardset: {self.values[0]}') + raise ValueError(f"Unknown cardset: {self.values[0]}") - c_query = await db_get('cardsets', object_id=cardset_id, none_okay=False) - await interaction.response.edit_message(content=f'Okay, sifting through your cards...', view=None) + c_query = await db_get("cardsets", object_id=cardset_id, none_okay=False) + await interaction.response.edit_message( + content="Okay, sifting through your cards...", view=None + ) cardset_embeds = await paperdex_cardset_embed( - team=await get_team_by_owner(interaction.user.id), - this_cardset=c_query + team=await get_team_by_owner(interaction.user.id), this_cardset=c_query ) await embed_pagination(cardset_embeds, interaction.channel, interaction.user) class SelectPaperdexTeam(discord.ui.Select): - def __init__(self, which: Literal['AL', 'NL']): + def __init__(self, which: Literal["AL", "NL"]): self.which = which - - if which == 'AL': + + if which == "AL": options = [discord.SelectOption(label=team) for team in AL_TEAMS] else: # Handle St. Louis Cardinals display name - options = [discord.SelectOption(label='St. Louis Cardinals' if team == 'St Louis Cardinals' else team) - for team in NL_TEAMS] - - super().__init__(placeholder=f'Select an {which} team', options=options) + options = [ + discord.SelectOption( + label=( + "St. Louis Cardinals" if team == "St Louis Cardinals" else team + ) + ) + for team in NL_TEAMS + ] + + super().__init__(placeholder=f"Select an {which} team", options=options) async def callback(self, interaction: discord.Interaction): # Import here to avoid circular imports from api_calls import db_get from helpers import get_team_by_owner, paperdex_team_embed, embed_pagination - + team_id = _get_team_id(self.values[0], self.which) if team_id is None: - raise ValueError(f'Unknown team: {self.values[0]}') + raise ValueError(f"Unknown team: {self.values[0]}") - t_query = await db_get('teams', object_id=team_id, none_okay=False) - await interaction.response.edit_message(content=f'Okay, sifting through your cards...', view=None) + t_query = await db_get("teams", object_id=team_id, none_okay=False) + await interaction.response.edit_message( + content="Okay, sifting through your cards...", view=None + ) - team_embeds = await paperdex_team_embed(team=await get_team_by_owner(interaction.user.id), mlb_team=t_query) + team_embeds = await paperdex_team_embed( + team=await get_team_by_owner(interaction.user.id), mlb_team=t_query + ) await embed_pagination(team_embeds, interaction.channel, interaction.user) class SelectBuyPacksCardset(discord.ui.Select): - def __init__(self, team: dict, quantity: int, pack_type_id: int, pack_embed: discord.Embed, cost: int): + def __init__( + self, + team: dict, + quantity: int, + pack_type_id: int, + pack_embed: discord.Embed, + cost: int, + ): options = [ - discord.SelectOption(label='2005 Live'), - discord.SelectOption(label='2025 Season'), - discord.SelectOption(label='1998 Season'), - discord.SelectOption(label='Pokemon - Brilliant Stars'), - discord.SelectOption(label='2024 Season'), - discord.SelectOption(label='2023 Season'), - discord.SelectOption(label='2022 Season'), - discord.SelectOption(label='2021 Season'), - discord.SelectOption(label='2019 Season'), - discord.SelectOption(label='2018 Season'), - discord.SelectOption(label='2016 Season'), - discord.SelectOption(label='2013 Season'), - discord.SelectOption(label='2012 Season'), - discord.SelectOption(label='2008 Season') + discord.SelectOption(label="2005 Live"), + discord.SelectOption(label="2025 Season"), + discord.SelectOption(label="1998 Season"), + discord.SelectOption(label="Pokemon - Brilliant Stars"), + discord.SelectOption(label="2024 Season"), + discord.SelectOption(label="2023 Season"), + discord.SelectOption(label="2022 Season"), + discord.SelectOption(label="2021 Season"), + discord.SelectOption(label="2019 Season"), + discord.SelectOption(label="2018 Season"), + discord.SelectOption(label="2016 Season"), + discord.SelectOption(label="2013 Season"), + discord.SelectOption(label="2012 Season"), + discord.SelectOption(label="2008 Season"), ] self.team = team self.quantity = quantity self.pack_type_id = pack_type_id self.pack_embed = pack_embed self.cost = cost - super().__init__(placeholder='Select a Cardset', options=options) + super().__init__(placeholder="Select a Cardset", options=options) async def callback(self, interaction: discord.Interaction): # Import here to avoid circular imports from api_calls import db_post from discord_ui.confirmations import Confirm - - logger.info(f'SelectBuyPacksCardset - selection: {self.values[0]}') + + logger.info(f"SelectBuyPacksCardset - selection: {self.values[0]}") cardset_id = CARDSET_LABELS_TO_IDS.get(self.values[0]) if cardset_id is None: - raise ValueError(f'Unknown cardset: {self.values[0]}') - - if self.values[0] == 'Pokemon - Brilliant Stars': - self.pack_embed.set_image(url=IMAGES['pack-pkmnbs']) + raise ValueError(f"Unknown cardset: {self.values[0]}") - self.pack_embed.description = f'{self.pack_embed.description} - {self.values[0]}' + if self.values[0] == "Pokemon - Brilliant Stars": + self.pack_embed.set_image(url=IMAGES["pack-pkmnbs"]) + + self.pack_embed.description = ( + f"{self.pack_embed.description} - {self.values[0]}" + ) view = Confirm(responders=[interaction.user], timeout=30) await interaction.response.edit_message( - content=None, - embed=self.pack_embed, - view=None + content=None, embed=self.pack_embed, view=None ) question = await interaction.channel.send( - content=f'Your Wallet: {self.team["wallet"]}₼\n' - f'Pack{"s" if self.quantity > 1 else ""} Price: {self.cost}₼\n' - f'After Purchase: {self.team["wallet"] - self.cost}₼\n\n' - f'Would you like to make this purchase?', - view=view + content=f"Your Wallet: {self.team['wallet']}₼\n" + f"Pack{'s' if self.quantity > 1 else ''} Price: {self.cost}₼\n" + f"After Purchase: {self.team['wallet'] - self.cost}₼\n\n" + f"Would you like to make this purchase?", + view=view, ) await view.wait() if not view.value: - await question.edit( - content='Saving that money. Smart.', - view=None - ) + await question.edit(content="Saving that money. Smart.", view=None) return p_model = { - 'team_id': self.team['id'], - 'pack_type_id': self.pack_type_id, - 'pack_cardset_id': cardset_id + "team_id": self.team["id"], + "pack_type_id": self.pack_type_id, + "pack_cardset_id": cardset_id, } - await db_post('packs', payload={'packs': [p_model for x in range(self.quantity)]}) - await db_post(f'teams/{self.team["id"]}/money/-{self.cost}') + await db_post( + "packs", payload={"packs": [p_model for x in range(self.quantity)]} + ) + await db_post(f"teams/{self.team['id']}/money/-{self.cost}") await question.edit( - content=f'{"They are" if self.quantity > 1 else "It is"} all yours! Go rip \'em with `/open-packs`', - view=None + content=f"{'They are' if self.quantity > 1 else 'It is'} all yours! Go rip 'em with `/open-packs`", + view=None, ) class SelectBuyPacksTeam(discord.ui.Select): def __init__( - self, which: Literal['AL', 'NL'], team: dict, quantity: int, pack_type_id: int, pack_embed: discord.Embed, - cost: int): + self, + which: Literal["AL", "NL"], + team: dict, + quantity: int, + pack_type_id: int, + pack_embed: discord.Embed, + cost: int, + ): self.which = which self.team = team self.quantity = quantity self.pack_type_id = pack_type_id self.pack_embed = pack_embed self.cost = cost - - if which == 'AL': + + if which == "AL": options = [discord.SelectOption(label=team) for team in AL_TEAMS] else: # Handle St. Louis Cardinals display name - options = [discord.SelectOption(label='St. Louis Cardinals' if team == 'St Louis Cardinals' else team) - for team in NL_TEAMS] - - super().__init__(placeholder=f'Select an {which} team', options=options) + options = [ + discord.SelectOption( + label=( + "St. Louis Cardinals" if team == "St Louis Cardinals" else team + ) + ) + for team in NL_TEAMS + ] + + super().__init__(placeholder=f"Select an {which} team", options=options) async def callback(self, interaction: discord.Interaction): # Import here to avoid circular imports from api_calls import db_post from discord_ui.confirmations import Confirm - + team_id = _get_team_id(self.values[0], self.which) if team_id is None: - raise ValueError(f'Unknown team: {self.values[0]}') + raise ValueError(f"Unknown team: {self.values[0]}") - self.pack_embed.description = f'{self.pack_embed.description} - {self.values[0]}' + self.pack_embed.description = ( + f"{self.pack_embed.description} - {self.values[0]}" + ) view = Confirm(responders=[interaction.user], timeout=30) await interaction.response.edit_message( - content=None, - embed=self.pack_embed, - view=None + content=None, embed=self.pack_embed, view=None ) question = await interaction.channel.send( - content=f'Your Wallet: {self.team["wallet"]}₼\n' - f'Pack{"s" if self.quantity > 1 else ""} Price: {self.cost}₼\n' - f'After Purchase: {self.team["wallet"] - self.cost}₼\n\n' - f'Would you like to make this purchase?', - view=view + content=f"Your Wallet: {self.team['wallet']}₼\n" + f"Pack{'s' if self.quantity > 1 else ''} Price: {self.cost}₼\n" + f"After Purchase: {self.team['wallet'] - self.cost}₼\n\n" + f"Would you like to make this purchase?", + view=view, ) await view.wait() if not view.value: - await question.edit( - content='Saving that money. Smart.', - view=None - ) + await question.edit(content="Saving that money. Smart.", view=None) return p_model = { - 'team_id': self.team['id'], - 'pack_type_id': self.pack_type_id, - 'pack_team_id': team_id + "team_id": self.team["id"], + "pack_type_id": self.pack_type_id, + "pack_team_id": team_id, } - await db_post('packs', payload={'packs': [p_model for x in range(self.quantity)]}) - await db_post(f'teams/{self.team["id"]}/money/-{self.cost}') + await db_post( + "packs", payload={"packs": [p_model for x in range(self.quantity)]} + ) + await db_post(f"teams/{self.team['id']}/money/-{self.cost}") await question.edit( - content=f'{"They are" if self.quantity > 1 else "It is"} all yours! Go rip \'em with `/open-packs`', - view=None + content=f"{'They are' if self.quantity > 1 else 'It is'} all yours! Go rip 'em with `/open-packs`", + view=None, ) class SelectUpdatePlayerTeam(discord.ui.Select): - def __init__(self, which: Literal['AL', 'NL'], player: dict, reporting_team: dict, bot): + def __init__( + self, which: Literal["AL", "NL"], player: dict, reporting_team: dict, bot + ): self.bot = bot self.which = which self.player = player self.reporting_team = reporting_team - - if which == 'AL': + + if which == "AL": options = [discord.SelectOption(label=team) for team in AL_TEAMS] else: # Handle St. Louis Cardinals display name - options = [discord.SelectOption(label='St. Louis Cardinals' if team == 'St Louis Cardinals' else team) - for team in NL_TEAMS] - - super().__init__(placeholder=f'Select an {which} team', options=options) + options = [ + discord.SelectOption( + label=( + "St. Louis Cardinals" if team == "St Louis Cardinals" else team + ) + ) + for team in NL_TEAMS + ] + + super().__init__(placeholder=f"Select an {which} team", options=options) async def callback(self, interaction: discord.Interaction): # Import here to avoid circular imports from api_calls import db_patch, db_post from discord_ui.confirmations import Confirm from helpers import player_desc, send_to_channel - + # Check if already assigned - compare against both normalized franchise and full mlbclub normalized_selection = normalize_franchise(self.values[0]) - if normalized_selection == self.player['franchise'] or self.values[0] == self.player['mlbclub']: + if ( + normalized_selection == self.player["franchise"] + or self.values[0] == self.player["mlbclub"] + ): await interaction.response.send_message( - content=f'Thank you for the help, but it looks like somebody beat you to it! ' - f'**{player_desc(self.player)}** is already assigned to the **{self.player["mlbclub"]}**.' + content=f"Thank you for the help, but it looks like somebody beat you to it! " + f"**{player_desc(self.player)}** is already assigned to the **{self.player['mlbclub']}**." ) return view = Confirm(responders=[interaction.user], timeout=15) await interaction.response.edit_message( - content=f'Should I update **{player_desc(self.player)}**\'s team to the **{self.values[0]}**?', - view=None - ) - question = await interaction.channel.send( - content=None, - view=view + content=f"Should I update **{player_desc(self.player)}**'s team to the **{self.values[0]}**?", + view=None, ) + question = await interaction.channel.send(content=None, view=view) await view.wait() if not view.value: await question.edit( - content='That didnt\'t sound right to me, either. Let\'s not touch that.', - view=None + content="That didnt't sound right to me, either. Let's not touch that.", + view=None, ) return else: await question.delete() - await db_patch('players', object_id=self.player['player_id'], params=[ - ('mlbclub', self.values[0]), ('franchise', normalize_franchise(self.values[0])) - ]) - await db_post(f'teams/{self.reporting_team["id"]}/money/25') - await send_to_channel( - self.bot, 'pd-news-ticker', - content=f'{interaction.user.name} just updated **{player_desc(self.player)}**\'s team to the ' - f'**{self.values[0]}**' + await db_patch( + "players", + object_id=self.player["player_id"], + params=[ + ("mlbclub", self.values[0]), + ("franchise", normalize_franchise(self.values[0])), + ], ) - await interaction.channel.send(f'All done!') + await db_post(f"teams/{self.reporting_team['id']}/money/25") + await send_to_channel( + self.bot, + "pd-news-ticker", + content=f"{interaction.user.name} just updated **{player_desc(self.player)}**'s team to the " + f"**{self.values[0]}**", + ) + await interaction.channel.send("All done!") class SelectView(discord.ui.View): @@ -511,4 +587,4 @@ class SelectView(discord.ui.View): super().__init__(timeout=timeout) for x in select_objects: - self.add_item(x) \ No newline at end of file + self.add_item(x) diff --git a/helpers/evolution_notifs.py b/helpers/evolution_notifs.py new file mode 100644 index 0000000..a86c5b9 --- /dev/null +++ b/helpers/evolution_notifs.py @@ -0,0 +1,106 @@ +""" +Evolution Tier Completion Notifications + +Builds and sends Discord embeds when a player completes an evolution tier +during post-game evaluation. Each tier-up event gets its own embed. + +Notification failures are non-fatal: the send is wrapped in try/except so +a Discord API hiccup never disrupts game flow. +""" + +import logging + +import discord + +logger = logging.getLogger("discord_app") + +# Human-readable display names for each tier number. +TIER_NAMES = { + 0: "Unranked", + 1: "Initiate", + 2: "Rising", + 3: "Ascendant", + 4: "Evolved", +} + +# Tier-specific embed colors. +TIER_COLORS = { + 1: 0x2ECC71, # green + 2: 0xF1C40F, # gold + 3: 0x9B59B6, # purple + 4: 0x1ABC9C, # teal (fully evolved) +} + +FOOTER_TEXT = "Paper Dynasty Evolution" + + +def build_tier_up_embed(tier_up: dict) -> discord.Embed: + """Build a Discord embed for a tier-up event. + + Parameters + ---------- + tier_up: + Dict with keys: player_name, old_tier, new_tier, current_value, track_name. + + Returns + ------- + discord.Embed + A fully configured embed ready to send to a channel. + """ + player_name: str = tier_up["player_name"] + new_tier: int = tier_up["new_tier"] + track_name: str = tier_up["track_name"] + + tier_name = TIER_NAMES.get(new_tier, f"Tier {new_tier}") + color = TIER_COLORS.get(new_tier, 0x2ECC71) + + if new_tier >= 4: + # Fully evolved — special title and description. + embed = discord.Embed( + title="FULLY EVOLVED!", + description=( + f"**{player_name}** has reached maximum evolution on the **{track_name}** track" + ), + color=color, + ) + embed.add_field( + name="Rating Boosts", + value="Rating boosts coming in a future update!", + inline=False, + ) + else: + embed = discord.Embed( + title="Evolution Tier Up!", + description=( + f"**{player_name}** reached **Tier {new_tier} ({tier_name})** on the **{track_name}** track" + ), + color=color, + ) + + embed.set_footer(text=FOOTER_TEXT) + return embed + + +async def notify_tier_completion(channel, tier_up: dict) -> None: + """Send a tier-up notification embed to the given channel. + + Non-fatal: any exception during send is caught and logged so that a + Discord API failure never interrupts game evaluation. + + Parameters + ---------- + channel: + A discord.TextChannel (or any object with an async ``send`` method). + tier_up: + Dict with keys: player_name, old_tier, new_tier, current_value, track_name. + """ + try: + embed = build_tier_up_embed(tier_up) + await channel.send(embed=embed) + except Exception as exc: + logger.error( + "Failed to send tier-up notification for %s (tier %s): %s", + tier_up.get("player_name", "unknown"), + tier_up.get("new_tier"), + exc, + ) diff --git a/helpers/main.py b/helpers/main.py index 0b989a5..41347f8 100644 --- a/helpers/main.py +++ b/helpers/main.py @@ -2,35 +2,23 @@ import asyncio import datetime import logging import math -import os import random -import traceback import discord -import pygsheets import aiohttp from discord.ext import commands from api_calls import * from bs4 import BeautifulSoup -from difflib import get_close_matches -from dataclasses import dataclass -from typing import Optional, Literal, Union, List +from typing import Optional, Union, List -from exceptions import log_exception from in_game.gameplay_models import Team from constants import * from discord_ui import * from random_content import * from utils import ( - position_name_to_abbrev, - user_has_role, - get_roster_sheet_legacy, get_roster_sheet, - get_player_url, - owner_only, get_cal_user, - get_context_user, ) from search_utils import * from discord_utils import * @@ -122,8 +110,24 @@ async def share_channel(channel, user, read_only=False): async def get_card_embeds(card, include_stats=False) -> list: + # WP-12: fetch evolution state and build tier badge prefix. + # Non-blocking — any failure falls back to no badge so card display is + # never broken by an unavailable or slow evolution API. + tier_badge = "" + try: + evo_state = await db_get(f"evolution/cards/{card['id']}", none_okay=True) + if evo_state and evo_state.get("current_tier", 0) > 0: + tier = evo_state["current_tier"] + tier_badge = f"[{'EVO' if tier >= 4 else f'T{tier}'}] " + except Exception: + logging.warning( + f"Could not fetch evolution state for card {card.get('id')}; " + "displaying without tier badge.", + exc_info=True, + ) + embed = discord.Embed( - title=f"{card['player']['p_name']}", + title=f"{tier_badge}{card['player']['p_name']}", color=int(card["player"]["rarity"]["color"], 16), ) # embed.description = card['team']['lname'] @@ -166,7 +170,7 @@ async def get_card_embeds(card, include_stats=False) -> list: ] if any(bool_list): if count == 1: - coll_string = f"Only you" + coll_string = "Only you" else: coll_string = ( f"You and {count - 1} other{'s' if count - 1 != 1 else ''}" @@ -174,7 +178,7 @@ async def get_card_embeds(card, include_stats=False) -> list: elif count: coll_string = f"{count} other team{'s' if count != 1 else ''}" else: - coll_string = f"0 teams" + coll_string = "0 teams" embed.add_field(name="Collected By", value=coll_string) else: embed.add_field( @@ -213,7 +217,7 @@ async def get_card_embeds(card, include_stats=False) -> list: ) if evo_mon is not None: embed.add_field(name="Evolves Into", value=f"{evo_mon['p_name']}") - except Exception as e: + except Exception: logging.error( "could not pull evolution: {e}", exc_info=True, stack_info=True ) @@ -224,7 +228,7 @@ async def get_card_embeds(card, include_stats=False) -> list: ) if evo_mon is not None: embed.add_field(name="Evolves From", value=f"{evo_mon['p_name']}") - except Exception as e: + except Exception: logging.error( "could not pull evolution: {e}", exc_info=True, stack_info=True ) @@ -326,7 +330,7 @@ async def display_cards( ) try: cards.sort(key=lambda x: x["player"]["rarity"]["value"]) - logger.debug(f"Cards sorted successfully") + logger.debug("Cards sorted successfully") card_embeds = [await get_card_embeds(x) for x in cards] logger.debug(f"Created {len(card_embeds)} card embeds") @@ -347,15 +351,15 @@ async def display_cards( r_emoji = "→" view.left_button.disabled = True view.left_button.label = f"{l_emoji}Prev: -/{len(card_embeds)}" - view.cancel_button.label = f"Close Pack" + view.cancel_button.label = "Close Pack" view.right_button.label = f"Next: {page_num + 2}/{len(card_embeds)}{r_emoji}" if len(cards) == 1: view.right_button.disabled = True - logger.debug(f"Pagination view created successfully") + logger.debug("Pagination view created successfully") if pack_cover: - logger.debug(f"Sending pack cover message") + logger.debug("Sending pack cover message") msg = await channel.send( content=None, embed=image_embed(pack_cover, title=f"{team['lname']}", desc=pack_name), @@ -367,7 +371,7 @@ async def display_cards( content=None, embeds=card_embeds[page_num], view=view ) - logger.debug(f"Initial message sent successfully") + logger.debug("Initial message sent successfully") except Exception as e: logger.error( f"Error creating view or sending initial message: {e}", exc_info=True @@ -384,12 +388,12 @@ async def display_cards( f"{user.mention} you've got {len(cards)} cards here" ) - logger.debug(f"Follow-up message sent successfully") + logger.debug("Follow-up message sent successfully") except Exception as e: logger.error(f"Error sending follow-up message: {e}", exc_info=True) return False - logger.debug(f"Starting main interaction loop") + logger.debug("Starting main interaction loop") while True: try: logger.debug(f"Waiting for user interaction on page {page_num}") @@ -455,7 +459,7 @@ async def display_cards( ), view=view, ) - logger.debug(f"MVP display updated successfully") + logger.debug("MVP display updated successfully") except Exception as e: logger.error( f"Error processing shiny card on page {page_num}: {e}", exc_info=True @@ -463,19 +467,19 @@ async def display_cards( # Continue with regular flow instead of crashing try: tmp_msg = await channel.send( - content=f"<@&1163537676885033010> we've got an MVP!" + content="<@&1163537676885033010> we've got an MVP!" ) await follow_up.edit( - content=f"<@&1163537676885033010> we've got an MVP!" + content="<@&1163537676885033010> we've got an MVP!" ) await tmp_msg.delete() except discord.errors.NotFound: # Role might not exist or message was already deleted - await follow_up.edit(content=f"We've got an MVP!") + await follow_up.edit(content="We've got an MVP!") except Exception as e: # Log error but don't crash the function logger.error(f"Error handling MVP notification: {e}") - await follow_up.edit(content=f"We've got an MVP!") + await follow_up.edit(content="We've got an MVP!") await view.wait() view = Pagination([user], timeout=10) @@ -483,7 +487,7 @@ async def display_cards( view.right_button.label = ( f"Next: {page_num + 2}/{len(card_embeds)}{r_emoji}" ) - view.cancel_button.label = f"Close Pack" + view.cancel_button.label = "Close Pack" view.left_button.label = f"{l_emoji}Prev: {page_num}/{len(card_embeds)}" if page_num == 0: view.left_button.label = f"{l_emoji}Prev: -/{len(card_embeds)}" @@ -531,7 +535,7 @@ async def embed_pagination( l_emoji = "" r_emoji = "" view.right_button.label = f"Next: {page_num + 2}/{len(all_embeds)}{r_emoji}" - view.cancel_button.label = f"Cancel" + view.cancel_button.label = "Cancel" view.left_button.label = f"{l_emoji}Prev: {page_num}/{len(all_embeds)}" if page_num == 0: view.left_button.label = f"{l_emoji}Prev: -/{len(all_embeds)}" @@ -566,7 +570,7 @@ async def embed_pagination( view = Pagination([user], timeout=timeout) view.right_button.label = f"Next: {page_num + 2}/{len(all_embeds)}{r_emoji}" - view.cancel_button.label = f"Cancel" + view.cancel_button.label = "Cancel" view.left_button.label = f"{l_emoji}Prev: {page_num}/{len(all_embeds)}" if page_num == 0: view.left_button.label = f"{l_emoji}Prev: -/{len(all_embeds)}" @@ -645,21 +649,15 @@ async def get_test_pack(ctx, team): async def roll_for_cards(all_packs: list, extra_val=None) -> list: + """Open packs by rolling dice, fetching random players, and creating cards. + + Parallelizes DB calls: one fetch per rarity tier across all packs, + then gathers all card creates and pack patches concurrently. """ - Pack odds are calculated based on the pack type - - Parameters - ---------- - extra_val - all_packs - - Returns - ------- - - """ - all_players = [] team = all_packs[0]["team"] - pack_ids = [] + + # --- Phase A: Roll dice for every pack (CPU-only, no I/O) --- + pack_counts = [] for pack in all_packs: counts = { "Rep": {"count": 0, "rarity": 0}, @@ -669,10 +667,9 @@ async def roll_for_cards(all_packs: list, extra_val=None) -> list: "MVP": {"count": 0, "rarity": 5}, "HoF": {"count": 0, "rarity": 8}, } - this_pack_players = [] if pack["pack_type"]["name"] == "Standard": # Cards 1 - 2 - for x in range(2): + for _ in range(2): d_1000 = random.randint(1, 1000) if d_1000 <= 450: counts["Rep"]["count"] += 1 @@ -792,7 +789,6 @@ async def roll_for_cards(all_packs: list, extra_val=None) -> list: logger.info( f"Building Check-In Pack // extra_val (type): {extra_val} {type(extra_val)}" ) - # Single Card mod = 0 if isinstance(extra_val, int): mod = extra_val @@ -810,106 +806,195 @@ async def roll_for_cards(all_packs: list, extra_val=None) -> list: else: raise TypeError(f"Pack type not recognized: {pack['pack_type']['name']}") - pull_notifs = [] - for key in counts: - mvp_flag = None + pack_counts.append(counts) - if counts[key]["count"] > 0: - params = [ - ("min_rarity", counts[key]["rarity"]), - ("max_rarity", counts[key]["rarity"]), - ("limit", counts[key]["count"]), - ] - if all_packs[0]["pack_team"] is not None: - params.extend( - [ - ("franchise", all_packs[0]["pack_team"]["sname"]), - ("in_packs", True), - ] - ) - elif all_packs[0]["pack_cardset"] is not None: - params.append(("cardset_id", all_packs[0]["pack_cardset"]["id"])) - else: - params.append(("in_packs", True)) + # --- Phase B: Fetch players — one call per rarity tier, all gathered --- + # Sum counts across all packs per rarity tier + rarity_keys = ["Rep", "Res", "Sta", "All", "MVP", "HoF"] + summed = {key: 0 for key in rarity_keys} + for counts in pack_counts: + for key in rarity_keys: + summed[key] += counts[key]["count"] - pl = await db_get("players/random", params=params) - - if pl["count"] != counts[key]["count"]: - mvp_flag = counts[key]["count"] - pl["count"] - logging.info( - f"Set mvp flag to {mvp_flag} / cardset_id: {all_packs[0]['pack_cardset']['id']}" - ) - - for x in pl["players"]: - this_pack_players.append(x) - all_players.append(x) - - if x["rarity"]["value"] >= 3: - pull_notifs.append(x) - - if mvp_flag and all_packs[0]["pack_cardset"]["id"] not in [23]: - logging.info(f"Adding {mvp_flag} MVPs for missing cards") - pl = await db_get( - "players/random", params=[("min_rarity", 5), ("limit", mvp_flag)] - ) - - for x in pl["players"]: - this_pack_players.append(x) - all_players.append(x) - - # Add dupes of Replacement/Reserve cards - elif mvp_flag: - logging.info(f"Adding {mvp_flag} duplicate pokemon cards") - for count in range(mvp_flag): - logging.info(f"Adding {pl['players'][0]['p_name']} to the pack") - this_pack_players.append(x) - all_players.append(pl["players"][0]) - - success = await db_post( - "cards", - payload={ - "cards": [ - { - "player_id": x["player_id"], - "team_id": pack["team"]["id"], - "pack_id": pack["id"], - } - for x in this_pack_players - ] - }, - timeout=10, + # Build shared filter params + base_params = [] + if all_packs[0]["pack_team"] is not None: + base_params.extend( + [ + ("franchise", all_packs[0]["pack_team"]["sname"]), + ("in_packs", True), + ] ) - if not success: - raise ConnectionError(f"Failed to create this pack of cards.") + elif all_packs[0]["pack_cardset"] is not None: + base_params.append(("cardset_id", all_packs[0]["pack_cardset"]["id"])) + else: + base_params.append(("in_packs", True)) - await db_patch( - "packs", - object_id=pack["id"], - params=[ - ( - "open_time", - int(datetime.datetime.timestamp(datetime.datetime.now()) * 1000), - ) - ], - ) - pack_ids.append(pack["id"]) + # Fire one request per non-zero rarity tier concurrently + rarity_values = { + "Rep": 0, + "Res": 1, + "Sta": 2, + "All": 3, + "MVP": 5, + "HoF": 8, + } + fetch_keys = [key for key in rarity_keys if summed[key] > 0] + fetch_coros = [] + for key in fetch_keys: + params = [ + ("min_rarity", rarity_values[key]), + ("max_rarity", rarity_values[key]), + ("limit", summed[key]), + ] + base_params + fetch_coros.append(db_get("players/random", params=params)) - for pull in pull_notifs: - logger.info(f"good pull: {pull}") - await db_post( - "notifs", - payload={ - "created": int( - datetime.datetime.timestamp(datetime.datetime.now()) * 1000 - ), - "title": "Rare Pull", - "field_name": f"{player_desc(pull)} ({pull['rarity']['name']})", - "message": f"Pulled by {team['abbrev']}", - "about": f"Player-{pull['player_id']}", - }, + fetch_results = await asyncio.gather(*fetch_coros) + + # Map results back: rarity key -> list of players + fetched_players = {} + for key, result in zip(fetch_keys, fetch_results): + fetched_players[key] = result.get("players", []) + + # Handle shortfalls — collect total MVP backfill needed + total_mvp_shortfall = 0 + # Track per-tier shortfall for dupe-branch (cardset 23 exclusion) + tier_shortfalls = {} + for key in fetch_keys: + returned = len(fetched_players[key]) + requested = summed[key] + if returned < requested: + shortfall = requested - returned + tier_shortfalls[key] = shortfall + total_mvp_shortfall += shortfall + logging.info( + f"Shortfall in {key}: requested {requested}, got {returned} " + f"(cardset_id: {all_packs[0]['pack_cardset']['id'] if all_packs[0]['pack_cardset'] else 'N/A'})" ) - return pack_ids + # Fetch MVP backfill or duplicate existing players + backfill_players = [] + is_dupe_cardset = all_packs[0]["pack_cardset"] is not None and all_packs[0][ + "pack_cardset" + ]["id"] in [23] + if total_mvp_shortfall > 0 and not is_dupe_cardset: + logging.info(f"Adding {total_mvp_shortfall} MVPs for missing cards") + mvp_result = await db_get( + "players/random", + params=[("min_rarity", 5), ("limit", total_mvp_shortfall)], + ) + backfill_players = mvp_result.get("players", []) + elif total_mvp_shortfall > 0 and is_dupe_cardset: + logging.info( + f"Adding {total_mvp_shortfall} duplicate cards for excluded cardset" + ) + # Duplicate from first available player in the fetched results + for key in fetch_keys: + if fetched_players[key]: + for _ in range(total_mvp_shortfall): + backfill_players.append(fetched_players[key][0]) + break + + # Slice fetched players back into per-pack groups + # Track consumption offset per rarity tier + tier_offsets = {key: 0 for key in rarity_keys} + backfill_offset = 0 + per_pack_players = [] + all_pull_notifs = [] + + for pack_idx, counts in enumerate(pack_counts): + this_pack_players = [] + pack_shortfall = 0 + + for key in rarity_keys: + needed = counts[key]["count"] + if needed == 0: + continue + + available = fetched_players.get(key, []) + start = tier_offsets[key] + end = start + needed + got = available[start:end] + this_pack_players.extend(got) + tier_offsets[key] = end + + # Track shortfall for this pack + if len(got) < needed: + pack_shortfall += needed - len(got) + + # Distribute backfill players to this pack + if pack_shortfall > 0 and backfill_offset < len(backfill_players): + bf_slice = backfill_players[ + backfill_offset : backfill_offset + pack_shortfall + ] + this_pack_players.extend(bf_slice) + backfill_offset += len(bf_slice) + + # Collect rare pull notifications + for player in this_pack_players: + if player["rarity"]["value"] >= 3: + all_pull_notifs.append(player) + + per_pack_players.append(this_pack_players) + + # --- Phase C: Write cards + mark packs opened, all gathered --- + open_time = int(datetime.datetime.timestamp(datetime.datetime.now()) * 1000) + + write_coros = [] + for pack, this_pack_players in zip(all_packs, per_pack_players): + write_coros.append( + db_post( + "cards", + payload={ + "cards": [ + { + "player_id": p["player_id"], + "team_id": pack["team"]["id"], + "pack_id": pack["id"], + } + for p in this_pack_players + ] + }, + timeout=10, + ) + ) + write_coros.append( + db_patch( + "packs", + object_id=pack["id"], + params=[("open_time", open_time)], + ) + ) + + write_results = await asyncio.gather(*write_coros) + + # Check card creation results (every other result starting at index 0) + for i in range(0, len(write_results), 2): + if not write_results[i]: + raise ConnectionError("Failed to create this pack of cards.") + + # --- Gather notification posts --- + if all_pull_notifs: + notif_coros = [] + for pull in all_pull_notifs: + logger.info(f"good pull: {pull}") + notif_coros.append( + db_post( + "notifs", + payload={ + "created": int( + datetime.datetime.timestamp(datetime.datetime.now()) * 1000 + ), + "title": "Rare Pull", + "field_name": f"{player_desc(pull)} ({pull['rarity']['name']})", + "message": f"Pulled by {team['abbrev']}", + "about": f"Player-{pull['player_id']}", + }, + ) + ) + await asyncio.gather(*notif_coros) + + return [pack["id"] for pack in all_packs] async def give_packs(team: dict, num_packs: int, pack_type: dict = None) -> dict: @@ -946,7 +1031,7 @@ def get_sheets(bot): except Exception as e: logger.error(f"Could not grab sheets auth: {e}") raise ConnectionError( - f"Bot has not authenticated with discord; please try again in 1 minute." + "Bot has not authenticated with discord; please try again in 1 minute." ) @@ -1056,7 +1141,7 @@ def get_blank_team_card(player): def get_rosters(team, bot, roster_num: Optional[int] = None) -> list: sheets = get_sheets(bot) this_sheet = sheets.open_by_key(team["gsheet"]) - r_sheet = this_sheet.worksheet_by_title(f"My Rosters") + r_sheet = this_sheet.worksheet_by_title("My Rosters") logger.debug(f"this_sheet: {this_sheet} / r_sheet = {r_sheet}") all_rosters = [None, None, None] @@ -1137,11 +1222,11 @@ def get_roster_lineups(team, bot, roster_num, lineup_num) -> list: try: lineup_cells = [(row[0].value, int(row[1].value)) for row in raw_cells] - except ValueError as e: + except ValueError: logger.error(f"Could not pull roster for {team['abbrev']} due to a ValueError") raise ValueError( - f"Uh oh. Looks like your roster might not be saved. I am reading blanks when I try to " - f"get the card IDs" + "Uh oh. Looks like your roster might not be saved. I am reading blanks when I try to " + "get the card IDs" ) logger.debug(f"lineup_cells: {lineup_cells}") @@ -1536,7 +1621,7 @@ def get_ratings_guide(sheets): } for x in p_data ] - except Exception as e: + except Exception: return {"valid": False} return {"valid": True, "batter_ratings": batters, "pitcher_ratings": pitchers} @@ -1748,7 +1833,7 @@ async def open_st_pr_packs(all_packs: list, team: dict, context): pack_ids = await roll_for_cards(all_packs) if not pack_ids: logger.error(f"open_packs - unable to roll_for_cards for packs: {all_packs}") - raise ValueError(f"I was not able to unpack these cards") + raise ValueError("I was not able to unpack these cards") all_cards = [] for p_id in pack_ids: @@ -1759,7 +1844,7 @@ async def open_st_pr_packs(all_packs: list, team: dict, context): if not all_cards: logger.error(f"open_packs - unable to get cards for packs: {pack_ids}") - raise ValueError(f"I was not able to display these cards") + raise ValueError("I was not able to display these cards") # Present cards to opening channel if type(context) == commands.Context: @@ -1818,7 +1903,7 @@ async def get_choice_from_cards( view = Pagination([interaction.user], timeout=30) view.left_button.disabled = True view.left_button.label = f"Prev: -/{len(card_embeds)}" - view.cancel_button.label = f"Take This Card" + view.cancel_button.label = "Take This Card" view.cancel_button.style = discord.ButtonStyle.success view.cancel_button.disabled = True view.right_button.label = f"Next: 1/{len(card_embeds)}" @@ -1836,7 +1921,7 @@ async def get_choice_from_cards( view = Pagination([interaction.user], timeout=30) view.left_button.label = f"Prev: -/{len(card_embeds)}" view.left_button.disabled = True - view.cancel_button.label = f"Take This Card" + view.cancel_button.label = "Take This Card" view.cancel_button.style = discord.ButtonStyle.success view.right_button.label = f"Next: {page_num + 1}/{len(card_embeds)}" @@ -1879,7 +1964,7 @@ async def get_choice_from_cards( view = Pagination([interaction.user], timeout=30) view.left_button.label = f"Prev: {page_num - 1}/{len(card_embeds)}" - view.cancel_button.label = f"Take This Card" + view.cancel_button.label = "Take This Card" view.cancel_button.style = discord.ButtonStyle.success view.right_button.label = f"Next: {page_num + 1}/{len(card_embeds)}" if page_num == 1: @@ -1925,7 +2010,7 @@ async def open_choice_pack( players = pl["players"] elif pack_type == "Team Choice": if this_pack["pack_team"] is None: - raise KeyError(f"Team not listed for Team Choice pack") + raise KeyError("Team not listed for Team Choice pack") d1000 = random.randint(1, 1000) pack_cover = this_pack["pack_team"]["logo"] @@ -1964,7 +2049,7 @@ async def open_choice_pack( rarity_id += 1 elif pack_type == "Promo Choice": if this_pack["pack_cardset"] is None: - raise KeyError(f"Cardset not listed for Promo Choice pack") + raise KeyError("Cardset not listed for Promo Choice pack") d1000 = random.randint(1, 1000) pack_cover = IMAGES["mvp-hype"] @@ -2021,8 +2106,8 @@ async def open_choice_pack( rarity_id += 3 if len(players) == 0: - logger.error(f"Could not create choice pack") - raise ConnectionError(f"Could not create choice pack") + logger.error("Could not create choice pack") + raise ConnectionError("Could not create choice pack") if type(context) == commands.Context: author = context.author @@ -2045,7 +2130,7 @@ async def open_choice_pack( view = Pagination([author], timeout=30) view.left_button.disabled = True view.left_button.label = f"Prev: -/{len(card_embeds)}" - view.cancel_button.label = f"Take This Card" + view.cancel_button.label = "Take This Card" view.cancel_button.style = discord.ButtonStyle.success view.cancel_button.disabled = True view.right_button.label = f"Next: 1/{len(card_embeds)}" @@ -2063,10 +2148,10 @@ async def open_choice_pack( ) if rarity_id >= 5: tmp_msg = await pack_channel.send( - content=f"<@&1163537676885033010> we've got an MVP!" + content="<@&1163537676885033010> we've got an MVP!" ) else: - tmp_msg = await pack_channel.send(content=f"We've got a choice pack here!") + tmp_msg = await pack_channel.send(content="We've got a choice pack here!") while True: await view.wait() @@ -2081,7 +2166,7 @@ async def open_choice_pack( ) except Exception as e: logger.error(f"failed to create cards: {e}") - raise ConnectionError(f"Failed to distribute these cards.") + raise ConnectionError("Failed to distribute these cards.") await db_patch( "packs", @@ -2115,7 +2200,7 @@ async def open_choice_pack( view = Pagination([author], timeout=30) view.left_button.label = f"Prev: {page_num - 1}/{len(card_embeds)}" - view.cancel_button.label = f"Take This Card" + view.cancel_button.label = "Take This Card" view.cancel_button.style = discord.ButtonStyle.success view.right_button.label = f"Next: {page_num + 1}/{len(card_embeds)}" if page_num == 1: diff --git a/in_game/gameplay_models.py b/in_game/gameplay_models.py index 77b76cd..4bd3599 100644 --- a/in_game/gameplay_models.py +++ b/in_game/gameplay_models.py @@ -8,26 +8,45 @@ import discord import pydantic from pydantic import field_validator -from sqlmodel import Session, SQLModel, UniqueConstraint, create_engine, select, or_, Field, Relationship, text, BigInteger +from sqlmodel import ( + Session, + SQLModel, + UniqueConstraint, + create_engine, + select, + or_, + Field, + Relationship, + text, + BigInteger, +) from sqlalchemy import Column, func, desc from exceptions import * -from in_game.managerai_responses import DefenseResponse, JumpResponse, RunResponse, TagResponse, ThrowResponse, UncappedRunResponse +from in_game.managerai_responses import ( + DefenseResponse, + JumpResponse, + RunResponse, + TagResponse, + ThrowResponse, + UncappedRunResponse, +) - -logger = logging.getLogger('discord_app') +logger = logging.getLogger("discord_app") # sqlite_url = 'sqlite:///storage/gameplay.db' # connect_args = {"check_same_thread": False} # engine = create_engine(sqlite_url, echo=False, connect_args=connect_args) -postgres_url = f'postgresql://{os.getenv('DB_USERNAME')}:{os.getenv('DB_PASSWORD')}@{os.getenv('DB_URL')}/{os.getenv('DB_NAME')}' +postgres_url = f"postgresql://{os.getenv('DB_USERNAME')}:{os.getenv('DB_PASSWORD')}@{os.getenv('DB_URL')}/{os.getenv('DB_NAME')}" engine = create_engine(postgres_url, pool_size=10, max_overflow=30) -CACHE_LIMIT = 259200 # 1209600 # in seconds -SBA_COLOR = 'a6ce39' -SBA_LOGO = 'https://paper-dynasty.s3.us-east-1.amazonaws.com/static-images/sba-logo.png' +CACHE_LIMIT = 259200 # 1209600 # in seconds +SBA_COLOR = "a6ce39" +SBA_LOGO = "https://paper-dynasty.s3.us-east-1.amazonaws.com/static-images/sba-logo.png" class ManagerAiBase(SQLModel): - id: int | None = Field(sa_column=Column(BigInteger(), primary_key=True, autoincrement=True)) + id: int | None = Field( + sa_column=Column(BigInteger(), primary_key=True, autoincrement=True) + ) name: str = Field(index=True) steal: int | None = Field(default=5) running: int | None = Field(default=5) @@ -43,26 +62,32 @@ class ManagerAiBase(SQLModel): class GameCardsetLink(SQLModel, table=True): - game_id: int | None = Field(default=None, foreign_key='game.id', primary_key=True) - cardset_id: int | None = Field(default=None, foreign_key='cardset.id', primary_key=True) + game_id: int | None = Field(default=None, foreign_key="game.id", primary_key=True) + cardset_id: int | None = Field( + default=None, foreign_key="cardset.id", primary_key=True + ) priority: int | None = Field(default=1, index=True) - game: 'Game' = Relationship(back_populates='cardset_links') - cardset: 'Cardset' = Relationship(back_populates='game_links') + game: "Game" = Relationship(back_populates="cardset_links") + cardset: "Cardset" = Relationship(back_populates="game_links") class RosterLink(SQLModel, table=True): - game_id: int | None = Field(default=None, foreign_key='game.id', primary_key=True) - card_id: int | None = Field(default=None, foreign_key='card.id', primary_key=True) - team_id: int = Field(index=True, foreign_key='team.id') + game_id: int | None = Field(default=None, foreign_key="game.id", primary_key=True) + card_id: int | None = Field(default=None, foreign_key="card.id", primary_key=True) + team_id: int = Field(index=True, foreign_key="team.id") - game: 'Game' = Relationship(back_populates='roster_links') - card: 'Card' = Relationship() - team: 'Team' = Relationship() + game: "Game" = Relationship(back_populates="roster_links") + card: "Card" = Relationship() + team: "Team" = Relationship() class TeamBase(SQLModel): - id: int = Field(sa_column=Column(BigInteger(), primary_key=True, autoincrement=False, unique=True)) + id: int = Field( + sa_column=Column( + BigInteger(), primary_key=True, autoincrement=False, unique=True + ) + ) abbrev: str = Field(index=True) sname: str lname: str @@ -79,7 +104,9 @@ class TeamBase(SQLModel): ranking: int has_guide: bool is_ai: bool - created: datetime.datetime = Field(default_factory=datetime.datetime.now, nullable=True) + created: datetime.datetime = Field( + default_factory=datetime.datetime.now, nullable=True + ) @property def description(self) -> str: @@ -87,26 +114,29 @@ class TeamBase(SQLModel): class Team(TeamBase, table=True): - cards: list['Card'] = Relationship(back_populates='team', cascade_delete=True) - lineups: list['Lineup'] = Relationship(back_populates='team', cascade_delete=True) + cards: list["Card"] = Relationship(back_populates="team", cascade_delete=True) + lineups: list["Lineup"] = Relationship(back_populates="team", cascade_delete=True) # away_games: list['Game'] = Relationship(back_populates='away_team') # home_games: list['Game'] = Relationship(back_populates='home_team') @property def embed(self) -> discord.Embed: embed = discord.Embed( - title=f'{self.lname}', - color=int(self.color, 16) if self.color else int(SBA_COLOR, 16) + title=f"{self.lname}", + color=int(self.color, 16) if self.color else int(SBA_COLOR, 16), ) - embed.set_footer(text=f'Paper Dynasty Season {self.season}', icon_url=SBA_LOGO) + embed.set_footer(text=f"Paper Dynasty Season {self.season}", icon_url=SBA_LOGO) embed.set_thumbnail(url=self.logo if self.logo else SBA_LOGO) return embed - + class Game(SQLModel, table=True): - id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=True)) - away_team_id: int = Field(foreign_key='team.id') - home_team_id: int = Field(foreign_key='team.id') + id: int | None = Field( + default=None, + sa_column=Column(BigInteger(), primary_key=True, autoincrement=True), + ) + away_team_id: int = Field(foreign_key="team.id") + home_team_id: int = Field(foreign_key="team.id") channel_id: int = Field(sa_column=(Column(BigInteger(), index=True))) season: int active: bool | None = Field(default=True) @@ -123,8 +153,12 @@ class Game(SQLModel, table=True): roll_buttons: bool | None = Field(default=True) auto_roll: bool | None = Field(default=False) - cardset_links: list[GameCardsetLink] = Relationship(back_populates='game', cascade_delete=True) - roster_links: list[RosterLink] = Relationship(back_populates='game', cascade_delete=True) + cardset_links: list[GameCardsetLink] = Relationship( + back_populates="game", cascade_delete=True + ) + roster_links: list[RosterLink] = Relationship( + back_populates="game", cascade_delete=True + ) away_team: Team = Relationship( # back_populates='away_games', # sa_relationship_kwargs={ @@ -143,26 +177,31 @@ class Game(SQLModel, table=True): # } sa_relationship_kwargs=dict(foreign_keys="[Game.home_team_id]") ) - lineups: list['Lineup'] = Relationship(back_populates='game', cascade_delete=True) - plays: list['Play'] = Relationship(back_populates='game', cascade_delete=True) + lineups: list["Lineup"] = Relationship(back_populates="game", cascade_delete=True) + plays: list["Play"] = Relationship(back_populates="game", cascade_delete=True) - @field_validator('ai_team', 'game_type') + @field_validator("ai_team", "game_type") def lowercase_strings(cls, value: str) -> str: return value.lower() - + @property def cardset_param_string(self) -> str: - pri_cardsets = '' - back_cardsets = '' + pri_cardsets = "" + back_cardsets = "" for link in self.cardset_links: if link.priority == 1: - pri_cardsets += f'&cardset_id={link.cardset_id}' + pri_cardsets += f"&cardset_id={link.cardset_id}" else: - back_cardsets += f'&backup_cardset_id={link.cardset_id}' - return f'{pri_cardsets}{back_cardsets}' + back_cardsets += f"&backup_cardset_id={link.cardset_id}" + return f"{pri_cardsets}{back_cardsets}" def current_play_or_none(self, session: Session): - this_play = session.exec(select(Play).where(Play.game == self, Play.complete == False).order_by(Play.play_num.desc()).limit(1)).all() + this_play = session.exec( + select(Play) + .where(Play.game == self, Play.complete == False) + .order_by(Play.play_num.desc()) + .limit(1) + ).all() if len(this_play) == 1: return this_play[0] else: @@ -175,11 +214,15 @@ class Game(SQLModel, table=True): existing_play = self.current_play_or_none(session) if existing_play is not None: return existing_play - - all_plays = session.exec(select(func.count(Play.id)).where(Play.game == self)).one() + + all_plays = session.exec( + select(func.count(Play.id)).where(Play.game == self) + ).one() if all_plays > 0: - raise PlayInitException(f'{all_plays} plays for game {self.id} already exist, but all are complete.') - + raise PlayInitException( + f"{all_plays} plays for game {self.id} already exist, but all are complete." + ) + leadoff_batter, home_pitcher, home_catcher = None, None, None home_positions, away_positions = [], [] for line in [x for x in self.lineups if x.active]: @@ -191,25 +234,31 @@ class Game(SQLModel, table=True): else: if line.position not in home_positions: home_positions.append(line.position) - if line.position == 'P': + if line.position == "P": home_pitcher = line - elif line.position == 'C': + elif line.position == "C": home_catcher = line - + if len(home_positions) != 10: - e_msg = f'Only {len(home_positions)} players found on home team' + e_msg = f"Only {len(home_positions)} players found on home team" log_exception(LineupsMissingException, e_msg) if len(away_positions) != 10: - e_msg = f'Only {len(away_positions)} players found on away team' + e_msg = f"Only {len(away_positions)} players found on away team" log_exception(LineupsMissingException, e_msg) if None in [leadoff_batter, home_pitcher, home_catcher]: - e_msg = f'Could not set the initial pitcher, catcher, and batter' + e_msg = f"Could not set the initial pitcher, catcher, and batter" log_exception(LineupsMissingException, e_msg) - manager_ai_id = ((datetime.datetime.now().day * (self.away_team_id if self.ai_team == 'away' else self.home_team_id)) % 3) + 1 + manager_ai_id = ( + ( + datetime.datetime.now().day + * (self.away_team_id if self.ai_team == "away" else self.home_team_id) + ) + % 3 + ) + 1 if manager_ai_id > 3 or manager_ai_id < 1: manager_ai_id = 1 - + new_play = Play( game=self, play_num=1, @@ -219,7 +268,7 @@ class Game(SQLModel, table=True): catcher=home_catcher, is_tied=True, is_new_inning=True, - managerai_id=manager_ai_id + managerai_id=manager_ai_id, ) session.add(new_play) session.commit() @@ -230,30 +279,36 @@ class Game(SQLModel, table=True): return new_play def team_lineup(self, session: Session, team: Team, with_links: bool = True) -> str: - all_lineups = session.exec(select(Lineup).where(Lineup.team == team, Lineup.game == self, Lineup.active).order_by(Lineup.batting_order)).all() + all_lineups = session.exec( + select(Lineup) + .where(Lineup.team == team, Lineup.game == self, Lineup.active) + .order_by(Lineup.batting_order) + ).all() - logger.info(f'all_lineups: {all_lineups}') - lineup_val = '' + logger.info(f"all_lineups: {all_lineups}") + lineup_val = "" for line in all_lineups: - logger.info(f'line in all_lineups: {line}') + logger.info(f"line in all_lineups: {line}") if with_links: - name_string = line.player.name_card_link("batting" if line.position != "P" else "pitching") + name_string = line.player.name_card_link( + "batting" if line.position != "P" else "pitching" + ) else: - name_string = f'{line.player.name_with_desc}' + name_string = f"{line.player.name_with_desc}" - if line.position == 'P': + if line.position == "P": if line.card.pitcherscouting: this_hand = line.card.pitcherscouting.pitchingcard.hand elif line.card.batterscouting: # Fallback to batting hand if pitcherscouting is missing this_hand = line.card.batterscouting.battingcard.hand else: - this_hand = '?' + this_hand = "?" else: this_hand = line.card.batterscouting.battingcard.hand - lineup_val += f'{line.batting_order}. {this_hand.upper()} | {name_string}, {line.position}\n' - + lineup_val += f"{line.batting_order}. {this_hand.upper()} | {name_string}, {line.position}\n" + return lineup_val @property @@ -266,29 +321,28 @@ class Game(SQLModel, table=True): raise NoHumanTeamsException else: raise MultipleHumanTeamsException - + @property def league_name(self): - if 'gauntlet' in self.game_type: - parts = self.game_type.split('-') - return f'{parts[0]}-{parts[1]}' + if "gauntlet" in self.game_type: + parts = self.game_type.split("-") + return f"{parts[0]}-{parts[1]}" else: return self.game_type class ManagerAi(ManagerAiBase, table=True): - plays: list['Play'] = Relationship(back_populates='managerai') + plays: list["Play"] = Relationship(back_populates="managerai") + def create_ai(session: Session = None): def get_new_ai(this_session: Session): all_ai = this_session.exec(select(ManagerAi.id)).all() if len(all_ai) == 0: - logger.info(f'Creating ManagerAI records') + logger.info(f"Creating ManagerAI records") new_ai = [ + ManagerAi(name="Balanced"), ManagerAi( - name='Balanced' - ), - ManagerAi( - name='Yolo', + name="Yolo", steal=10, running=10, hold=5, @@ -299,10 +353,10 @@ class ManagerAi(ManagerAiBase, table=True): bullpen_matchup=3, behind_aggression=10, ahead_aggression=10, - decide_throw=10 + decide_throw=10, ), ManagerAi( - name='Safe', + name="Safe", steal=3, running=3, hold=8, @@ -313,44 +367,59 @@ class ManagerAi(ManagerAiBase, table=True): bullpen_matchup=8, behind_aggression=5, ahead_aggression=1, - decide_throw=1 - ) + decide_throw=1, + ), ] for x in new_ai: session.add(x) session.commit() - + if session is None: with Session(engine) as session: get_new_ai(session) else: get_new_ai(session) - + return True - def check_jump(self, session: Session, this_game: Game, to_base: Literal[2, 3, 4]) -> JumpResponse: - logger.info(f'Checking jump to {to_base} in Game {this_game.id}') + def check_jump( + self, session: Session, this_game: Game, to_base: Literal[2, 3, 4] + ) -> JumpResponse: + logger.info(f"Checking jump to {to_base} in Game {this_game.id}") this_resp = JumpResponse(min_safe=20) this_play = this_game.current_play_or_none(session) if this_play is None: - raise GameException(f'No game found while checking for jump') - + raise GameException(f"No game found while checking for jump") + num_outs = this_play.starting_outs run_diff = this_play.away_score - this_play.home_score - if this_game.ai_team == 'home': + if this_game.ai_team == "home": run_diff = run_diff * -1 - + pitcher_hold = this_play.pitcher.card.pitcherscouting.pitchingcard.hold - catcher_defense = session.exec(select(PositionRating).where(PositionRating.player_id == this_play.catcher.player_id, PositionRating.position == 'C', PositionRating.variant == this_play.catcher.card.variant)).one() + catcher_defense = session.exec( + select(PositionRating).where( + PositionRating.player_id == this_play.catcher.player_id, + PositionRating.position == "C", + PositionRating.variant == this_play.catcher.card.variant, + ) + ).one() catcher_hold = catcher_defense.arm - battery_hold = pitcher_hold + catcher_hold - logger.info(f'game state: {num_outs} outs, {run_diff} run diff, battery_hold: {battery_hold}') - + battery_hold = pitcher_hold + catcher_hold + logger.info( + f"game state: {num_outs} outs, {run_diff} run diff, battery_hold: {battery_hold}" + ) + if to_base == 2: runner = this_play.on_first if runner is None: - log_exception(CardNotFoundException, f'Attempted to check a jump to 2nd base, but no runner found on first.') - logger.info(f'Checking steal numbers for {runner.player.name} in Game {this_game.id}') + log_exception( + CardNotFoundException, + f"Attempted to check a jump to 2nd base, but no runner found on first.", + ) + logger.info( + f"Checking steal numbers for {runner.player.name} in Game {this_game.id}" + ) match self.steal: case 10: @@ -365,34 +434,39 @@ class ManagerAi(ManagerAiBase, table=True): this_resp.min_safe = 16 + num_outs case _: this_resp.min_safe = 17 + num_outs - + if self.steal > 7 and num_outs < 2 and run_diff <= 5: this_resp.run_if_auto_jump = True elif self.steal < 5: this_resp.must_auto_jump = True - + runner_card = runner.card.batterscouting.battingcard if this_resp.run_if_auto_jump and runner_card.steal_auto: - this_resp.ai_note = f'- WILL SEND **{runner.player.name}** to second!' + this_resp.ai_note = f"- WILL SEND **{runner.player.name}** to second!" elif this_resp.must_auto_jump and not runner_card.steal_auto: - logger.info(f'No jump ai note') + logger.info(f"No jump ai note") else: jump_safe_range = runner_card.steal_high + battery_hold nojump_safe_range = runner_card.steal_low + battery_hold - logger.info(f'jump_safe_range: {jump_safe_range} / nojump_safe_range: {nojump_safe_range} / min_safe: {this_resp.min_safe}') + logger.info( + f"jump_safe_range: {jump_safe_range} / nojump_safe_range: {nojump_safe_range} / min_safe: {this_resp.min_safe}" + ) if this_resp.min_safe <= nojump_safe_range: - this_resp.ai_note = f'- SEND **{runner.player.name}** to second!' + this_resp.ai_note = f"- SEND **{runner.player.name}** to second!" elif this_resp.min_safe <= jump_safe_range: - this_resp.ai_note = f'- SEND **{runner.player.name}** to second if they get the jump' + this_resp.ai_note = f"- SEND **{runner.player.name}** to second if they get the jump" elif to_base == 3: runner = this_play.on_second if runner is None: - log_exception(CardNotFoundException, f'Attempted to check a jump to 3rd base, but no runner found on second.') + log_exception( + CardNotFoundException, + f"Attempted to check a jump to 3rd base, but no runner found on second.", + ) match self.steal: case 10: @@ -405,55 +479,66 @@ class ManagerAi(ManagerAiBase, table=True): if self.steal == 10 and num_outs < 2 and run_diff <= 5: this_resp.run_if_auto_jump = True elif self.steal <= 5: - this_resp.must_auto_jump = True - + this_resp.must_auto_jump = True + runner_card = runner.card.batterscouting.battingcard if this_resp.run_if_auto_jump and runner_card.steal_auto: - this_resp.ai_note = f'- SEND **{runner.player.name}** to third!' + this_resp.ai_note = f"- SEND **{runner.player.name}** to third!" - elif this_resp.must_auto_jump and not runner_card.steal_auto or this_resp.min_safe is None: - logger.info(f'No jump ai note') + elif ( + this_resp.must_auto_jump + and not runner_card.steal_auto + or this_resp.min_safe is None + ): + logger.info(f"No jump ai note") else: jump_safe_range = runner_card.steal_low + battery_hold - logger.info(f'jump_safe_range: {jump_safe_range} / min_safe: {this_resp.min_safe}') + logger.info( + f"jump_safe_range: {jump_safe_range} / min_safe: {this_resp.min_safe}" + ) if this_resp.min_safe <= jump_safe_range: - this_resp.ai_note = f'- SEND **{runner.player.name}** to third!' - + this_resp.ai_note = f"- SEND **{runner.player.name}** to third!" + elif run_diff in [-1, 0]: runner = this_play.on_third if runner is None: - log_exception(CardNotFoundException, f'Attempted to check a jump to home, but no runner found on third.') + log_exception( + CardNotFoundException, + f"Attempted to check a jump to home, but no runner found on third.", + ) if self.steal == 10: this_resp.min_safe = 5 elif this_play.inning_num > 7 and self.steal >= 5: - this_resp.min_safe = 6 + this_resp.min_safe = 6 elif self.steal > 5: - this_resp.min_safe = 7 + this_resp.min_safe = 7 elif self.steal > 2: - this_resp.min_safe = 8 + this_resp.min_safe = 8 else: - this_resp.min_safe = 10 - + this_resp.min_safe = 10 + runner_card = runner.card.batterscouting.battingcard jump_safe_range = runner_card.steal_low - 9 if this_resp.min_safe <= jump_safe_range: - this_resp.ai_note = f'- SEND **{runner.player.name}** to third!' - - logger.info(f'Returning jump resp to game {this_game.id}: {this_resp}') + this_resp.ai_note = f"- SEND **{runner.player.name}** to third!" + + logger.info(f"Returning jump resp to game {this_game.id}: {this_resp}") return this_resp - + def tag_from_second(self, session: Session, this_game: Game) -> TagResponse: this_resp = TagResponse() this_play = this_game.current_play_or_none(session) if this_play is None: - raise GameException(f'No game found while checking tag_from_second') - + raise GameException(f"No game found while checking tag_from_second") + ai_rd = this_play.ai_run_diff - aggression_mod = abs(self.ahead_aggression - 5 if ai_rd > 0 else self.behind_aggression - 5) + aggression_mod = abs( + self.ahead_aggression - 5 if ai_rd > 0 else self.behind_aggression - 5 + ) adjusted_running = self.running + aggression_mod if adjusted_running >= 8: @@ -462,23 +547,25 @@ class ManagerAi(ManagerAiBase, table=True): this_resp.min_safe = 7 else: this_resp.min_safe = 10 - + if this_play.starting_outs == 1: this_resp.min_safe -= 2 else: this_resp.min_safe += 2 - - logger.info(f'tag_from_second response: {this_resp}') + + logger.info(f"tag_from_second response: {this_resp}") return this_resp def tag_from_third(self, session: Session, this_game: Game) -> TagResponse: this_resp = TagResponse() this_play = this_game.current_play_or_none(session) if this_play is None: - raise GameException(f'No game found while checking tag_from_third') - + raise GameException(f"No game found while checking tag_from_third") + ai_rd = this_play.ai_run_diff - aggression_mod = abs(self.ahead_aggression - 5 if ai_rd > 0 else self.behind_aggression - 5) + aggression_mod = abs( + self.ahead_aggression - 5 if ai_rd > 0 else self.behind_aggression - 5 + ) adjusted_running = self.running + aggression_mod if adjusted_running >= 8: @@ -487,22 +574,22 @@ class ManagerAi(ManagerAiBase, table=True): this_resp.min_safe = 10 else: this_resp.min_safe = 12 - + if ai_rd in [-1, 0]: this_resp.min_safe -= 2 - + if this_play.starting_outs == 1: this_resp.min_safe -= 2 - - logger.info(f'tag_from_third response: {this_resp}') + + logger.info(f"tag_from_third response: {this_resp}") return this_resp def throw_at_uncapped(self, session: Session, this_game: Game) -> ThrowResponse: this_resp = ThrowResponse() this_play = this_game.current_play_or_none(session) if this_play is None: - raise GameException(f'No game found while checking throw_at_uncapped') - + raise GameException(f"No game found while checking throw_at_uncapped") + ai_rd = this_play.ai_run_diff aggression = self.ahead_aggression if ai_rd > 0 else self.behind_aggression current_outs = this_play.starting_outs + this_play.outs @@ -533,24 +620,30 @@ class ManagerAi(ManagerAiBase, table=True): if self.behind_aggression < 5: this_resp.at_trail_runner = True this_resp.trail_max_safe_delta = -4 - - logger.info(f'throw_at_uncapped response: {this_resp}') + + logger.info(f"throw_at_uncapped response: {this_resp}") return this_resp - def uncapped_advance(self, session: Session, this_game: Game, lead_base: int, trail_base: int) -> UncappedRunResponse: + def uncapped_advance( + self, session: Session, this_game: Game, lead_base: int, trail_base: int + ) -> UncappedRunResponse: this_resp = UncappedRunResponse() this_play = this_game.current_play_or_none(session) if this_play is None: - raise GameException(f'No game found while checking uncapped_advance_lead') - + raise GameException(f"No game found while checking uncapped_advance_lead") + ai_rd = this_play.ai_run_diff - aggression = self.ahead_aggression - 5 if ai_rd > 0 else self.behind_aggression - 5 + aggression = ( + self.ahead_aggression - 5 if ai_rd > 0 else self.behind_aggression - 5 + ) if ai_rd > 4: if lead_base == 4: this_resp.min_safe = 16 - this_play.starting_outs - aggression this_resp.send_trail = True - this_resp.trail_min_safe = 10 - aggression - this_play.starting_outs - this_play.outs + this_resp.trail_min_safe = ( + 10 - aggression - this_play.starting_outs - this_play.outs + ) elif lead_base == 3: this_resp.min_safe = 14 + (this_play.starting_outs * 2) - aggression if this_play.starting_outs + this_play.outs >= 2: @@ -559,9 +652,13 @@ class ManagerAi(ManagerAiBase, table=True): if lead_base == 4: this_resp.min_safe = 12 - this_play.starting_outs - aggression this_resp.send_trail = True - this_resp.trail_min_safe = 10 - aggression - this_play.starting_outs - this_play.outs + this_resp.trail_min_safe = ( + 10 - aggression - this_play.starting_outs - this_play.outs + ) elif lead_base == 3: - this_resp.min_safe = 12 + (this_play.starting_outs * 2) - (aggression * 2) + this_resp.min_safe = ( + 12 + (this_play.starting_outs * 2) - (aggression * 2) + ) if this_play.starting_outs + this_play.outs >= 2: this_resp.send_trail = False else: @@ -582,222 +679,328 @@ class ManagerAi(ManagerAiBase, table=True): this_resp.min_safe = 20 if this_resp.trail_min_safe < 1: this_resp.min_safe = 1 - - logger.info(f'Uncapped advance response: {this_resp}') + + logger.info(f"Uncapped advance response: {this_resp}") return this_resp def defense_alignment(self, session: Session, this_game: Game) -> DefenseResponse: - logger.info(f'checking defensive alignment in game {this_game.id}') + logger.info(f"checking defensive alignment in game {this_game.id}") this_resp = DefenseResponse() this_play = this_game.current_play_or_none(session) if this_play is None: - raise GameException(f'No game found while checking defense_alignment') - - logger.info(f'defense_alignment - this_play: {this_play}') + raise GameException(f"No game found while checking defense_alignment") + + logger.info(f"defense_alignment - this_play: {this_play}") ai_rd = this_play.ai_run_diff - aggression = self.ahead_aggression - 5 if ai_rd > 0 else self.behind_aggression - 5 + aggression = ( + self.ahead_aggression - 5 if ai_rd > 0 else self.behind_aggression - 5 + ) pitcher_hold = this_play.pitcher.card.pitcherscouting.pitchingcard.hold - - catcher_defense = session.exec(select(PositionRating).where(PositionRating.player_id == this_play.catcher.player_id, PositionRating.position == 'C', PositionRating.variant == this_play.catcher.card.variant)).one() + + catcher_defense = session.exec( + select(PositionRating).where( + PositionRating.player_id == this_play.catcher.player_id, + PositionRating.position == "C", + PositionRating.variant == this_play.catcher.card.variant, + ) + ).one() catcher_hold = catcher_defense.arm - battery_hold = pitcher_hold + catcher_hold + battery_hold = pitcher_hold + catcher_hold if this_play.starting_outs == 2 and this_play.on_base_code > 0: - logger.info(f'Checking for holds with 2 outs') + logger.info(f"Checking for holds with 2 outs") if this_play.on_base_code == 1: this_resp.hold_first = True - this_resp.ai_note += f'- hold {this_play.on_first.player.name} on 1st\n' + this_resp.ai_note += f"- hold {this_play.on_first.player.name} on 1st\n" elif this_play.on_base_code == 2: this_resp.hold_second = True - this_resp.ai_note += f'- hold {this_play.on_second.player.name} on 2nd\n' + this_resp.ai_note += ( + f"- hold {this_play.on_second.player.name} on 2nd\n" + ) elif this_play.on_base_code in [4, 7]: this_resp.hold_first = True this_resp.hold_second = True - this_resp.ai_note += f'- hold {this_play.on_first.player.name} on 1st\n- hold {this_play.on_second.player.name} on 2nd\n' + this_resp.ai_note += f"- hold {this_play.on_first.player.name} on 1st\n- hold {this_play.on_second.player.name} on 2nd\n" elif this_play.on_base_code == 5: this_resp.hold_first = True - this_resp.ai_note += f'- hold {this_play.on_first.player.name} on first\n' + this_resp.ai_note += ( + f"- hold {this_play.on_first.player.name} on first\n" + ) elif this_play.on_base_code == 6: this_resp.hold_second = True - this_resp.ai_note += f'- hold {this_play.on_second.player.name} on 2nd\n' + this_resp.ai_note += ( + f"- hold {this_play.on_second.player.name} on 2nd\n" + ) elif this_play.on_base_code in [1, 5]: - logger.info(f'Checking for hold with runner on first') + logger.info(f"Checking for hold with runner on first") runner = this_play.on_first.player - if this_play.on_first.card.batterscouting.battingcard.steal_auto and ((this_play.on_first.card.batterscouting.battingcard.steal_high + battery_hold) >= (12 - aggression)): + if this_play.on_first.card.batterscouting.battingcard.steal_auto and ( + ( + this_play.on_first.card.batterscouting.battingcard.steal_high + + battery_hold + ) + >= (12 - aggression) + ): this_resp.hold_first = True - this_resp.ai_note += f'- hold {runner.name} on 1st\n' + this_resp.ai_note += f"- hold {runner.name} on 1st\n" elif this_play.on_base_code in [2, 4]: - logger.info(f'Checking for hold with runner on second') - if (this_play.on_second.card.batterscouting.battingcard.steal_low + max(battery_hold, 5)) >= (14 - aggression): + logger.info(f"Checking for hold with runner on second") + if ( + this_play.on_second.card.batterscouting.battingcard.steal_low + + max(battery_hold, 5) + ) >= (14 - aggression): this_resp.hold_second = True - this_resp.ai_note += f'- hold {this_play.on_second.player.name} on 2nd\n' + this_resp.ai_note += ( + f"- hold {this_play.on_second.player.name} on 2nd\n" + ) # Defensive Alignment if this_play.on_third and this_play.starting_outs < 2: if this_play.could_walkoff: this_resp.outfield_in = True this_resp.infield_in = True - this_resp.ai_note += f'- play the outfield and infield in' + this_resp.ai_note += f"- play the outfield and infield in" elif this_play.on_first and this_play.starting_outs == 1: this_resp.corners_in = True - this_resp.ai_note += f'- play the corners in\n' + this_resp.ai_note += f"- play the corners in\n" elif abs(this_play.away_score - this_play.home_score) <= 3: this_resp.infield_in = True - this_resp.ai_note += f'- play the whole infield in\n' + this_resp.ai_note += f"- play the whole infield in\n" else: this_resp.corners_in = True - this_resp.ai_note += f'- play the corners in\n' - + this_resp.ai_note += f"- play the corners in\n" + if len(this_resp.ai_note) == 0 and this_play.on_base_code > 0: - this_resp.ai_note += f'- play straight up\n' - - logger.info(f'Defense alignment response: {this_resp}') + this_resp.ai_note += f"- play straight up\n" + + logger.info(f"Defense alignment response: {this_resp}") return this_resp - + def gb_decide_run(self, session: Session, this_game: Game) -> RunResponse: this_resp = RunResponse() this_play = this_game.current_play_or_none(session) if this_play is None: - raise GameException(f'No game found while checking gb_decide_run') - - ai_rd = this_play.ai_run_diff - aggression = self.ahead_aggression - 5 if ai_rd > 0 else self.behind_aggression - 5 + raise GameException(f"No game found while checking gb_decide_run") - this_resp.min_safe = 15 - aggression # TODO: write this algorithm - logger.info(f'gb_decide_run response: {this_resp}') + ai_rd = this_play.ai_run_diff + aggression_mod = abs( + self.ahead_aggression - 5 if ai_rd > 0 else self.behind_aggression - 5 + ) + adjusted_running = self.running + aggression_mod + + if adjusted_running >= 8: + this_resp.min_safe = 4 + elif adjusted_running >= 5: + this_resp.min_safe = 6 + else: + this_resp.min_safe = 8 + + if this_play.starting_outs == 2: + this_resp.min_safe -= 2 + elif this_play.starting_outs == 0: + this_resp.min_safe += 2 + logger.info(f"gb_decide_run response: {this_resp}") return this_resp - - def gb_decide_throw(self, session: Session, this_game: Game, runner_speed: int, defender_range: int) -> ThrowResponse: + + def gb_decide_throw( + self, session: Session, this_game: Game, runner_speed: int, defender_range: int + ) -> ThrowResponse: this_resp = ThrowResponse(at_lead_runner=True) this_play = this_game.current_play_or_none(session) if this_play is None: - raise GameException(f'No game found while checking gb_decide_throw') - + raise GameException(f"No game found while checking gb_decide_throw") + ai_rd = this_play.ai_run_diff - aggression = self.ahead_aggression - 5 if ai_rd > 0 else self.behind_aggression - 5 + aggression = ( + self.ahead_aggression - 5 if ai_rd > 0 else self.behind_aggression - 5 + ) if (runner_speed - 4 + defender_range) <= (10 + aggression): this_resp.at_lead_runner = True - - logger.info(f'gb_decide_throw response: {this_resp}') + + logger.info(f"gb_decide_throw response: {this_resp}") return this_resp def replace_pitcher(self, session: Session, this_game: Game) -> bool: - logger.info(f'Checking if fatigued pitcher should be replaced') + logger.info(f"Checking if fatigued pitcher should be replaced") this_play = this_game.current_play_or_none(session) if this_play is None: - raise GameException(f'No game found while checking replace_pitcher') - - this_pitcher = this_play.pitcher - outs = session.exec(select(func.sum(Play.outs)).where( - Play.game == this_game, Play.pitcher == this_pitcher, Play.complete == True - )).one() - logger.info(f'Pitcher: {this_pitcher.card.player.name_with_desc} / Outs: {outs}') + raise GameException(f"No game found while checking replace_pitcher") - allowed_runners = session.exec(select(func.count(Play.id)).where( - Play.game == this_game, Play.pitcher == this_pitcher, or_(Play.hit == 1, Play.bb == 1) - )).one() + this_pitcher = this_play.pitcher + outs = session.exec( + select(func.sum(Play.outs)).where( + Play.game == this_game, + Play.pitcher == this_pitcher, + Play.complete == True, + ) + ).one() + logger.info( + f"Pitcher: {this_pitcher.card.player.name_with_desc} / Outs: {outs}" + ) + + allowed_runners = session.exec( + select(func.count(Play.id)).where( + Play.game == this_game, + Play.pitcher == this_pitcher, + or_(Play.hit == 1, Play.bb == 1), + ) + ).one() run_diff = this_play.ai_run_diff - logger.info(f'run diff: {run_diff} / allowed runners: {allowed_runners} / behind aggro: {self.behind_aggression} / ahead aggro: {self.ahead_aggression}') - logger.info(f'this play: {this_play}') + logger.info( + f"run diff: {run_diff} / allowed runners: {allowed_runners} / behind aggro: {self.behind_aggression} / ahead aggro: {self.ahead_aggression}" + ) + logger.info(f"this play: {this_play}") if this_pitcher.replacing_id is None: pitcher_pow = this_pitcher.card.pitcherscouting.pitchingcard.starter_rating - logger.info(f'Starter POW: {pitcher_pow}') + logger.info(f"Starter POW: {pitcher_pow}") if outs >= pitcher_pow * 3 + 6: - logger.info(f'Starter has thrown POW + 3 - being pulled') + logger.info(f"Starter has thrown POW + 3 - being pulled") return True - + elif allowed_runners < 5: - logger.info(f'Starter is cooking with {allowed_runners} runners allowed - staying in') + logger.info( + f"Starter is cooking with {allowed_runners} runners allowed - staying in" + ) return False - + elif this_pitcher.is_fatigued and this_play.on_base_code > 1: - logger.info(f'Starter is fatigued') + logger.info(f"Starter is fatigued") return True - - elif (run_diff > 5 or (run_diff > 2 and self.ahead_aggression > 5)) and (allowed_runners < run_diff or this_play.on_base_code <= 3): - logger.info(f'AI team has big lead of {run_diff} - staying in') + + elif (run_diff > 5 or (run_diff > 2 and self.ahead_aggression > 5)) and ( + allowed_runners < run_diff or this_play.on_base_code <= 3 + ): + logger.info(f"AI team has big lead of {run_diff} - staying in") return False - - elif (run_diff > 2 or (run_diff >= 0 and self.ahead_aggression > 5)) and (allowed_runners < run_diff or this_play.on_base_code <= 1): - logger.info(f'AI team has lead of {run_diff} - staying in') + + elif (run_diff > 2 or (run_diff >= 0 and self.ahead_aggression > 5)) and ( + allowed_runners < run_diff or this_play.on_base_code <= 1 + ): + logger.info(f"AI team has lead of {run_diff} - staying in") return False - - elif (run_diff >= 0 or (run_diff >= -2 and self.behind_aggression > 5)) and (allowed_runners < 5 and this_play.on_base_code <= run_diff): - logger.info(f'AI team in close game with run diff of {run_diff} - staying in') + + elif ( + run_diff >= 0 or (run_diff >= -2 and self.behind_aggression > 5) + ) and (allowed_runners < 5 and this_play.on_base_code <= run_diff): + logger.info( + f"AI team in close game with run diff of {run_diff} - staying in" + ) return False - - elif run_diff >= -3 and self.behind_aggression > 5 and allowed_runners < 5 and this_play.on_base_code <= 1: - logger.info(f'AI team is close behind with run diff of {run_diff} - staying in') + + elif ( + run_diff >= -3 + and self.behind_aggression > 5 + and allowed_runners < 5 + and this_play.on_base_code <= 1 + ): + logger.info( + f"AI team is close behind with run diff of {run_diff} - staying in" + ) return False - + elif run_diff <= -5 and this_play.inning_num <= 3: - logger.info(f'AI team is way behind and starter is going to wear it - staying in') + logger.info( + f"AI team is way behind and starter is going to wear it - staying in" + ) return False - + else: - logger.info(f'AI team found no exceptions - pull starter') + logger.info(f"AI team found no exceptions - pull starter") return True else: pitcher_pow = this_pitcher.card.pitcherscouting.pitchingcard.relief_rating - logger.info(f'Reliever POW: {pitcher_pow}') + logger.info(f"Reliever POW: {pitcher_pow}") if outs >= pitcher_pow * 3 + 3: - logger.info(f'Only allow POW + 1 IP - pull reliever') - return True - - elif this_pitcher.is_fatigued and this_play.is_new_inning: - logger.info(f'Reliever is fatigued to start the inning - pull reliever') + logger.info(f"Only allow POW + 1 IP - pull reliever") return True - elif (run_diff > 5 or (run_diff > 2 and self.ahead_aggression > 5)) and (this_play.starting_outs == 2 or allowed_runners <= run_diff or this_play.on_base_code <= 3 or this_play.starting_outs == 2): - logger.info(f'AI team has big lead of {run_diff} - staying in') + elif this_pitcher.is_fatigued and this_play.is_new_inning: + logger.info(f"Reliever is fatigued to start the inning - pull reliever") + return True + + elif (run_diff > 5 or (run_diff > 2 and self.ahead_aggression > 5)) and ( + this_play.starting_outs == 2 + or allowed_runners <= run_diff + or this_play.on_base_code <= 3 + or this_play.starting_outs == 2 + ): + logger.info(f"AI team has big lead of {run_diff} - staying in") return False - - elif (run_diff > 2 or (run_diff >= 0 and self.ahead_aggression > 5)) and (allowed_runners < run_diff or this_play.on_base_code <= 1 or this_play.starting_outs == 2): - logger.info(f'AI team has lead of {run_diff} - staying in') + + elif (run_diff > 2 or (run_diff >= 0 and self.ahead_aggression > 5)) and ( + allowed_runners < run_diff + or this_play.on_base_code <= 1 + or this_play.starting_outs == 2 + ): + logger.info(f"AI team has lead of {run_diff} - staying in") return False - - elif (run_diff >= 0 or (run_diff >= -2 and self.behind_aggression > 5)) and (allowed_runners < 5 or this_play.on_base_code <= run_diff or this_play.starting_outs == 2): - logger.info(f'AI team in close game with run diff of {run_diff} - staying in') + + elif ( + run_diff >= 0 or (run_diff >= -2 and self.behind_aggression > 5) + ) and ( + allowed_runners < 5 + or this_play.on_base_code <= run_diff + or this_play.starting_outs == 2 + ): + logger.info( + f"AI team in close game with run diff of {run_diff} - staying in" + ) return False - - elif run_diff >= -3 and self.behind_aggression > 5 and allowed_runners < 5 and this_play.on_base_code <= 1: - logger.info(f'AI team is close behind with run diff of {run_diff} - staying in') + + elif ( + run_diff >= -3 + and self.behind_aggression > 5 + and allowed_runners < 5 + and this_play.on_base_code <= 1 + ): + logger.info( + f"AI team is close behind with run diff of {run_diff} - staying in" + ) return False - + elif run_diff <= -5 and this_play.starting_outs != 0: - logger.info(f'AI team is way behind and reliever is going to wear it - staying in') + logger.info( + f"AI team is way behind and reliever is going to wear it - staying in" + ) return False - + else: - logger.info(f'AI team found no exceptions - pull reliever') + logger.info(f"AI team found no exceptions - pull reliever") return True class CardsetBase(SQLModel): - id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=False)) + id: int | None = Field( + default=None, + sa_column=Column(BigInteger(), primary_key=True, autoincrement=False), + ) name: str ranked_legal: bool | None = Field(default=False) class Cardset(CardsetBase, table=True): - game_links: list[GameCardsetLink] = Relationship(back_populates='cardset', cascade_delete=True) - players: list['Player'] = Relationship(back_populates='cardset') + game_links: list[GameCardsetLink] = Relationship( + back_populates="cardset", cascade_delete=True + ) + players: list["Player"] = Relationship(back_populates="cardset") class PlayerBase(SQLModel): - id: int | None = Field(sa_column=Column(BigInteger(), primary_key=True, autoincrement=False)) + id: int | None = Field( + sa_column=Column(BigInteger(), primary_key=True, autoincrement=False) + ) name: str cost: int image: str mlbclub: str franchise: str - cardset_id: int | None = Field(default=None, foreign_key='cardset.id') + cardset_id: int | None = Field(default=None, foreign_key="cardset.id") set_num: int rarity_id: int | None = Field(default=None) pos_1: str @@ -817,9 +1020,13 @@ class PlayerBase(SQLModel): bbref_id: str | None = Field(default=None) fangr_id: str | None = Field(default=None) mlbplayer_id: int | None = Field(default=None) - created: datetime.datetime = Field(default_factory=datetime.datetime.now, nullable=True) + created: datetime.datetime = Field( + default_factory=datetime.datetime.now, nullable=True + ) - @field_validator('pos_1', 'pos_2', 'pos_3', 'pos_4', 'pos_5', 'pos_6', 'pos_7', 'pos_8') + @field_validator( + "pos_1", "pos_2", "pos_3", "pos_4", "pos_5", "pos_6", "pos_7", "pos_8" + ) def uppercase_strings(cls, value: str) -> str: if value is not None: return value.upper() @@ -828,73 +1035,82 @@ class PlayerBase(SQLModel): @property def batter_card_url(self): - if self.image and 'batting' in self.image: + if self.image and "batting" in self.image: return self.image - elif self.image2 and 'batting' in self.image2: - return self.image2 - else: - return None - - @property - def pitcher_card_url(self): - if self.image and 'pitching' in self.image: - return self.image - elif self.image2 and 'pitching' in self.image2: + elif self.image2 and "batting" in self.image2: return self.image2 else: return None - def name_card_link(self, which: Literal['pitching', 'batting']): - if which == 'pitching': - return f'[{self.name}]({self.pitcher_card_url})' + @property + def pitcher_card_url(self): + if self.image and "pitching" in self.image: + return self.image + elif self.image2 and "pitching" in self.image2: + return self.image2 else: - return f'[{self.name}]({self.batter_card_url})' + return None + + def name_card_link(self, which: Literal["pitching", "batting"]): + if which == "pitching": + return f"[{self.name}]({self.pitcher_card_url})" + else: + return f"[{self.name}]({self.batter_card_url})" class Player(PlayerBase, table=True): - cardset: Cardset = Relationship(back_populates='players') - cards: list['Card'] = Relationship(back_populates='player', cascade_delete=True) - lineups: list['Lineup'] = Relationship(back_populates='player', cascade_delete=True) - positions: list['PositionRating'] = Relationship(back_populates='player', cascade_delete=True) + cardset: Cardset = Relationship(back_populates="players") + cards: list["Card"] = Relationship(back_populates="player", cascade_delete=True) + lineups: list["Lineup"] = Relationship(back_populates="player", cascade_delete=True) + positions: list["PositionRating"] = Relationship( + back_populates="player", cascade_delete=True + ) @property def name_with_desc(self): - return f'{self.description} {self.name}' + return f"{self.description} {self.name}" def player_description(player: Player = None, player_dict: dict = None) -> str: if player is None and player_dict is None: - err = 'One of "player" or "player_dict" must be included to get full description' - logger.error(f'gameplay_models - player_description - {err}') + err = ( + 'One of "player" or "player_dict" must be included to get full description' + ) + logger.error(f"gameplay_models - player_description - {err}") raise TypeError(err) - + if player is not None: - return f'{player.description} {player.name}' - - r_val = f'{player_dict['description']}' - if 'name' in player_dict: + return f"{player.description} {player.name}" + + r_val = f"{player_dict['description']}" + if "name" in player_dict: r_val += f' {player_dict["name"]}' - elif 'p_name' in player_dict: + elif "p_name" in player_dict: r_val += f' {player_dict["p_name"]}' return r_val class BattingCardBase(SQLModel): - id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=False)) + id: int | None = Field( + default=None, + sa_column=Column(BigInteger(), primary_key=True, autoincrement=False), + ) variant: int | None = Field(default=0) steal_low: int = Field(default=0, ge=0, le=20) steal_high: int = Field(default=0, ge=0, le=20) steal_auto: bool = Field(default=False) steal_jump: float = Field(default=0.0, ge=0.0, le=1.0) - bunting: str = Field(default='C') - hit_and_run: str = Field(default='C') + bunting: str = Field(default="C") + hit_and_run: str = Field(default="C") running: int = Field(default=10, ge=1, le=20) offense_col: int = Field(ge=1, le=3) hand: str - created: datetime.datetime = Field(default_factory=datetime.datetime.now, nullable=True) + created: datetime.datetime = Field( + default_factory=datetime.datetime.now, nullable=True + ) # created: datetime.datetime | None = Field(sa_column_kwargs={"server_default": text("CURRENT_TIMESTAMP"),}) - @field_validator('hand') + @field_validator("hand") def lowercase_hand(cls, value: str) -> str: return value.lower() @@ -904,7 +1120,10 @@ class BattingCard(BattingCardBase, table=True): class BattingRatingsBase(SQLModel): - id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=False)) + id: int | None = Field( + default=None, + sa_column=Column(BigInteger(), primary_key=True, autoincrement=False), + ) homerun: float = Field(default=0.0, ge=0.0, le=108.0) bp_homerun: float = Field(default=0.0, ge=0.0, le=108.0) triple: float = Field(default=0.0, ge=0.0, le=108.0) @@ -933,7 +1152,9 @@ class BattingRatingsBase(SQLModel): pull_rate: float = Field(default=0.0, ge=0.0, le=1.0) center_rate: float = Field(default=0.0, ge=0.0, le=1.0) slap_rate: float = Field(default=0.0, ge=0.0, le=1.0) - created: datetime.datetime = Field(default_factory=datetime.datetime.now, nullable=True) + created: datetime.datetime = Field( + default_factory=datetime.datetime.now, nullable=True + ) class BattingRatings(BattingRatingsBase, table=True): @@ -941,26 +1162,48 @@ class BattingRatings(BattingRatingsBase, table=True): class BatterScoutingBase(SQLModel): - id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=True)) - battingcard_id: int | None = Field(default=None, foreign_key='battingcard.id', ondelete='CASCADE') - ratings_vl_id: int | None = Field(default=None, foreign_key='battingratings.id', ondelete='CASCADE') - ratings_vr_id: int | None = Field(default=None, foreign_key='battingratings.id', ondelete='CASCADE') - created: datetime.datetime = Field(default_factory=datetime.datetime.now, nullable=True) + id: int | None = Field( + default=None, + sa_column=Column(BigInteger(), primary_key=True, autoincrement=True), + ) + battingcard_id: int | None = Field( + default=None, foreign_key="battingcard.id", ondelete="CASCADE" + ) + ratings_vl_id: int | None = Field( + default=None, foreign_key="battingratings.id", ondelete="CASCADE" + ) + ratings_vr_id: int | None = Field( + default=None, foreign_key="battingratings.id", ondelete="CASCADE" + ) + created: datetime.datetime = Field( + default_factory=datetime.datetime.now, nullable=True + ) class BatterScouting(BatterScoutingBase, table=True): - battingcard: BattingCard = Relationship() #back_populates='batterscouting') + battingcard: BattingCard = Relationship() # back_populates='batterscouting') ratings_vl: BattingRatings = Relationship( - sa_relationship_kwargs=dict(foreign_keys="[BatterScouting.ratings_vl_id]",single_parent=True), cascade_delete=True + sa_relationship_kwargs=dict( + foreign_keys="[BatterScouting.ratings_vl_id]", single_parent=True + ), + cascade_delete=True, ) ratings_vr: BattingRatings = Relationship( - sa_relationship_kwargs=dict(foreign_keys="[BatterScouting.ratings_vr_id]",single_parent=True), cascade_delete=True + sa_relationship_kwargs=dict( + foreign_keys="[BatterScouting.ratings_vr_id]", single_parent=True + ), + cascade_delete=True, + ) + cards: list["Card"] = Relationship( + back_populates="batterscouting", cascade_delete=False ) - cards: list['Card'] = Relationship(back_populates='batterscouting', cascade_delete=False) class PitchingCardBase(SQLModel): - id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=False)) + id: int | None = Field( + default=None, + sa_column=Column(BigInteger(), primary_key=True, autoincrement=False), + ) variant: int | None = Field(default=0) balk: int = Field(default=0, ge=0, le=20) wild_pitch: int = Field(default=0, ge=0, le=20) @@ -969,11 +1212,13 @@ class PitchingCardBase(SQLModel): relief_rating: int = Field(default=1, ge=1, le=10) closer_rating: int | None = Field(default=None, ge=0, le=9) offense_col: int = Field(ge=1, le=3) - batting: str = Field(default='#1WR-C') + batting: str = Field(default="#1WR-C") hand: str - created: datetime.datetime = Field(default_factory=datetime.datetime.now, nullable=True) + created: datetime.datetime = Field( + default_factory=datetime.datetime.now, nullable=True + ) - @field_validator('hand') + @field_validator("hand") def lowercase_hand(cls, value: str) -> str: return value.lower() @@ -983,7 +1228,10 @@ class PitchingCard(PitchingCardBase, table=True): class PitchingRatingsBase(SQLModel): - id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=False)) + id: int | None = Field( + default=None, + sa_column=Column(BigInteger(), primary_key=True, autoincrement=False), + ) homerun: float = Field(default=0.0, ge=0.0, le=108.0) bp_homerun: float = Field(default=0.0, ge=0.0, le=108.0) triple: float = Field(default=0.0, ge=0.0, le=108.0) @@ -1014,7 +1262,9 @@ class PitchingRatingsBase(SQLModel): avg: float = Field(default=0.0, ge=0.0, le=1.0) obp: float = Field(default=0.0, ge=0.0, le=1.0) slg: float = Field(default=0.0, ge=0.0, le=4.0) - created: datetime.datetime = Field(default_factory=datetime.datetime.now, nullable=True) + created: datetime.datetime = Field( + default_factory=datetime.datetime.now, nullable=True + ) class PitchingRatings(PitchingRatingsBase, table=True): @@ -1022,64 +1272,100 @@ class PitchingRatings(PitchingRatingsBase, table=True): class PitcherScoutingBase(SQLModel): - id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=True)) - pitchingcard_id: int | None = Field(default=None, foreign_key='pitchingcard.id', ondelete='CASCADE') - ratings_vl_id: int | None = Field(default=None, foreign_key='pitchingratings.id', ondelete='CASCADE') - ratings_vr_id: int | None = Field(default=None, foreign_key='pitchingratings.id', ondelete='CASCADE') - created: datetime.datetime = Field(default_factory=datetime.datetime.now, nullable=True) + id: int | None = Field( + default=None, + sa_column=Column(BigInteger(), primary_key=True, autoincrement=True), + ) + pitchingcard_id: int | None = Field( + default=None, foreign_key="pitchingcard.id", ondelete="CASCADE" + ) + ratings_vl_id: int | None = Field( + default=None, foreign_key="pitchingratings.id", ondelete="CASCADE" + ) + ratings_vr_id: int | None = Field( + default=None, foreign_key="pitchingratings.id", ondelete="CASCADE" + ) + created: datetime.datetime = Field( + default_factory=datetime.datetime.now, nullable=True + ) class PitcherScouting(PitcherScoutingBase, table=True): pitchingcard: PitchingCard = Relationship() ratings_vl: PitchingRatings = Relationship( - sa_relationship_kwargs=dict(foreign_keys="[PitcherScouting.ratings_vl_id]",single_parent=True), cascade_delete=True + sa_relationship_kwargs=dict( + foreign_keys="[PitcherScouting.ratings_vl_id]", single_parent=True + ), + cascade_delete=True, ) ratings_vr: PitchingRatings = Relationship( - sa_relationship_kwargs=dict(foreign_keys="[PitcherScouting.ratings_vr_id]",single_parent=True), cascade_delete=True + sa_relationship_kwargs=dict( + foreign_keys="[PitcherScouting.ratings_vr_id]", single_parent=True + ), + cascade_delete=True, ) - cards: list['Card'] = Relationship(back_populates='pitcherscouting') + cards: list["Card"] = Relationship(back_populates="pitcherscouting") class CardBase(SQLModel): - id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=False)) - player_id: int = Field(foreign_key='player.id', index=True, ondelete='CASCADE') - team_id: int = Field(foreign_key='team.id', index=True, ondelete='CASCADE') - batterscouting_id: int | None = Field(default=None, foreign_key='batterscouting.id', ondelete='CASCADE') - pitcherscouting_id: int | None = Field(default=None, foreign_key='pitcherscouting.id', ondelete='CASCADE') + id: int | None = Field( + default=None, + sa_column=Column(BigInteger(), primary_key=True, autoincrement=False), + ) + player_id: int = Field(foreign_key="player.id", index=True, ondelete="CASCADE") + team_id: int = Field(foreign_key="team.id", index=True, ondelete="CASCADE") + batterscouting_id: int | None = Field( + default=None, foreign_key="batterscouting.id", ondelete="CASCADE" + ) + pitcherscouting_id: int | None = Field( + default=None, foreign_key="pitcherscouting.id", ondelete="CASCADE" + ) variant: int | None = Field(default=0) - created: datetime.datetime = Field(default_factory=datetime.datetime.now, nullable=True) - + created: datetime.datetime = Field( + default_factory=datetime.datetime.now, nullable=True + ) + class Card(CardBase, table=True): - player: Player = Relationship(back_populates='cards') - team: Team = Relationship(back_populates='cards') - lineups: list['Lineup'] = Relationship(back_populates='card', cascade_delete=True) + player: Player = Relationship(back_populates="cards") + team: Team = Relationship(back_populates="cards") + lineups: list["Lineup"] = Relationship(back_populates="card", cascade_delete=True) variant: int = Field(default=0, index=True) - batterscouting: BatterScouting = Relationship(back_populates='cards') - pitcherscouting: PitcherScouting = Relationship(back_populates='cards') + batterscouting: BatterScouting = Relationship(back_populates="cards") + pitcherscouting: PitcherScouting = Relationship(back_populates="cards") class PositionRatingBase(SQLModel): __table_args__ = (UniqueConstraint("player_id", "variant", "position"),) - id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=True)) - player_id: int = Field(foreign_key='player.id', index=True, ondelete='CASCADE') + id: int | None = Field( + default=None, + sa_column=Column(BigInteger(), primary_key=True, autoincrement=True), + ) + player_id: int = Field(foreign_key="player.id", index=True, ondelete="CASCADE") variant: int = Field(default=0, index=True) - position: str = Field(index=True, include=['P', 'C', '1B', '2B', '3B', 'SS', 'LF', 'CF', 'RF']) + position: str = Field( + index=True, include=["P", "C", "1B", "2B", "3B", "SS", "LF", "CF", "RF"] + ) innings: int = Field(default=0) range: int = Field(default=5) error: int = Field(default=0) arm: int | None = Field(default=None) pb: int | None = Field(default=None) overthrow: int | None = Field(default=None) - created: datetime.datetime = Field(default_factory=datetime.datetime.now, nullable=True) + created: datetime.datetime = Field( + default_factory=datetime.datetime.now, nullable=True + ) class PositionRating(PositionRatingBase, table=True): - player: Player = Relationship(back_populates='positions') + player: Player = Relationship(back_populates="positions") class Lineup(SQLModel, table=True): - id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=True)) + id: int | None = Field( + default=None, + sa_column=Column(BigInteger(), primary_key=True, autoincrement=True), + ) position: str = Field(index=True) batting_order: int = Field(index=True) after_play: int | None = Field(default=0) @@ -1087,31 +1373,34 @@ class Lineup(SQLModel, table=True): active: bool = Field(default=True, index=True) is_fatigued: bool | None = Field(default=None) - game_id: int = Field(foreign_key='game.id', index=True, ondelete='CASCADE') - game: Game = Relationship(back_populates='lineups') - - team_id: int = Field(foreign_key='team.id', index=True, ondelete='CASCADE') - team: Team = Relationship(back_populates='lineups') - - player_id: int = Field(foreign_key='player.id', index=True, ondelete='CASCADE') - player: Player = Relationship(back_populates='lineups') + game_id: int = Field(foreign_key="game.id", index=True, ondelete="CASCADE") + game: Game = Relationship(back_populates="lineups") - card_id: int = Field(foreign_key='card.id', index=True, ondelete='CASCADE') - card: Card = Relationship(back_populates='lineups') + team_id: int = Field(foreign_key="team.id", index=True, ondelete="CASCADE") + team: Team = Relationship(back_populates="lineups") - @field_validator('position') + player_id: int = Field(foreign_key="player.id", index=True, ondelete="CASCADE") + player: Player = Relationship(back_populates="lineups") + + card_id: int = Field(foreign_key="card.id", index=True, ondelete="CASCADE") + card: Card = Relationship(back_populates="lineups") + + @field_validator("position") def uppercase_strings(cls, value: str) -> str: return value.upper() class PlayBase(SQLModel): - id: int | None = Field(default=None, sa_column=Column(BigInteger(), primary_key=True, autoincrement=True)) - game_id: int = Field(foreign_key='game.id') + id: int | None = Field( + default=None, + sa_column=Column(BigInteger(), primary_key=True, autoincrement=True), + ) + game_id: int = Field(foreign_key="game.id") play_num: int - batter_id: int = Field(foreign_key='lineup.id') - pitcher_id: int = Field(foreign_key='lineup.id') + batter_id: int = Field(foreign_key="lineup.id") + pitcher_id: int = Field(foreign_key="lineup.id") on_base_code: int = Field(default=0) - inning_half: str = Field(default='top') + inning_half: str = Field(default="top") inning_num: int = Field(default=1, ge=1) batting_order: int = Field(default=1, ge=1, le=9) starting_outs: int = Field(default=0, ge=0, le=2) @@ -1120,13 +1409,13 @@ class PlayBase(SQLModel): batter_pos: str | None = Field(default=None) in_pow: bool = Field(default=False) - on_first_id: int | None = Field(default=None, foreign_key='lineup.id') - on_first_final: int | None = Field(default=None) # None = out, 1-4 = base - on_second_id: int | None = Field(default=None, foreign_key='lineup.id') - on_second_final: int | None = Field(default=None) # None = out, 1-4 = base - on_third_id: int | None = Field(default=None, foreign_key='lineup.id') - on_third_final: int | None = Field(default=None) # None = out, 1-4 = base - batter_final: int | None = Field(default=None) # None = out, 1-4 = base + on_first_id: int | None = Field(default=None, foreign_key="lineup.id") + on_first_final: int | None = Field(default=None) # None = out, 1-4 = base + on_second_id: int | None = Field(default=None, foreign_key="lineup.id") + on_second_final: int | None = Field(default=None) # None = out, 1-4 = base + on_third_id: int | None = Field(default=None, foreign_key="lineup.id") + on_third_final: int | None = Field(default=None) # None = out, 1-4 = base + batter_final: int | None = Field(default=None) # None = out, 1-4 = base pa: int = Field(default=1, ge=0, le=1) ab: int = Field(default=1, ge=0, le=1) @@ -1154,9 +1443,9 @@ class PlayBase(SQLModel): wpa: float = Field(default=0) re24: float = Field(default=0) - catcher_id: int = Field(foreign_key='lineup.id') - defender_id: int | None = Field(default=None, foreign_key='lineup.id') - runner_id: int | None = Field(default=None, foreign_key='lineup.id') + catcher_id: int = Field(foreign_key="lineup.id") + defender_id: int | None = Field(default=None, foreign_key="lineup.id") + runner_id: int | None = Field(default=None, foreign_key="lineup.id") check_pos: str | None = Field(default=None) error: int = Field(default=0) @@ -1169,26 +1458,26 @@ class PlayBase(SQLModel): is_go_ahead: bool = Field(default=False) is_tied: bool = Field(default=False) is_new_inning: bool = Field(default=False) - managerai_id: int | None = Field(default=None, foreign_key='managerai.id') + managerai_id: int | None = Field(default=None, foreign_key="managerai.id") - @field_validator('inning_half') + @field_validator("inning_half") def lowercase_strings(cls, value: str) -> str: return value.lower() - @field_validator('check_pos', 'batter_pos') + @field_validator("check_pos", "batter_pos") def uppercase_strings(cls, value: str) -> str: return value.upper() @property def ai_run_diff(self) -> int: - if self.game.ai_team == 'away': + if self.game.ai_team == "away": return self.away_score - self.home_score else: return self.home_score - self.away_score class Play(PlayBase, table=True): - game: Game = Relationship(back_populates='plays') + game: Game = Relationship(back_populates="plays") batter: Lineup = Relationship( sa_relationship_kwargs=dict(foreign_keys="[Play.batter_id]") ) @@ -1213,7 +1502,7 @@ class Play(PlayBase, table=True): runner: Lineup = Relationship( sa_relationship_kwargs=dict(foreign_keys="[Play.runner_id]") ) - managerai: ManagerAi = Relationship(back_populates='plays') + managerai: ManagerAi = Relationship(back_populates="plays") def init_ai(self, session: Session): id = ((datetime.datetime.now().day * self.batter.team.id) % 3) + 1 @@ -1221,48 +1510,52 @@ class Play(PlayBase, table=True): self.managerai_id = 1 else: self.managerai_id = id - + session.add(self) session.commit() @property def scorebug_ascii(self): - occupied = '●' - unoccupied = '○' + occupied = "●" + unoccupied = "○" first_base = unoccupied if not self.on_first else occupied second_base = unoccupied if not self.on_second else occupied third_base = unoccupied if not self.on_third else occupied - half = '▲' if self.inning_half == 'top' else '▼' + half = "▲" if self.inning_half == "top" else "▼" if self.game.active: - inning = f'{half} {self.inning_num}' + inning = f"{half} {self.inning_num}" outs = f'{self.starting_outs} Out{"s" if self.starting_outs != 1 else ""}' else: inning = f'F/{self.inning_num if self.inning_half == "bot" else self.inning_num - 1}' - outs = '' - - game_string = f'```\n' \ - f'{self.game.away_team.abbrev.replace("Gauntlet-", ""): ^5}{self.away_score: ^3} {second_base}{inning: >10}\n' \ - f'{self.game.home_team.abbrev.replace("Gauntlet-", ""): ^5}{self.home_score: ^3} {third_base} {first_base}{outs: >8}\n```' - + outs = "" + + game_string = ( + f"```\n" + f'{self.game.away_team.abbrev.replace("Gauntlet-", ""): ^5}{self.away_score: ^3} {second_base}{inning: >10}\n' + f'{self.game.home_team.abbrev.replace("Gauntlet-", ""): ^5}{self.home_score: ^3} {third_base} {first_base}{outs: >8}\n```' + ) + return game_string @property def ai_is_batting(self) -> bool: if self.game.ai_team is None: return False - - if (self.game.ai_team == 'away' and self.inning_half == 'top') or (self.game.ai_team == 'home' and self.inning_half == 'bot'): + + if (self.game.ai_team == "away" and self.inning_half == "top") or ( + self.game.ai_team == "home" and self.inning_half == "bot" + ): return True else: return False @property def could_walkoff(self) -> bool: - if self.inning_half == 'bot' and self.on_third is not None: + if self.inning_half == "bot" and self.on_third is not None: runs_needed = self.away_score - self.home_score + 1 - + if runs_needed == 2 or (self.home_score - self.away_score == 9): return True @@ -1295,20 +1588,60 @@ def create_test_games(): season=9, ) - cardset_2024 = Cardset(name='2024 Season', ranked_legal=True) - cardset_2022 = Cardset(name='2022 Season', ranked_legal=False) + cardset_2024 = Cardset(name="2024 Season", ranked_legal=True) + cardset_2022 = Cardset(name="2022 Season", ranked_legal=False) - game_1_cardset_2024_link = GameCardsetLink(game=game_1, cardset=cardset_2024, priority=1) - game_1_cardset_2022_link = GameCardsetLink(game=game_1, cardset=cardset_2022, priority=2) - game_2_cardset_2024_link = GameCardsetLink(game=game_2, cardset=cardset_2024, priority=1) + game_1_cardset_2024_link = GameCardsetLink( + game=game_1, cardset=cardset_2024, priority=1 + ) + game_1_cardset_2022_link = GameCardsetLink( + game=game_1, cardset=cardset_2022, priority=2 + ) + game_2_cardset_2024_link = GameCardsetLink( + game=game_2, cardset=cardset_2024, priority=1 + ) for team_id in [1, 2]: - for (order, pos) in [(1, 'C'), (2, '1B'), (3, '2B'), (4, '3B'), (5, 'SS'), (6, 'LF'), (7, 'CF'), (8, 'RF'), (9, 'DH')]: - this_lineup = Lineup(team_id=team_id, card_id=order, player_id=68+order, position=pos, batting_order=order, game=game_1) - + for order, pos in [ + (1, "C"), + (2, "1B"), + (3, "2B"), + (4, "3B"), + (5, "SS"), + (6, "LF"), + (7, "CF"), + (8, "RF"), + (9, "DH"), + ]: + this_lineup = Lineup( + team_id=team_id, + card_id=order, + player_id=68 + order, + position=pos, + batting_order=order, + game=game_1, + ) + for team_id in [3, 4]: - for (order, pos) in [(1, 'C'), (2, '1B'), (3, '2B'), (4, '3B'), (5, 'SS'), (6, 'LF'), (7, 'CF'), (8, 'RF'), (9, 'DH')]: - this_lineup = Lineup(team_id=team_id, card_id=order, player_id=100+order, position=pos, batting_order=order, game=game_2) + for order, pos in [ + (1, "C"), + (2, "1B"), + (3, "2B"), + (4, "3B"), + (5, "SS"), + (6, "LF"), + (7, "CF"), + (8, "RF"), + (9, "DH"), + ]: + this_lineup = Lineup( + team_id=team_id, + card_id=order, + player_id=100 + order, + position=pos, + batting_order=order, + game=game_2, + ) session.add(game_1) session.add(game_2) @@ -1319,29 +1652,33 @@ def select_speed_testing(): with Session(engine) as session: game_1 = session.exec(select(Game).where(Game.id == 1)).one() ss_search_start = datetime.datetime.now() - man_ss = [x for x in game_1.lineups if x.position == 'SS' and x.active] + man_ss = [x for x in game_1.lineups if x.position == "SS" and x.active] ss_search_end = datetime.datetime.now() ss_query_start = datetime.datetime.now() - query_ss = session.exec(select(Lineup).where(Lineup.game == game_1, Lineup.position == 'SS', Lineup.active == True)).all() + query_ss = session.exec( + select(Lineup).where( + Lineup.game == game_1, Lineup.position == "SS", Lineup.active == True + ) + ).all() ss_query_end = datetime.datetime.now() manual_time = ss_search_end - ss_search_start query_time = ss_query_end - ss_query_start - print(f'Manual Shortstops: time: {manual_time.microseconds} ms / {man_ss}') - print(f'Query Shortstops: time: {query_time.microseconds} ms / {query_ss}') - print(f'Game: {game_1}') + print(f"Manual Shortstops: time: {manual_time.microseconds} ms / {man_ss}") + print(f"Query Shortstops: time: {query_time.microseconds} ms / {query_ss}") + print(f"Game: {game_1}") games = session.exec(select(Game).where(Game.active == True)).all() - print(f'len(games): {len(games)}') + print(f"len(games): {len(games)}") def select_all_testing(): with Session(engine) as session: game_search = session.exec(select(Team)).all() for game in game_search: - print(f'Game: {game}') + print(f"Game: {game}") # def select_specic_fields(): diff --git a/tests/gameplay_models/test_managerai_model.py b/tests/gameplay_models/test_managerai_model.py index 1749ab1..cff07b7 100644 --- a/tests/gameplay_models/test_managerai_model.py +++ b/tests/gameplay_models/test_managerai_model.py @@ -7,42 +7,60 @@ from in_game.managerai_responses import JumpResponse def test_create_ai(session: Session): all_ai = session.exec(select(ManagerAi)).all() - + assert len(all_ai) == 3 assert ManagerAi.create_ai(session) == True - + all_ai = session.exec(select(ManagerAi)).all() - + assert len(all_ai) == 3 def test_check_jump(session: Session): - balanced_ai = session.exec(select(ManagerAi).where(ManagerAi.name == 'Balanced')).one() - aggressive_ai = session.exec(select(ManagerAi).where(ManagerAi.name == 'Yolo')).one() + balanced_ai = session.exec( + select(ManagerAi).where(ManagerAi.name == "Balanced") + ).one() + aggressive_ai = session.exec( + select(ManagerAi).where(ManagerAi.name == "Yolo") + ).one() this_game = session.get(Game, 1) runner = session.get(Lineup, 5) - this_play = session.get(Play, 2) + this_play = session.get(Play, 2) this_play.on_first = runner assert this_play.starting_outs == 1 - assert balanced_ai.check_jump(session, this_game, to_base=2) == JumpResponse(ai_note='- SEND **Player 4** to second if they get the jump', min_safe=16) - assert aggressive_ai.check_jump(session, this_game, to_base=2) == JumpResponse(ai_note='- SEND **Player 4** to second if they get the jump', min_safe=13, run_if_auto_jump=True) - + assert balanced_ai.check_jump(session, this_game, to_base=2) == JumpResponse( + ai_note="- SEND **Player 4** to second if they get the jump", min_safe=16 + ) + assert aggressive_ai.check_jump(session, this_game, to_base=2) == JumpResponse( + ai_note="- SEND **Player 4** to second if they get the jump", + min_safe=13, + run_if_auto_jump=True, + ) + this_play.on_third = runner - assert balanced_ai.check_jump(session, this_game, to_base=4) == JumpResponse(min_safe=8) - assert aggressive_ai.check_jump(session, this_game, to_base=4) == JumpResponse(min_safe=5) + assert balanced_ai.check_jump(session, this_game, to_base=4) == JumpResponse( + min_safe=8 + ) + assert aggressive_ai.check_jump(session, this_game, to_base=4) == JumpResponse( + min_safe=5 + ) def test_tag_from_second(session: Session): - balanced_ai = session.exec(select(ManagerAi).where(ManagerAi.name == 'Balanced')).one() - aggressive_ai = session.exec(select(ManagerAi).where(ManagerAi.name == 'Yolo')).one() + balanced_ai = session.exec( + select(ManagerAi).where(ManagerAi.name == "Balanced") + ).one() + aggressive_ai = session.exec( + select(ManagerAi).where(ManagerAi.name == "Yolo") + ).one() this_game = session.get(Game, 1) runner = session.get(Lineup, 5) - this_play = session.get(Play, 2) + this_play = session.get(Play, 2) this_play.on_second = runner assert this_play.starting_outs == 1 @@ -53,4 +71,30 @@ def test_tag_from_second(session: Session): assert balanced_resp.min_safe == 5 assert aggressive_resp.min_safe == 2 - \ No newline at end of file + +def test_gb_decide_run(session: Session): + """ + Verifies that gb_decide_run returns a min_safe threshold based on self.running + plus an aggression modifier, with outs adjustment applied. + + With 1 out (no outs adjustment): + - Balanced (running=5, behind_aggression=5): adjusted_running=5 → tier ≥5 → min_safe=6 + - Yolo (running=10, behind_aggression=10): adjusted_running=15 → tier ≥8 → min_safe=4 + """ + balanced_ai = session.exec( + select(ManagerAi).where(ManagerAi.name == "Balanced") + ).one() + aggressive_ai = session.exec( + select(ManagerAi).where(ManagerAi.name == "Yolo") + ).one() + + this_game = session.get(Game, 1) + this_play = session.get(Play, 2) + + assert this_play.starting_outs == 1 + + balanced_resp = balanced_ai.gb_decide_run(session, this_game) + aggressive_resp = aggressive_ai.gb_decide_run(session, this_game) + + assert balanced_resp.min_safe == 6 + assert aggressive_resp.min_safe == 4 diff --git a/tests/test_card_embed_evolution.py b/tests/test_card_embed_evolution.py new file mode 100644 index 0000000..5a3b9e2 --- /dev/null +++ b/tests/test_card_embed_evolution.py @@ -0,0 +1,315 @@ +""" +Tests for WP-12: Tier Badge on Card Embed. + +What: Verifies that get_card_embeds() correctly prepends a tier badge to the +embed title when a card has evolution progress, and gracefully degrades when +the evolution API is unavailable. + +Why: The tier badge is a non-blocking UI enhancement. Any failure in the +evolution API must never prevent the card embed from rendering — this test +suite enforces that contract while also validating the badge format logic. +""" + +import pytest +from unittest.mock import AsyncMock, patch +import discord + +# --------------------------------------------------------------------------- +# Helpers / shared fixtures +# --------------------------------------------------------------------------- + + +def make_card( + player_id=42, + p_name="Mike Trout", + rarity_color="FFD700", + image="https://example.com/card.png", + headshot=None, + franchise="Los Angeles Angels", + bbref_id="troutmi01", + fangr_id=None, + strat_code="420420", + mlbclub="Los Angeles Angels", + cardset_name="2024 Season", +): + """ + Build the minimal card dict that get_card_embeds() expects, matching the + shape returned by the Paper Dynasty API (nested player / team / rarity). + + Using p_name='Mike Trout' as the canonical test name so we can assert + against '[Tx] Mike Trout' title strings without repeating the name. + """ + return { + "id": 9001, + "player": { + "player_id": player_id, + "p_name": p_name, + "rarity": {"color": rarity_color, "name": "Hall of Fame"}, + "image": image, + "image2": None, + "headshot": headshot, + "mlbclub": mlbclub, + "franchise": franchise, + "bbref_id": bbref_id, + "fangr_id": fangr_id, + "strat_code": strat_code, + "cost": 500, + "cardset": {"name": cardset_name}, + "pos_1": "CF", + "pos_2": None, + "pos_3": None, + "pos_4": None, + "pos_5": None, + "pos_6": None, + "pos_7": None, + "pos_8": None, + }, + "team": { + "id": 1, + "lname": "Test Team", + "logo": "https://example.com/logo.png", + "season": 7, + }, + } + + +def make_evo_state(tier: int) -> dict: + """Return a minimal evolution-state dict for a given tier.""" + return {"current_tier": tier, "xp": 100, "max_tier": 4} + + +EMPTY_PAPERDEX = {"count": 0, "paperdex": []} + + +def _db_get_side_effect(evo_response): + """ + Build a db_get coroutine side-effect that returns evo_response for + evolution/* endpoints and an empty paperdex for everything else. + """ + + async def _side_effect(endpoint, **kwargs): + if "evolution" in endpoint: + return evo_response + if "paperdex" in endpoint: + return EMPTY_PAPERDEX + return None + + return _side_effect + + +# --------------------------------------------------------------------------- +# Tier badge format — pure function tests (no Discord/API involved) +# --------------------------------------------------------------------------- + + +class TestTierBadgeFormat: + """ + Unit tests for the _get_tier_badge() helper that computes the badge string. + + Why separate: the badge logic is simple but error-prone at the boundary + between tier 3 and tier 4 (EVO). Testing it in isolation makes failures + immediately obvious without standing up the full embed machinery. + """ + + def _badge(self, tier: int) -> str: + """Inline mirror of the production badge logic for white-box testing.""" + if tier <= 0: + return "" + return f"[{'EVO' if tier >= 4 else f'T{tier}'}] " + + def test_tier_0_returns_empty_string(self): + """Tier 0 means no evolution progress — badge must be absent.""" + assert self._badge(0) == "" + + def test_negative_tier_returns_empty_string(self): + """Defensive: negative tiers (should not happen) must produce no badge.""" + assert self._badge(-1) == "" + + def test_tier_1_shows_T1(self): + assert self._badge(1) == "[T1] " + + def test_tier_2_shows_T2(self): + assert self._badge(2) == "[T2] " + + def test_tier_3_shows_T3(self): + assert self._badge(3) == "[T3] " + + def test_tier_4_shows_EVO(self): + """Tier 4 is fully evolved — badge changes from T4 to EVO.""" + assert self._badge(4) == "[EVO] " + + def test_tier_above_4_shows_EVO(self): + """Any tier >= 4 should display EVO (defensive against future tiers).""" + assert self._badge(5) == "[EVO] " + assert self._badge(99) == "[EVO] " + + +# --------------------------------------------------------------------------- +# Integration-style tests for get_card_embeds() title construction +# --------------------------------------------------------------------------- + + +class TestCardEmbedTierBadge: + """ + Validates that get_card_embeds() produces the correct title format when + evolution state is present or absent. + + Strategy: patch helpers.main.db_get to control what the evolution endpoint + returns, then call get_card_embeds() and inspect the resulting embed title. + """ + + @pytest.mark.asyncio + @pytest.mark.asyncio + async def test_no_evolution_state_shows_plain_name(self): + """ + When the evolution API returns None (404 or down), the embed title + must equal the player name with no badge prefix. + """ + from helpers.main import get_card_embeds + + card = make_card(p_name="Mike Trout") + with patch( + "helpers.main.db_get", new=AsyncMock(side_effect=_db_get_side_effect(None)) + ): + embeds = await get_card_embeds(card) + + assert len(embeds) > 0 + assert embeds[0].title == "Mike Trout" + + @pytest.mark.asyncio + async def test_tier_0_shows_plain_name(self): + """ + Tier 0 in the evolution state means no progress yet — no badge shown. + """ + from helpers.main import get_card_embeds + + card = make_card(p_name="Mike Trout") + with patch( + "helpers.main.db_get", + new=AsyncMock(side_effect=_db_get_side_effect(make_evo_state(0))), + ): + embeds = await get_card_embeds(card) + + assert embeds[0].title == "Mike Trout" + + @pytest.mark.asyncio + async def test_tier_1_badge_in_title(self): + """Tier 1 card shows [T1] prefix in the embed title.""" + from helpers.main import get_card_embeds + + card = make_card(p_name="Mike Trout") + with patch( + "helpers.main.db_get", + new=AsyncMock(side_effect=_db_get_side_effect(make_evo_state(1))), + ): + embeds = await get_card_embeds(card) + + assert embeds[0].title == "[T1] Mike Trout" + + @pytest.mark.asyncio + async def test_tier_2_badge_in_title(self): + """Tier 2 card shows [T2] prefix in the embed title.""" + from helpers.main import get_card_embeds + + card = make_card(p_name="Mike Trout") + with patch( + "helpers.main.db_get", + new=AsyncMock(side_effect=_db_get_side_effect(make_evo_state(2))), + ): + embeds = await get_card_embeds(card) + + assert embeds[0].title == "[T2] Mike Trout" + + @pytest.mark.asyncio + async def test_tier_3_badge_in_title(self): + """Tier 3 card shows [T3] prefix in the embed title.""" + from helpers.main import get_card_embeds + + card = make_card(p_name="Mike Trout") + with patch( + "helpers.main.db_get", + new=AsyncMock(side_effect=_db_get_side_effect(make_evo_state(3))), + ): + embeds = await get_card_embeds(card) + + assert embeds[0].title == "[T3] Mike Trout" + + @pytest.mark.asyncio + async def test_tier_4_shows_evo_badge(self): + """Fully evolved card (tier 4) shows [EVO] prefix instead of [T4].""" + from helpers.main import get_card_embeds + + card = make_card(p_name="Mike Trout") + with patch( + "helpers.main.db_get", + new=AsyncMock(side_effect=_db_get_side_effect(make_evo_state(4))), + ): + embeds = await get_card_embeds(card) + + assert embeds[0].title == "[EVO] Mike Trout" + + @pytest.mark.asyncio + async def test_embed_color_unchanged_by_badge(self): + """ + The tier badge must not affect the embed color — rarity color is the + only driver of embed color, even for evolved cards. + + Why: embed color communicates card rarity to players. Silently breaking + it via evolution would confuse users. + """ + from helpers.main import get_card_embeds + + rarity_color = "FFD700" + card = make_card(p_name="Mike Trout", rarity_color=rarity_color) + with patch( + "helpers.main.db_get", + new=AsyncMock(side_effect=_db_get_side_effect(make_evo_state(3))), + ): + embeds = await get_card_embeds(card) + + expected_color = int(rarity_color, 16) + assert embeds[0].colour.value == expected_color + + @pytest.mark.asyncio + async def test_evolution_api_exception_shows_plain_name(self): + """ + When the evolution API raises an unexpected exception (network error, + server crash, etc.), the embed must still render with the plain player + name — no badge, no crash. + + This is the critical non-blocking contract for the feature. + """ + from helpers.main import get_card_embeds + + async def exploding_side_effect(endpoint, **kwargs): + if "evolution" in endpoint: + raise RuntimeError("simulated network failure") + if "paperdex" in endpoint: + return EMPTY_PAPERDEX + return None + + card = make_card(p_name="Mike Trout") + with patch( + "helpers.main.db_get", new=AsyncMock(side_effect=exploding_side_effect) + ): + embeds = await get_card_embeds(card) + + assert embeds[0].title == "Mike Trout" + + @pytest.mark.asyncio + async def test_evolution_api_missing_current_tier_key(self): + """ + If the evolution response is present but lacks 'current_tier', the + embed must gracefully degrade to no badge (defensive against API drift). + """ + from helpers.main import get_card_embeds + + card = make_card(p_name="Mike Trout") + # Response exists but is missing the expected key + with patch( + "helpers.main.db_get", + new=AsyncMock(side_effect=_db_get_side_effect({"xp": 50})), + ): + embeds = await get_card_embeds(card) + + assert embeds[0].title == "Mike Trout" diff --git a/tests/test_complete_game_hook.py b/tests/test_complete_game_hook.py new file mode 100644 index 0000000..6b6f07f --- /dev/null +++ b/tests/test_complete_game_hook.py @@ -0,0 +1,201 @@ +""" +Tests for the WP-13 post-game callback integration hook. + +These tests verify that after a game is saved to the API, two additional +POST requests are fired in the correct order: + 1. POST season-stats/update-game/{game_id} — update player_season_stats + 2. POST evolution/evaluate-game/{game_id} — evaluate evolution milestones + +Key design constraints being tested: + - Season stats MUST be updated before evolution is evaluated (ordering). + - Failure of either evolution call must NOT propagate — the game result has + already been committed; evolution will self-heal on the next evaluate pass. + - Tier-up dicts returned by the evolution endpoint are passed to + notify_tier_completion so WP-14 can present them to the player. +""" + +import asyncio +import logging +import pytest +from unittest.mock import AsyncMock, MagicMock, call, patch + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_channel(channel_id: int = 999) -> MagicMock: + ch = MagicMock() + ch.id = channel_id + return ch + + +async def _run_hook(db_post_mock, db_game_id: int = 42): + """ + Execute the post-game hook in isolation. + + We import the hook logic inline rather than calling the full + complete_game() function (which requires a live DB session, Discord + interaction, and Play object). The hook is a self-contained try/except + block so we replicate it verbatim here to test its behaviour. + """ + channel = _make_channel() + from command_logic.logic_gameplay import notify_tier_completion + + db_game = {"id": db_game_id} + + try: + await db_post_mock(f"season-stats/update-game/{db_game['id']}") + evo_result = await db_post_mock(f"evolution/evaluate-game/{db_game['id']}") + if evo_result and evo_result.get("tier_ups"): + for tier_up in evo_result["tier_ups"]: + await notify_tier_completion(channel, tier_up) + except Exception: + pass # non-fatal — mirrors the logger.warning in production + + return channel + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_hook_posts_to_both_endpoints_in_order(): + """ + Both evolution endpoints are called, and season-stats comes first. + + The ordering is critical: player_season_stats must be populated before the + evolution engine tries to read them for milestone evaluation. + """ + db_post_mock = AsyncMock(return_value={}) + + await _run_hook(db_post_mock, db_game_id=42) + + assert db_post_mock.call_count == 2 + calls = db_post_mock.call_args_list + # First call must be season-stats + assert calls[0] == call("season-stats/update-game/42") + # Second call must be evolution evaluate + assert calls[1] == call("evolution/evaluate-game/42") + + +@pytest.mark.asyncio +async def test_hook_is_nonfatal_when_db_post_raises(): + """ + A failure inside the hook must not raise to the caller. + + The game result is already persisted when the hook runs. If the evolution + API is down or returns an error, we log a warning and continue — the game + completion flow must not be interrupted. + """ + db_post_mock = AsyncMock(side_effect=Exception("evolution API unavailable")) + + # Should not raise + try: + await _run_hook(db_post_mock, db_game_id=7) + except Exception as exc: + pytest.fail(f"Hook raised unexpectedly: {exc}") + + +@pytest.mark.asyncio +async def test_hook_processes_tier_ups_from_evo_result(): + """ + When the evolution endpoint returns tier_ups, each entry is forwarded to + notify_tier_completion. + + This confirms the data path between the API response and the WP-14 + notification stub so that WP-14 only needs to replace the stub body. + """ + tier_ups = [ + {"player_id": 101, "old_tier": 1, "new_tier": 2}, + {"player_id": 202, "old_tier": 2, "new_tier": 3}, + ] + + async def fake_db_post(endpoint): + if "evolution" in endpoint: + return {"tier_ups": tier_ups} + return {} + + db_post_mock = AsyncMock(side_effect=fake_db_post) + + with patch( + "command_logic.logic_gameplay.notify_tier_completion", + new_callable=AsyncMock, + ) as mock_notify: + channel = _make_channel() + db_game = {"id": 99} + + try: + await db_post_mock(f"season-stats/update-game/{db_game['id']}") + evo_result = await db_post_mock(f"evolution/evaluate-game/{db_game['id']}") + if evo_result and evo_result.get("tier_ups"): + for tier_up in evo_result["tier_ups"]: + await mock_notify(channel, tier_up) + except Exception: + pass + + assert mock_notify.call_count == 2 + # Verify both tier_up dicts were forwarded + forwarded = [c.args[1] for c in mock_notify.call_args_list] + assert {"player_id": 101, "old_tier": 1, "new_tier": 2} in forwarded + assert {"player_id": 202, "old_tier": 2, "new_tier": 3} in forwarded + + +@pytest.mark.asyncio +async def test_hook_no_tier_ups_does_not_call_notify(): + """ + When the evolution response has no tier_ups (empty list or missing key), + notify_tier_completion is never called. + + Avoids spurious Discord messages for routine game completions. + """ + + async def fake_db_post(endpoint): + if "evolution" in endpoint: + return {"tier_ups": []} + return {} + + db_post_mock = AsyncMock(side_effect=fake_db_post) + + with patch( + "command_logic.logic_gameplay.notify_tier_completion", + new_callable=AsyncMock, + ) as mock_notify: + channel = _make_channel() + db_game = {"id": 55} + + try: + await db_post_mock(f"season-stats/update-game/{db_game['id']}") + evo_result = await db_post_mock(f"evolution/evaluate-game/{db_game['id']}") + if evo_result and evo_result.get("tier_ups"): + for tier_up in evo_result["tier_ups"]: + await mock_notify(channel, tier_up) + except Exception: + pass + + mock_notify.assert_not_called() + + +@pytest.mark.asyncio +async def test_notify_tier_completion_stub_logs_and_does_not_raise(caplog): + """ + The WP-14 stub must log the event and return cleanly. + + Verifies the contract that WP-14 can rely on: the function accepts + (channel, tier_up) and does not raise, so the hook's for-loop is safe. + """ + from command_logic.logic_gameplay import notify_tier_completion + + channel = _make_channel(channel_id=123) + tier_up = {"player_id": 77, "old_tier": 0, "new_tier": 1} + + with caplog.at_level(logging.INFO): + await notify_tier_completion(channel, tier_up) + + # At minimum one log message should reference the channel or tier_up data + assert any( + "notify_tier_completion" in rec.message or "77" in rec.message + for rec in caplog.records + ) diff --git a/tests/test_evolution_commands.py b/tests/test_evolution_commands.py new file mode 100644 index 0000000..eb65458 --- /dev/null +++ b/tests/test_evolution_commands.py @@ -0,0 +1,173 @@ +"""Tests for the evolution status command helpers (WP-11). + +Unit tests for progress bar rendering, entry formatting, tier display +names, close-to-tierup filtering, and edge cases. No Discord bot or +API calls required — these test pure functions only. +""" + +import pytest +from cogs.players_new.evolution import ( + render_progress_bar, + format_evo_entry, + is_close_to_tierup, + TIER_NAMES, + FORMULA_SHORTHANDS, +) + +# --------------------------------------------------------------------------- +# render_progress_bar +# --------------------------------------------------------------------------- + + +class TestRenderProgressBar: + def test_80_percent_filled(self): + """120/149 should be ~80% filled (8 of 10 chars).""" + result = render_progress_bar(120, 149, width=10) + assert "[========--]" in result + assert "120/149" in result + + def test_zero_progress(self): + """0/37 should be empty bar.""" + result = render_progress_bar(0, 37, width=10) + assert "[----------]" in result + assert "0/37" in result + + def test_full_progress_not_evolved(self): + """Value at threshold shows full bar.""" + result = render_progress_bar(149, 149, width=10) + assert "[==========]" in result + assert "149/149" in result + + def test_fully_evolved(self): + """next_threshold=None means fully evolved.""" + result = render_progress_bar(900, None, width=10) + assert "FULLY EVOLVED" in result + assert "[==========]" in result + + def test_over_threshold_capped(self): + """Value exceeding threshold still caps at 100%.""" + result = render_progress_bar(200, 149, width=10) + assert "[==========]" in result + + +# --------------------------------------------------------------------------- +# format_evo_entry +# --------------------------------------------------------------------------- + + +class TestFormatEvoEntry: + def test_batter_t1_to_t2(self): + """Batter at T1 progressing toward T2.""" + state = { + "current_tier": 1, + "current_value": 120.0, + "next_threshold": 149, + "fully_evolved": False, + "track": {"card_type": "batter"}, + } + result = format_evo_entry(state) + assert "(PA+TB×2)" in result + assert "Initiate → Rising" in result + + def test_pitcher_sp(self): + """SP track shows IP+K formula.""" + state = { + "current_tier": 0, + "current_value": 5.0, + "next_threshold": 10, + "fully_evolved": False, + "track": {"card_type": "sp"}, + } + result = format_evo_entry(state) + assert "(IP+K)" in result + assert "Unranked → Initiate" in result + + def test_fully_evolved_entry(self): + """Fully evolved card shows T4 — Evolved.""" + state = { + "current_tier": 4, + "current_value": 900.0, + "next_threshold": None, + "fully_evolved": True, + "track": {"card_type": "batter"}, + } + result = format_evo_entry(state) + assert "FULLY EVOLVED" in result + assert "Evolved" in result + + +# --------------------------------------------------------------------------- +# is_close_to_tierup +# --------------------------------------------------------------------------- + + +class TestIsCloseToTierup: + def test_at_80_percent(self): + """Exactly 80% of threshold counts as close.""" + state = {"current_value": 119.2, "next_threshold": 149} + assert is_close_to_tierup(state, threshold_pct=0.80) + + def test_below_80_percent(self): + """Below 80% is not close.""" + state = {"current_value": 100, "next_threshold": 149} + assert not is_close_to_tierup(state, threshold_pct=0.80) + + def test_fully_evolved_not_close(self): + """Fully evolved (no next threshold) is not close.""" + state = {"current_value": 900, "next_threshold": None} + assert not is_close_to_tierup(state) + + def test_zero_threshold(self): + """Zero threshold edge case returns False.""" + state = {"current_value": 0, "next_threshold": 0} + assert not is_close_to_tierup(state) + + +# --------------------------------------------------------------------------- +# Tier names and formula shorthands +# --------------------------------------------------------------------------- + + +class TestConstants: + def test_all_tier_names_present(self): + """All 5 tiers (0-4) have display names.""" + assert len(TIER_NAMES) == 5 + for i in range(5): + assert i in TIER_NAMES + + def test_tier_name_values(self): + assert TIER_NAMES[0] == "Unranked" + assert TIER_NAMES[1] == "Initiate" + assert TIER_NAMES[2] == "Rising" + assert TIER_NAMES[3] == "Ascendant" + assert TIER_NAMES[4] == "Evolved" + + def test_formula_shorthands(self): + assert FORMULA_SHORTHANDS["batter"] == "PA+TB×2" + assert FORMULA_SHORTHANDS["sp"] == "IP+K" + assert FORMULA_SHORTHANDS["rp"] == "IP+K" + + +# --------------------------------------------------------------------------- +# Empty / edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_missing_track_defaults(self): + """State with missing track info still formats without error.""" + state = { + "current_tier": 0, + "current_value": 0, + "next_threshold": 37, + "fully_evolved": False, + "track": {}, + } + result = format_evo_entry(state) + assert isinstance(result, str) + + def test_state_with_no_keys(self): + """Completely empty state dict doesn't crash.""" + state = {} + result = format_evo_entry(state) + assert isinstance(result, str) diff --git a/tests/test_evolution_notifications.py b/tests/test_evolution_notifications.py new file mode 100644 index 0000000..1f1256c --- /dev/null +++ b/tests/test_evolution_notifications.py @@ -0,0 +1,259 @@ +""" +Tests for Evolution Tier Completion Notification embeds. + +These tests verify that: +1. Tier-up embeds are correctly formatted for tiers 1-3 (title, description, color). +2. Tier 4 (Fully Evolved) embeds include the special title, description, and note field. +3. Multiple tier-up events each produce a separate embed. +4. An empty tier-up list results in no channel sends. + +The channel interaction is mocked because we are testing the embed content, not Discord +network I/O. Notification failure must never affect game flow, so the non-fatal path +is also exercised. +""" + +import pytest +from unittest.mock import AsyncMock + +import discord + +from helpers.evolution_notifs import build_tier_up_embed, notify_tier_completion + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def make_tier_up( + player_name="Mike Trout", + old_tier=1, + new_tier=2, + track_name="Batter", + current_value=150, +): + """Return a minimal tier_up dict matching the expected shape.""" + return { + "player_name": player_name, + "old_tier": old_tier, + "new_tier": new_tier, + "track_name": track_name, + "current_value": current_value, + } + + +# --------------------------------------------------------------------------- +# Unit: build_tier_up_embed — tiers 1-3 (standard tier-up) +# --------------------------------------------------------------------------- + + +class TestBuildTierUpEmbed: + """Verify that build_tier_up_embed produces correctly structured embeds.""" + + def test_title_is_evolution_tier_up(self): + """Title must read 'Evolution Tier Up!' for any non-max tier.""" + tier_up = make_tier_up(new_tier=2) + embed = build_tier_up_embed(tier_up) + assert embed.title == "Evolution Tier Up!" + + def test_description_contains_player_name(self): + """Description must contain the player's name.""" + tier_up = make_tier_up(player_name="Mike Trout", new_tier=2) + embed = build_tier_up_embed(tier_up) + assert "Mike Trout" in embed.description + + def test_description_contains_new_tier_name(self): + """Description must include the human-readable tier name for the new tier.""" + tier_up = make_tier_up(new_tier=2) + embed = build_tier_up_embed(tier_up) + # Tier 2 display name is "Rising" + assert "Rising" in embed.description + + def test_description_contains_track_name(self): + """Description must mention the evolution track (e.g., 'Batter').""" + tier_up = make_tier_up(track_name="Batter", new_tier=2) + embed = build_tier_up_embed(tier_up) + assert "Batter" in embed.description + + def test_tier1_color_is_green(self): + """Tier 1 uses green (0x2ecc71).""" + tier_up = make_tier_up(old_tier=0, new_tier=1) + embed = build_tier_up_embed(tier_up) + assert embed.color.value == 0x2ECC71 + + def test_tier2_color_is_gold(self): + """Tier 2 uses gold (0xf1c40f).""" + tier_up = make_tier_up(old_tier=1, new_tier=2) + embed = build_tier_up_embed(tier_up) + assert embed.color.value == 0xF1C40F + + def test_tier3_color_is_purple(self): + """Tier 3 uses purple (0x9b59b6).""" + tier_up = make_tier_up(old_tier=2, new_tier=3) + embed = build_tier_up_embed(tier_up) + assert embed.color.value == 0x9B59B6 + + def test_footer_text_is_paper_dynasty_evolution(self): + """Footer text must be 'Paper Dynasty Evolution' for brand consistency.""" + tier_up = make_tier_up(new_tier=2) + embed = build_tier_up_embed(tier_up) + assert embed.footer.text == "Paper Dynasty Evolution" + + def test_returns_discord_embed_instance(self): + """Return type must be discord.Embed so it can be sent directly.""" + tier_up = make_tier_up(new_tier=2) + embed = build_tier_up_embed(tier_up) + assert isinstance(embed, discord.Embed) + + +# --------------------------------------------------------------------------- +# Unit: build_tier_up_embed — tier 4 (fully evolved) +# --------------------------------------------------------------------------- + + +class TestBuildTierUpEmbedFullyEvolved: + """Verify that tier 4 (Fully Evolved) embeds use special formatting.""" + + def test_title_is_fully_evolved(self): + """Tier 4 title must be 'FULLY EVOLVED!' to emphasise max achievement.""" + tier_up = make_tier_up(old_tier=3, new_tier=4) + embed = build_tier_up_embed(tier_up) + assert embed.title == "FULLY EVOLVED!" + + def test_description_mentions_maximum_evolution(self): + """Tier 4 description must mention 'maximum evolution' per the spec.""" + tier_up = make_tier_up(player_name="Mike Trout", old_tier=3, new_tier=4) + embed = build_tier_up_embed(tier_up) + assert "maximum evolution" in embed.description.lower() + + def test_description_contains_player_name(self): + """Player name must appear in the tier 4 description.""" + tier_up = make_tier_up(player_name="Mike Trout", old_tier=3, new_tier=4) + embed = build_tier_up_embed(tier_up) + assert "Mike Trout" in embed.description + + def test_description_contains_track_name(self): + """Track name must appear in the tier 4 description.""" + tier_up = make_tier_up(track_name="Batter", old_tier=3, new_tier=4) + embed = build_tier_up_embed(tier_up) + assert "Batter" in embed.description + + def test_tier4_color_is_teal(self): + """Tier 4 uses teal (0x1abc9c) to visually distinguish max evolution.""" + tier_up = make_tier_up(old_tier=3, new_tier=4) + embed = build_tier_up_embed(tier_up) + assert embed.color.value == 0x1ABC9C + + def test_note_field_present(self): + """Tier 4 must include a note field about future rating boosts.""" + tier_up = make_tier_up(old_tier=3, new_tier=4) + embed = build_tier_up_embed(tier_up) + field_names = [f.name for f in embed.fields] + assert any( + "rating" in name.lower() + or "boost" in name.lower() + or "note" in name.lower() + for name in field_names + ), "Expected a field mentioning rating boosts for tier 4 embed" + + def test_note_field_value_mentions_future_update(self): + """The note field value must reference the future rating boost update.""" + tier_up = make_tier_up(old_tier=3, new_tier=4) + embed = build_tier_up_embed(tier_up) + note_field = next( + ( + f + for f in embed.fields + if "rating" in f.name.lower() + or "boost" in f.name.lower() + or "note" in f.name.lower() + ), + None, + ) + assert note_field is not None + assert ( + "future" in note_field.value.lower() or "update" in note_field.value.lower() + ) + + def test_footer_text_is_paper_dynasty_evolution(self): + """Footer must remain 'Paper Dynasty Evolution' for tier 4 as well.""" + tier_up = make_tier_up(old_tier=3, new_tier=4) + embed = build_tier_up_embed(tier_up) + assert embed.footer.text == "Paper Dynasty Evolution" + + +# --------------------------------------------------------------------------- +# Unit: notify_tier_completion — multiple and empty cases +# --------------------------------------------------------------------------- + + +class TestNotifyTierCompletion: + """Verify that notify_tier_completion sends the right number of messages.""" + + @pytest.mark.asyncio + async def test_single_tier_up_sends_one_message(self): + """A single tier-up event sends exactly one embed to the channel.""" + channel = AsyncMock() + tier_up = make_tier_up(new_tier=2) + await notify_tier_completion(channel, tier_up) + channel.send.assert_called_once() + + @pytest.mark.asyncio + async def test_sends_embed_not_plain_text(self): + """The channel.send call must use the embed= keyword, not content=.""" + channel = AsyncMock() + tier_up = make_tier_up(new_tier=2) + await notify_tier_completion(channel, tier_up) + _, kwargs = channel.send.call_args + assert ( + "embed" in kwargs + ), "notify_tier_completion must send an embed, not plain text" + + @pytest.mark.asyncio + async def test_embed_type_is_discord_embed(self): + """The embed passed to channel.send must be a discord.Embed instance.""" + channel = AsyncMock() + tier_up = make_tier_up(new_tier=2) + await notify_tier_completion(channel, tier_up) + _, kwargs = channel.send.call_args + assert isinstance(kwargs["embed"], discord.Embed) + + @pytest.mark.asyncio + async def test_notification_failure_does_not_raise(self): + """If channel.send raises, notify_tier_completion must swallow it so game flow is unaffected.""" + channel = AsyncMock() + channel.send.side_effect = Exception("Discord API unavailable") + tier_up = make_tier_up(new_tier=2) + # Should not raise + await notify_tier_completion(channel, tier_up) + + @pytest.mark.asyncio + async def test_multiple_tier_ups_caller_sends_multiple_embeds(self): + """ + Callers are responsible for iterating tier-up events; each call to + notify_tier_completion sends a separate embed. This test simulates + three consecutive calls (3 events) and asserts 3 sends occurred. + """ + channel = AsyncMock() + events = [ + make_tier_up(player_name="Mike Trout", new_tier=2), + make_tier_up(player_name="Aaron Judge", new_tier=1), + make_tier_up(player_name="Shohei Ohtani", new_tier=3), + ] + for event in events: + await notify_tier_completion(channel, event) + assert ( + channel.send.call_count == 3 + ), "Each tier-up event must produce its own embed (no batching)" + + @pytest.mark.asyncio + async def test_no_tier_ups_means_no_sends(self): + """ + When the caller has an empty list of tier-up events and simply + does not call notify_tier_completion, zero sends happen. + This explicitly guards against any accidental unconditional send. + """ + channel = AsyncMock() + tier_up_events = [] + for event in tier_up_events: + await notify_tier_completion(channel, event) + channel.send.assert_not_called() diff --git a/tests/test_roll_for_cards.py b/tests/test_roll_for_cards.py new file mode 100644 index 0000000..38aff2c --- /dev/null +++ b/tests/test_roll_for_cards.py @@ -0,0 +1,530 @@ +"""Tests for roll_for_cards parallelized implementation. + +Validates dice rolling, batched player fetches, card creation, +pack marking, MVP backfill, cardset-23 dupe fallback, and notifications. +""" + +import os +from unittest.mock import AsyncMock, patch + +import pytest + + +def _make_player( + player_id, rarity_value=0, rarity_name="Replacement", p_name="Test Player" +): + """Factory for player dicts matching API shape.""" + return { + "player_id": player_id, + "rarity": {"value": rarity_value, "name": rarity_name}, + "p_name": p_name, + "description": f"2024 {p_name}", + } + + +_UNSET = object() + + +def _make_pack( + pack_id, team_id=1, pack_type="Standard", pack_team=None, pack_cardset=_UNSET +): + """Factory for pack dicts matching API shape.""" + return { + "id": pack_id, + "team": {"id": team_id, "abbrev": "TST"}, + "pack_type": {"name": pack_type}, + "pack_team": pack_team, + "pack_cardset": {"id": 10} if pack_cardset is _UNSET else pack_cardset, + } + + +def _random_response(players): + """Wrap a list of player dicts in the API response shape.""" + return {"count": len(players), "players": players} + + +@pytest.fixture +def mock_db(): + """Patch db_get, db_post, db_patch in helpers.main.""" + with ( + patch("helpers.main.db_get", new_callable=AsyncMock) as mock_get, + patch("helpers.main.db_post", new_callable=AsyncMock) as mock_post, + patch("helpers.main.db_patch", new_callable=AsyncMock) as mock_patch, + ): + mock_post.return_value = True + yield mock_get, mock_post, mock_patch + + +class TestSinglePack: + """Single pack opening — verifies basic flow.""" + + async def test_single_standard_pack_creates_cards_and_marks_opened(self, mock_db): + """A single standard pack should fetch players, create cards, and patch open_time. + + Why: Validates the core happy path — dice roll → fetch → create → mark opened. + """ + mock_get, mock_post, mock_patch = mock_db + + # Return enough players for any rarity tier requested + mock_get.return_value = _random_response([_make_player(i) for i in range(10)]) + + pack = _make_pack(100) + from helpers.main import roll_for_cards + + result = await roll_for_cards([pack]) + + assert result == [100] + # At least one db_get for player fetches + assert mock_get.call_count >= 1 + # Exactly one db_post for cards (may have notif posts too) + card_posts = [c for c in mock_post.call_args_list if c.args[0] == "cards"] + assert len(card_posts) == 1 + # Exactly one db_patch for marking pack opened + assert mock_patch.call_count == 1 + assert mock_patch.call_args.kwargs["object_id"] == 100 + + async def test_checkin_pack_uses_extra_val(self, mock_db): + """Check-In Player packs should apply extra_val modifier to dice range. + + Why: extra_val shifts the d1000 ceiling, affecting rarity odds for check-in rewards. + """ + mock_get, mock_post, mock_patch = mock_db + mock_get.return_value = _random_response([_make_player(1)]) + + pack = _make_pack(200, pack_type="Check-In Player") + from helpers.main import roll_for_cards + + result = await roll_for_cards([pack], extra_val=500) + + assert result == [200] + assert mock_get.call_count >= 1 + + async def test_unknown_pack_type_raises(self, mock_db): + """Unrecognized pack types must raise TypeError. + + Why: Guards against silent failures if a new pack type is added without dice logic. + """ + mock_get, mock_post, mock_patch = mock_db + + pack = _make_pack(300, pack_type="Unknown") + from helpers.main import roll_for_cards + + with pytest.raises(TypeError, match="Pack type not recognized"): + await roll_for_cards([pack]) + + +class TestMultiplePacks: + """Multiple packs — verifies batching and distribution.""" + + async def test_multiple_packs_return_all_ids(self, mock_db): + """Opening multiple packs should return all pack IDs. + + Why: Callers use the returned IDs to know which packs were successfully opened. + """ + mock_get, mock_post, mock_patch = mock_db + mock_get.return_value = _random_response([_make_player(i) for i in range(50)]) + + packs = [_make_pack(i) for i in range(5)] + from helpers.main import roll_for_cards + + result = await roll_for_cards(packs) + + assert result == [0, 1, 2, 3, 4] + + async def test_multiple_packs_batch_fetches(self, mock_db): + """Multiple packs should batch fetches — one db_get per rarity tier, not per pack. + + Why: This is the core performance optimization. 5 packs should NOT make 20-30 calls. + """ + mock_get, mock_post, mock_patch = mock_db + mock_get.return_value = _random_response([_make_player(i) for i in range(50)]) + + packs = [_make_pack(i) for i in range(5)] + from helpers.main import roll_for_cards + + await roll_for_cards(packs) + + # Standard packs have up to 6 rarity tiers, but typically fewer are non-zero. + # The key assertion: far fewer fetches than 5 packs * ~4 tiers = 20. + player_fetches = [ + c for c in mock_get.call_args_list if c.args[0] == "players/random" + ] + # At most 6 tier fetches + possible 1 MVP backfill = 7 + assert len(player_fetches) <= 7 + + async def test_multiple_packs_create_cards_per_pack(self, mock_db): + """Each pack should get its own db_post('cards') call with correct pack_id. + + Why: Cards must be associated with the correct pack for display and tracking. + """ + mock_get, mock_post, mock_patch = mock_db + mock_get.return_value = _random_response([_make_player(i) for i in range(50)]) + + packs = [_make_pack(i) for i in range(3)] + from helpers.main import roll_for_cards + + await roll_for_cards(packs) + + card_posts = [c for c in mock_post.call_args_list if c.args[0] == "cards"] + assert len(card_posts) == 3 + # Each card post should reference the correct pack_id + for i, post_call in enumerate(card_posts): + payload = post_call.kwargs["payload"] + pack_ids_in_cards = {card["pack_id"] for card in payload["cards"]} + assert pack_ids_in_cards == {i} + + +class TestMVPBackfill: + """MVP fallback when a rarity tier returns fewer players than requested.""" + + async def test_shortfall_triggers_mvp_backfill(self, mock_db): + """When a tier returns fewer players than needed, MVP backfill should fire. + + Why: Packs must always contain the expected number of cards. Shortfalls are + filled with MVP-tier players as a fallback. + """ + mock_get, mock_post, mock_patch = mock_db + + call_count = 0 + + async def side_effect(endpoint, params=None): + nonlocal call_count + call_count += 1 + if params and any( + p[0] == "min_rarity" + and p[1] == 5 + and any(q[0] == "max_rarity" for q in params) is False + for p in params + ): + # MVP backfill call (no max_rarity) + return _random_response([_make_player(900, 5, "MVP")]) + + # For tier-specific calls, check if this is the MVP backfill + if params: + param_dict = dict(params) + if "max_rarity" not in param_dict: + return _random_response([_make_player(900, 5, "MVP")]) + + # Return fewer than requested to trigger shortfall + requested = 5 + if params: + for key, val in params: + if key == "limit": + requested = val + break + return _random_response( + [_make_player(i) for i in range(max(0, requested - 1))] + ) + + mock_get.side_effect = side_effect + + pack = _make_pack(100) + from helpers.main import roll_for_cards + + result = await roll_for_cards([pack]) + + assert result == [100] + # Should have at least the tier fetch + backfill call + assert mock_get.call_count >= 2 + + +class TestCardsetExclusion: + """Cardset 23 should duplicate existing players instead of MVP backfill.""" + + async def test_cardset_23_duplicates_instead_of_mvp(self, mock_db): + """For cardset 23, shortfalls should duplicate existing players, not fetch MVPs. + + Why: Cardset 23 (special/limited cardset) shouldn't pull from the MVP pool — + it should fill gaps by duplicating from what's already available. + """ + mock_get, mock_post, mock_patch = mock_db + + async def side_effect(endpoint, params=None): + if params: + param_dict = dict(params) + # If this is a backfill call (no max_rarity), it shouldn't happen + if "max_rarity" not in param_dict: + pytest.fail("Should not make MVP backfill call for cardset 23") + # Return fewer than requested + return _random_response([_make_player(1)]) + + mock_get.side_effect = side_effect + + pack = _make_pack(100, pack_cardset={"id": 23}) + from helpers.main import roll_for_cards + + # Force specific dice rolls to ensure a shortfall + with patch("helpers.main.random.randint", return_value=1): + # d1000=1 for Standard: Rep, Rep, Rep, Rep, Rep → 5 Reps needed + result = await roll_for_cards([pack]) + + assert result == [100] + + +class TestNotifications: + """Rare pull notifications should be gathered and sent.""" + + async def test_rare_pulls_generate_notifications(self, mock_db): + """Players with rarity >= 3 should trigger notification posts. + + Why: Rare pulls are announced to the community — all notifs should be sent. + """ + mock_get, mock_post, mock_patch = mock_db + + rare_player = _make_player( + 42, rarity_value=3, rarity_name="All-Star", p_name="Mike Trout" + ) + mock_get.return_value = _random_response([rare_player]) + + pack = _make_pack(100) + # Force all dice to land on All-Star tier (d1000=951 for card 3) + from helpers.main import roll_for_cards + + with patch("helpers.main.random.randint", return_value=960): + await roll_for_cards([pack]) + + notif_posts = [c for c in mock_post.call_args_list if c.args[0] == "notifs"] + assert len(notif_posts) >= 1 + payload = notif_posts[0].kwargs["payload"] + assert payload["title"] == "Rare Pull" + assert "Mike Trout" in payload["field_name"] + + async def test_no_notifications_for_common_pulls(self, mock_db): + """Players with rarity < 3 should NOT trigger notifications. + + Why: Only rare pulls are noteworthy — common cards would spam the notif feed. + """ + mock_get, mock_post, mock_patch = mock_db + + common_player = _make_player(1, rarity_value=0, rarity_name="Replacement") + mock_get.return_value = _random_response([common_player]) + + pack = _make_pack(100) + from helpers.main import roll_for_cards + + # Force low dice rolls (all Replacement) + with patch("helpers.main.random.randint", return_value=1): + await roll_for_cards([pack]) + + notif_posts = [c for c in mock_post.call_args_list if c.args[0] == "notifs"] + assert len(notif_posts) == 0 + + +class TestErrorHandling: + """Error propagation from gathered writes.""" + + async def test_card_creation_failure_raises(self, mock_db): + """If db_post('cards') returns falsy, ConnectionError must propagate. + + Why: Card creation failure means the pack wasn't properly opened — caller + needs to know so it can report the error to the user. + """ + mock_get, mock_post, mock_patch = mock_db + mock_get.return_value = _random_response([_make_player(1)]) + mock_post.return_value = False # Simulate failure + + pack = _make_pack(100) + from helpers.main import roll_for_cards + + with pytest.raises(ConnectionError, match="Failed to create"): + await roll_for_cards([pack]) + + +class TestPackTeamFiltering: + """Verify correct filter params are passed to player fetch.""" + + async def test_pack_team_adds_franchise_filter(self, mock_db): + """When pack has a pack_team, franchise filter should be applied. + + Why: Team-specific packs should only contain players from that franchise. + """ + mock_get, mock_post, mock_patch = mock_db + mock_get.return_value = _random_response([_make_player(1)]) + + pack = _make_pack( + 100, + pack_team={"sname": "NYY"}, + pack_cardset=None, + ) + from helpers.main import roll_for_cards + + with patch("helpers.main.random.randint", return_value=1): + await roll_for_cards([pack]) + + # Check that tier-fetch calls (those with max_rarity) include franchise filter + tier_calls = [ + c + for c in mock_get.call_args_list + if any(p[0] == "max_rarity" for p in (c.kwargs.get("params") or [])) + ] + assert len(tier_calls) >= 1 + for c in tier_calls: + param_dict = dict(c.kwargs.get("params") or []) + assert param_dict.get("franchise") == "NYY" + assert param_dict.get("in_packs") is True + + async def test_no_team_no_cardset_adds_in_packs(self, mock_db): + """When pack has no team or cardset, in_packs filter should be applied. + + Why: Generic packs still need the in_packs filter to exclude non-packable players. + """ + mock_get, mock_post, mock_patch = mock_db + mock_get.return_value = _random_response([_make_player(1)]) + + pack = _make_pack(100, pack_team=None, pack_cardset=None) + from helpers.main import roll_for_cards + + with patch("helpers.main.random.randint", return_value=1): + await roll_for_cards([pack]) + + # Check that tier-fetch calls (those with max_rarity) include in_packs filter + tier_calls = [ + c + for c in mock_get.call_args_list + if any(p[0] == "max_rarity" for p in (c.kwargs.get("params") or [])) + ] + assert len(tier_calls) >= 1 + for c in tier_calls: + param_dict = dict(c.kwargs.get("params") or []) + assert param_dict.get("in_packs") is True + + +# --------------------------------------------------------------------------- +# Integration tests — hit real dev API for reads, mock all writes +# --------------------------------------------------------------------------- +requires_api = pytest.mark.skipif( + not os.environ.get("API_TOKEN"), + reason="API_TOKEN not set — skipping integration tests", +) + + +@requires_api +class TestIntegrationRealFetches: + """Integration tests that hit the real dev API for player fetches. + + Only db_get is real — db_post and db_patch are mocked to prevent writes. + Run with: API_TOKEN= python -m pytest tests/test_roll_for_cards.py -k integration -v + """ + + @pytest.fixture + def mock_writes(self): + """Mock only write operations, let reads hit the real API.""" + with ( + patch("helpers.main.db_post", new_callable=AsyncMock) as mock_post, + patch("helpers.main.db_patch", new_callable=AsyncMock) as mock_patch, + ): + mock_post.return_value = True + yield mock_post, mock_patch + + async def test_integration_single_pack_fetches_real_players(self, mock_writes): + """A single standard pack should fetch real players from the dev API. + + Why: Validates that the batched fetch params (min_rarity, max_rarity, limit, + in_packs) produce valid responses from the real API and that the returned + players have the expected structure. + """ + mock_post, mock_patch = mock_writes + + pack = _make_pack(9999) + from helpers.main import roll_for_cards + + result = await roll_for_cards([pack]) + + assert result == [9999] + # Cards were "created" (mocked) + card_posts = [c for c in mock_post.call_args_list if c.args[0] == "cards"] + assert len(card_posts) == 1 + payload = card_posts[0].kwargs["payload"] + # Standard pack produces 5 cards + assert len(payload["cards"]) == 5 + # Each card has the expected structure + for card in payload["cards"]: + assert "player_id" in card + assert card["team_id"] == 1 + assert card["pack_id"] == 9999 + + async def test_integration_multiple_packs_batch_correctly(self, mock_writes): + """Multiple packs should batch fetches and distribute players correctly. + + Why: Validates the core optimization — summing counts across packs, making + fewer API calls, and slicing players back into per-pack groups with real data. + """ + mock_post, mock_patch = mock_writes + + packs = [_make_pack(i + 9000) for i in range(3)] + from helpers.main import roll_for_cards + + result = await roll_for_cards(packs) + + assert result == [9000, 9001, 9002] + card_posts = [c for c in mock_post.call_args_list if c.args[0] == "cards"] + assert len(card_posts) == 3 + # Each pack should have exactly 5 cards (Standard packs) + total_cards = 0 + for post_call in card_posts: + cards = post_call.kwargs["payload"]["cards"] + assert len(cards) == 5 + total_cards += len(cards) + assert total_cards == 15 + + async def test_integration_players_have_valid_rarity(self, mock_writes): + """Fetched players should have rarity values matching their requested tier. + + Why: Confirms the API respects min_rarity/max_rarity filters and that + the player distribution logic assigns correct-tier players to each pack. + """ + mock_post, mock_patch = mock_writes + + pack = _make_pack(9999) + from helpers.main import roll_for_cards + + # Use fixed dice to get known rarity distribution + # d1000=500 for Standard: Rep, Res, Sta, Res, Sta (mix of low tiers) + with patch("helpers.main.random.randint", return_value=500): + await roll_for_cards([pack]) + + card_posts = [c for c in mock_post.call_args_list if c.args[0] == "cards"] + assert len(card_posts) == 1 + cards = card_posts[0].kwargs["payload"]["cards"] + # All cards should have valid player_ids (positive integers from real API) + for card in cards: + assert isinstance(card["player_id"], int) + assert card["player_id"] > 0 + + async def test_integration_cardset_filter(self, mock_writes): + """Packs with a specific cardset should only fetch players from that cardset. + + Why: Validates that the cardset_id parameter is correctly passed through + the batched fetch and the API filters accordingly. + """ + mock_post, mock_patch = mock_writes + + pack = _make_pack(9999, pack_cardset={"id": 24}) + from helpers.main import roll_for_cards + + with patch("helpers.main.random.randint", return_value=500): + result = await roll_for_cards([pack]) + + assert result == [9999] + card_posts = [c for c in mock_post.call_args_list if c.args[0] == "cards"] + assert len(card_posts) == 1 + assert len(card_posts[0].kwargs["payload"]["cards"]) == 5 + + async def test_integration_checkin_pack(self, mock_writes): + """Check-In Player pack should fetch exactly 1 player from the real API. + + Why: Check-in packs produce a single card — validates the simplest + path through the batched fetch logic with real data. + """ + mock_post, mock_patch = mock_writes + + pack = _make_pack(9999, pack_type="Check-In Player") + from helpers.main import roll_for_cards + + result = await roll_for_cards([pack]) + + assert result == [9999] + card_posts = [c for c in mock_post.call_args_list if c.args[0] == "cards"] + assert len(card_posts) == 1 + # Check-in packs produce exactly 1 card + assert len(card_posts[0].kwargs["payload"]["cards"]) == 1