From 18ae65a8e22ecd2af4385c3d09dc70904f82203f Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Fri, 6 Mar 2026 14:04:05 -0600 Subject: [PATCH 1/7] fix: replace create_item_in_table placeholder with direct endpoint call (#30) Remove the generic placeholder method from BaseService and replace the single call site in CustomCommandsService.get_or_create_creator with a direct client.post("custom_commands/creators", ...) call, consistent with how _update_creator_stats and _update_creator_info already work. Co-Authored-By: Claude Sonnet 4.6 --- services/base_service.py | 332 +++++++++++++++------------- services/custom_commands_service.py | 5 +- 2 files changed, 175 insertions(+), 162 deletions(-) diff --git a/services/base_service.py b/services/base_service.py index e919e6b..faf0dba 100644 --- a/services/base_service.py +++ b/services/base_service.py @@ -3,6 +3,7 @@ Base service class for Discord Bot v2.0 Provides common CRUD operations and error handling for all data services. """ + import logging import hashlib from typing import Optional, Type, TypeVar, Generic, Dict, Any, List, Tuple @@ -12,15 +13,15 @@ from models.base import SBABaseModel from exceptions import APIException from utils.cache import CacheManager -logger = logging.getLogger(f'{__name__}.BaseService') +logger = logging.getLogger(f"{__name__}.BaseService") -T = TypeVar('T', bound=SBABaseModel) +T = TypeVar("T", bound=SBABaseModel) class BaseService(Generic[T]): """ Base service class providing common CRUD operations for SBA models. - + Features: - Generic type support for any SBABaseModel subclass - Automatic model validation and conversion @@ -28,15 +29,17 @@ class BaseService(Generic[T]): - API response format handling (count + list format) - Connection management via global client """ - - def __init__(self, - model_class: Type[T], - endpoint: str, - client: Optional[APIClient] = None, - cache_manager: Optional[CacheManager] = None): + + def __init__( + self, + model_class: Type[T], + endpoint: str, + client: Optional[APIClient] = None, + cache_manager: Optional[CacheManager] = None, + ): """ Initialize base service. - + Args: model_class: Pydantic model class for this service endpoint: API endpoint path (e.g., 'players', 'teams') @@ -48,40 +51,44 @@ class BaseService(Generic[T]): self._client = client self._cached_client: Optional[APIClient] = None self.cache = cache_manager or CacheManager() - - logger.debug(f"Initialized {self.__class__.__name__} for {model_class.__name__} at endpoint '{endpoint}'") - - def _generate_cache_key(self, method: str, params: Optional[List[Tuple[str, Any]]] = None) -> str: + + logger.debug( + f"Initialized {self.__class__.__name__} for {model_class.__name__} at endpoint '{endpoint}'" + ) + + def _generate_cache_key( + self, method: str, params: Optional[List[Tuple[str, Any]]] = None + ) -> str: """ Generate consistent cache key for API calls. - + Args: method: API method name params: Query parameters as list of tuples - + Returns: SHA256-hashed cache key """ key_parts = [self.endpoint, method] - + if params: # Sort parameters for consistent key generation sorted_params = sorted(params, key=lambda x: str(x[0])) param_str = "&".join([f"{k}={v}" for k, v in sorted_params]) key_parts.append(param_str) - + key_data = ":".join(key_parts) key_hash = hashlib.sha256(key_data.encode()).hexdigest()[:16] # First 16 chars - + return self.cache.cache_key("sba", f"{self.endpoint}_{key_hash}") - + async def _get_cached_items(self, cache_key: str) -> Optional[List[T]]: """ Get cached list of model items. - + Args: cache_key: Cache key to lookup - + Returns: List of model instances or None if not cached """ @@ -91,13 +98,15 @@ class BaseService(Generic[T]): return [self.model_class.from_api_data(item) for item in cached_data] except Exception as e: logger.warning(f"Error deserializing cached data for {cache_key}: {e}") - + return None - - async def _cache_items(self, cache_key: str, items: List[T], ttl: Optional[int] = None) -> None: + + async def _cache_items( + self, cache_key: str, items: List[T], ttl: Optional[int] = None + ) -> None: """ Cache list of model items. - + Args: cache_key: Cache key to store under items: List of model instances to cache @@ -105,40 +114,40 @@ class BaseService(Generic[T]): """ if not items: return - + try: # Convert to JSON-serializable format cache_data = [item.model_dump() for item in items] await self.cache.set(cache_key, cache_data, ttl) except Exception as e: logger.warning(f"Error caching items for {cache_key}: {e}") - + async def get_client(self) -> APIClient: """ Get API client instance with caching to reduce async overhead. - + Returns: APIClient instance (cached after first access) """ if self._client: return self._client - + # Cache the global client to avoid repeated async calls if self._cached_client is None: self._cached_client = await get_global_client() - + return self._cached_client - + async def get_by_id(self, object_id: int) -> Optional[T]: """ Get single object by ID. - + Args: object_id: Unique identifier for the object - + Returns: Model instance or None if not found - + Raises: APIException: For API errors ValueError: For invalid data @@ -146,80 +155,90 @@ class BaseService(Generic[T]): try: client = await self.get_client() data = await client.get(self.endpoint, object_id=object_id) - + if not data: logger.debug(f"{self.model_class.__name__} {object_id} not found") return None - + model = self.model_class.from_api_data(data) logger.debug(f"Retrieved {self.model_class.__name__} {object_id}: {model}") return model - + except APIException: - logger.error(f"API error retrieving {self.model_class.__name__} {object_id}") + logger.error( + f"API error retrieving {self.model_class.__name__} {object_id}" + ) raise except Exception as e: - logger.error(f"Error retrieving {self.model_class.__name__} {object_id}: {e}") + logger.error( + f"Error retrieving {self.model_class.__name__} {object_id}: {e}" + ) raise APIException(f"Failed to retrieve {self.model_class.__name__}: {e}") - - async def get_all(self, params: Optional[List[tuple]] = None) -> Tuple[List[T], int]: + + async def get_all( + self, params: Optional[List[tuple]] = None + ) -> Tuple[List[T], int]: """ Get all objects with optional query parameters. - + Args: params: Query parameters as list of (key, value) tuples - + Returns: Tuple of (list of model instances, total count) - + Raises: APIException: For API errors """ try: client = await self.get_client() data = await client.get(self.endpoint, params=params) - + if not data: logger.debug(f"No {self.model_class.__name__} objects found") return [], 0 - + # Handle API response format: {'count': int, '': [...]} items, count = self._extract_items_and_count_from_response(data) - + models = [self.model_class.from_api_data(item) for item in items] - logger.debug(f"Retrieved {len(models)} of {count} {self.model_class.__name__} objects") + logger.debug( + f"Retrieved {len(models)} of {count} {self.model_class.__name__} objects" + ) return models, count - + except APIException: logger.error(f"API error retrieving {self.model_class.__name__} list") raise except Exception as e: logger.error(f"Error retrieving {self.model_class.__name__} list: {e}") - raise APIException(f"Failed to retrieve {self.model_class.__name__} list: {e}") - + raise APIException( + f"Failed to retrieve {self.model_class.__name__} list: {e}" + ) + async def get_all_items(self, params: Optional[List[tuple]] = None) -> List[T]: """ Get all objects (convenience method that only returns the list). - + Args: params: Query parameters as list of (key, value) tuples - + Returns: List of model instances """ items, _ = await self.get_all(params=params) return items - + async def create(self, model_data: Dict[str, Any]) -> Optional[T]: """ Create new object from data dictionary. - + Args: model_data: Dictionary of model fields - + Returns: Created model instance or None - + Raises: APIException: For API errors ValueError: For invalid data @@ -227,86 +246,90 @@ class BaseService(Generic[T]): try: client = await self.get_client() response = await client.post(self.endpoint, model_data) - + if not response: logger.warning(f"No response from {self.model_class.__name__} creation") return None - + model = self.model_class.from_api_data(response) logger.debug(f"Created {self.model_class.__name__}: {model}") return model - + except APIException: logger.error(f"API error creating {self.model_class.__name__}") raise except Exception as e: logger.error(f"Error creating {self.model_class.__name__}: {e}") raise APIException(f"Failed to create {self.model_class.__name__}: {e}") - + async def create_from_model(self, model: T) -> Optional[T]: """ Create new object from model instance. - + Args: model: Model instance to create - + Returns: Created model instance or None """ return await self.create(model.to_dict(exclude_none=True)) - + async def update(self, object_id: int, model_data: Dict[str, Any]) -> Optional[T]: """ Update existing object. - + Args: object_id: ID of object to update model_data: Dictionary of fields to update - + Returns: Updated model instance or None if not found - + Raises: APIException: For API errors """ try: client = await self.get_client() response = await client.put(self.endpoint, model_data, object_id=object_id) - + if not response: - logger.debug(f"{self.model_class.__name__} {object_id} not found for update") + logger.debug( + f"{self.model_class.__name__} {object_id} not found for update" + ) return None - + model = self.model_class.from_api_data(response) logger.debug(f"Updated {self.model_class.__name__} {object_id}: {model}") return model - + except APIException: logger.error(f"API error updating {self.model_class.__name__} {object_id}") raise except Exception as e: logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}") raise APIException(f"Failed to update {self.model_class.__name__}: {e}") - + async def update_from_model(self, model: T) -> Optional[T]: """ Update object from model instance. - + Args: model: Model instance to update (must have ID) - + Returns: Updated model instance or None - + Raises: ValueError: If model has no ID """ if not model.id: raise ValueError(f"Cannot update {self.model_class.__name__} without ID") - + return await self.update(model.id, model.to_dict(exclude_none=True)) - - async def patch(self, object_id: int, model_data: Dict[str, Any], use_query_params: bool = False) -> Optional[T]: + + async def patch( + self, object_id: int, model_data: Dict[str, Any], use_query_params: bool = False + ) -> Optional[T]: """ Update existing object with HTTP PATCH. @@ -323,10 +346,14 @@ class BaseService(Generic[T]): """ try: client = await self.get_client() - response = await client.patch(self.endpoint, model_data, object_id, use_query_params=use_query_params) + response = await client.patch( + self.endpoint, model_data, object_id, use_query_params=use_query_params + ) if not response: - logger.debug(f"{self.model_class.__name__} {object_id} not found for update") + logger.debug( + f"{self.model_class.__name__} {object_id} not found for update" + ) return None model = self.model_class.from_api_data(response) @@ -340,134 +367,142 @@ class BaseService(Generic[T]): logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}") raise APIException(f"Failed to update {self.model_class.__name__}: {e}") - async def delete(self, object_id: int) -> bool: """ Delete object by ID. - + Args: object_id: ID of object to delete - + Returns: True if deleted, False if not found - + Raises: APIException: For API errors """ try: client = await self.get_client() success = await client.delete(self.endpoint, object_id=object_id) - + if success: logger.debug(f"Deleted {self.model_class.__name__} {object_id}") else: - logger.debug(f"{self.model_class.__name__} {object_id} not found for deletion") - + logger.debug( + f"{self.model_class.__name__} {object_id} not found for deletion" + ) + return success - + except APIException: logger.error(f"API error deleting {self.model_class.__name__} {object_id}") raise except Exception as e: logger.error(f"Error deleting {self.model_class.__name__} {object_id}: {e}") raise APIException(f"Failed to delete {self.model_class.__name__}: {e}") - - + async def get_by_field(self, field: str, value: Any) -> List[T]: """ Get objects by specific field value. - + Args: field: Field name to search value: Field value to match - + Returns: List of matching model instances """ params = [(field, str(value))] return await self.get_all_items(params=params) - + async def count(self, params: Optional[List[tuple]] = None) -> int: """ Get count of objects matching parameters. - + Args: params: Query parameters - + Returns: Number of matching objects (from API count field) """ _, count = await self.get_all(params=params) return count - - def _extract_items_and_count_from_response(self, data: Any) -> Tuple[List[Dict[str, Any]], int]: + + def _extract_items_and_count_from_response( + self, data: Any + ) -> Tuple[List[Dict[str, Any]], int]: """ Extract items list and count from API response with optimized parsing. - + Expected format: {'count': int, '': [...]} Single object format: {'id': 1, 'name': '...'} - + Args: data: API response data - + Returns: Tuple of (items list, total count) """ if isinstance(data, list): return data, len(data) - + if not isinstance(data, dict): - logger.warning(f"Unexpected response format for {self.model_class.__name__}: {type(data)}") + logger.warning( + f"Unexpected response format for {self.model_class.__name__}: {type(data)}" + ) return [], 0 - + # Single pass through the response dict - get count first - count = data.get('count', 0) - + count = data.get("count", 0) + # Priority order for finding items list (most common first) - field_candidates = [self.endpoint, 'items', 'data', 'results'] + field_candidates = [self.endpoint, "items", "data", "results"] for field_name in field_candidates: if field_name in data and isinstance(data[field_name], list): return data[field_name], count or len(data[field_name]) - + # Single object response (check for common identifying fields) - if any(key in data for key in ['id', 'name', 'abbrev']): + if any(key in data for key in ["id", "name", "abbrev"]): return [data], 1 - + return [], count - - async def get_items_with_params(self, params: Optional[List[tuple]] = None) -> List[T]: + + async def get_items_with_params( + self, params: Optional[List[tuple]] = None + ) -> List[T]: """ Get all items with parameters (alias for get_all_items for compatibility). - + Args: params: Query parameters as list of (key, value) tuples - + Returns: List of model instances """ return await self.get_all_items(params=params) - + async def create_item(self, model_data: Dict[str, Any]) -> Optional[T]: """ Create item (alias for create for compatibility). - + Args: model_data: Dictionary of model fields - + Returns: Created model instance or None """ return await self.create(model_data) - - async def update_item_by_field(self, field: str, value: Any, update_data: Dict[str, Any]) -> Optional[T]: + + async def update_item_by_field( + self, field: str, value: Any, update_data: Dict[str, Any] + ) -> Optional[T]: """ Update item by field value. - + Args: field: Field name to search by value: Field value to match update_data: Data to update - + Returns: Updated model instance or None if not found """ @@ -475,22 +510,22 @@ class BaseService(Generic[T]): items = await self.get_by_field(field, value) if not items: return None - + # Update the first matching item item = items[0] if not item.id: return None - + return await self.update(item.id, update_data) - + async def delete_item_by_field(self, field: str, value: Any) -> bool: """ Delete item by field value. - + Args: field: Field name to search by value: Field value to match - + Returns: True if deleted, False if not found """ @@ -498,62 +533,41 @@ class BaseService(Generic[T]): items = await self.get_by_field(field, value) if not items: return False - + # Delete the first matching item item = items[0] if not item.id: return False - + return await self.delete(item.id) - - async def create_item_in_table(self, table_name: str, item_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """ - Create item in a specific table (simplified for custom commands service). - This is a placeholder - real implementation would need table-specific endpoints. - - Args: - table_name: Name of the table - item_data: Data to create - - Returns: - Created item data or None - """ - # For now, use the main endpoint - this would need proper implementation - # for different tables like 'custom_command_creators' - try: - client = await self.get_client() - # Use table name as endpoint for now - response = await client.post(table_name, item_data) - return response - except Exception as e: - logger.error(f"Error creating item in table {table_name}: {e}") - return None - - async def get_items_from_table_with_params(self, table_name: str, params: List[tuple]) -> List[Dict[str, Any]]: + + async def get_items_from_table_with_params( + self, table_name: str, params: List[tuple] + ) -> List[Dict[str, Any]]: """ Get items from a specific table with parameters. - + Args: table_name: Name of the table params: Query parameters - + Returns: List of item dictionaries """ try: client = await self.get_client() data = await client.get(table_name, params=params) - + if not data: return [] - + # Handle response format items, _ = self._extract_items_and_count_from_response(data) return items - + except Exception as e: logger.error(f"Error getting items from table {table_name}: {e}") return [] def __repr__(self) -> str: - return f"{self.__class__.__name__}(model={self.model_class.__name__}, endpoint='{self.endpoint}')" \ No newline at end of file + return f"{self.__class__.__name__}(model={self.model_class.__name__}, endpoint='{self.endpoint}')" diff --git a/services/custom_commands_service.py b/services/custom_commands_service.py index 72838bb..03445bc 100644 --- a/services/custom_commands_service.py +++ b/services/custom_commands_service.py @@ -552,9 +552,8 @@ class CustomCommandsService(BaseService[CustomCommand]): "active_commands": 0, } - result = await self.create_item_in_table( - "custom_commands/creators", creator_data - ) + client = await self.get_client() + result = await client.post("custom_commands/creators", creator_data) if not result: raise BotException("Failed to create command creator") From e98a658fdea078a8eaa38a1d650452178c997639 Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Sun, 8 Mar 2026 11:21:51 -0500 Subject: [PATCH 2/7] fix: show actual validation errors in trade embed Quick Status Quick Status previously only showed "X errors found" with no details. Now lists each error and suggestion inline. Also stripped all emoji from embed titles, field names, values, buttons, and messages. Co-Authored-By: Claude Opus 4.6 --- views/trade_embed.py | 449 +++++++++++++++++++++---------------------- 1 file changed, 222 insertions(+), 227 deletions(-) diff --git a/views/trade_embed.py b/views/trade_embed.py index 507299f..acccfa7 100644 --- a/views/trade_embed.py +++ b/views/trade_embed.py @@ -3,6 +3,7 @@ Interactive Trade Embed Views Handles the Discord embed and button interfaces for the multi-team trade builder. """ + import discord from typing import Optional, List from datetime import datetime, timezone @@ -31,60 +32,56 @@ class TradeEmbedView(discord.ui.View): """Check if user has permission to interact with this view.""" if interaction.user.id != self.user_id: await interaction.response.send_message( - "āŒ You don't have permission to use this trade builder.", - ephemeral=True + "You don't have permission to use this trade builder.", + ephemeral=True, ) return False return True async def on_timeout(self) -> None: """Handle view timeout.""" - # Disable all buttons when timeout occurs for item in self.children: if isinstance(item, discord.ui.Button): item.disabled = True - @discord.ui.button(label="Remove Move", style=discord.ButtonStyle.red, emoji="āž–") - async def remove_move_button(self, interaction: discord.Interaction, button: discord.ui.Button): + @discord.ui.button(label="Remove Move", style=discord.ButtonStyle.red) + async def remove_move_button( + self, interaction: discord.Interaction, button: discord.ui.Button + ): """Handle remove move button click.""" if self.builder.is_empty: await interaction.response.send_message( - "āŒ No moves to remove. Add some moves first!", - ephemeral=True + "No moves to remove. Add some moves first!", ephemeral=True ) return - # Create select menu for move removal select_view = RemoveTradeMovesView(self.builder, self.user_id) embed = await create_trade_embed(self.builder) await interaction.response.edit_message(embed=embed, view=select_view) - @discord.ui.button(label="Validate Trade", style=discord.ButtonStyle.secondary, emoji="šŸ”") - async def validate_button(self, interaction: discord.Interaction, button: discord.ui.Button): + @discord.ui.button(label="Validate Trade", style=discord.ButtonStyle.secondary) + async def validate_button( + self, interaction: discord.Interaction, button: discord.ui.Button + ): """Handle validate trade button click.""" await interaction.response.defer(ephemeral=True) - # Perform detailed validation validation = await self.builder.validate_trade() - # Create validation report if validation.is_legal: - status_emoji = "āœ…" status_text = "**Trade is LEGAL**" color = EmbedColors.SUCCESS else: - status_emoji = "āŒ" status_text = "**Trade has ERRORS**" color = EmbedColors.ERROR embed = EmbedTemplate.create_base_embed( - title=f"{status_emoji} Trade Validation Report", + title="Trade Validation Report", description=status_text, - color=color + color=color, ) - # Add team-by-team validation for participant in self.builder.trade.participants: team_validation = validation.get_participant_validation(participant.team.id) if team_validation: @@ -98,72 +95,65 @@ class TradeEmbedView(discord.ui.View): team_status.append(team_validation.pre_existing_transactions_note) embed.add_field( - name=f"šŸŸļø {participant.team.abbrev} - {participant.team.sname}", + name=f"{participant.team.abbrev} - {participant.team.sname}", value="\n".join(team_status), - inline=False + inline=False, ) - # Add overall errors and suggestions if validation.all_errors: - error_text = "\n".join([f"• {error}" for error in validation.all_errors]) - embed.add_field( - name="āŒ Errors", - value=error_text, - inline=False - ) + error_text = "\n".join([f"- {error}" for error in validation.all_errors]) + embed.add_field(name="Errors", value=error_text, inline=False) if validation.all_suggestions: - suggestion_text = "\n".join([f"šŸ’” {suggestion}" for suggestion in validation.all_suggestions]) - embed.add_field( - name="šŸ’” Suggestions", - value=suggestion_text, - inline=False + suggestion_text = "\n".join( + [f"- {suggestion}" for suggestion in validation.all_suggestions] ) + embed.add_field(name="Suggestions", value=suggestion_text, inline=False) await interaction.followup.send(embed=embed, ephemeral=True) - @discord.ui.button(label="Submit Trade", style=discord.ButtonStyle.primary, emoji="šŸ“¤") - async def submit_button(self, interaction: discord.Interaction, button: discord.ui.Button): + @discord.ui.button(label="Submit Trade", style=discord.ButtonStyle.primary) + async def submit_button( + self, interaction: discord.Interaction, button: discord.ui.Button + ): """Handle submit trade button click.""" if self.builder.is_empty: await interaction.response.send_message( - "āŒ Cannot submit empty trade. Add some moves first!", - ephemeral=True + "Cannot submit empty trade. Add some moves first!", ephemeral=True ) return - # Validate before submission validation = await self.builder.validate_trade() if not validation.is_legal: - error_msg = "āŒ **Cannot submit illegal trade:**\n" - error_msg += "\n".join([f"• {error}" for error in validation.all_errors]) + error_msg = "**Cannot submit illegal trade:**\n" + error_msg += "\n".join([f"- {error}" for error in validation.all_errors]) if validation.all_suggestions: error_msg += "\n\n**Suggestions:**\n" - error_msg += "\n".join([f"šŸ’” {suggestion}" for suggestion in validation.all_suggestions]) + error_msg += "\n".join( + [f"- {suggestion}" for suggestion in validation.all_suggestions] + ) await interaction.response.send_message(error_msg, ephemeral=True) return - # Show confirmation modal modal = SubmitTradeConfirmationModal(self.builder) await interaction.response.send_modal(modal) - @discord.ui.button(label="Cancel Trade", style=discord.ButtonStyle.secondary, emoji="āŒ") - async def cancel_button(self, interaction: discord.Interaction, button: discord.ui.Button): + @discord.ui.button(label="Cancel Trade", style=discord.ButtonStyle.secondary) + async def cancel_button( + self, interaction: discord.Interaction, button: discord.ui.Button + ): """Handle cancel trade button click.""" self.builder.clear_trade() embed = await create_trade_embed(self.builder) - # Disable all buttons after cancellation for item in self.children: if isinstance(item, discord.ui.Button): item.disabled = True await interaction.response.edit_message( - content="āŒ **Trade cancelled and cleared.**", - embed=embed, - view=self + content="**Trade cancelled and cleared.**", embed=embed, view=self ) self.stop() @@ -176,12 +166,12 @@ class RemoveTradeMovesView(discord.ui.View): self.builder = builder self.user_id = user_id - # Create select menu with current moves if not builder.is_empty: self.add_item(RemoveTradeMovesSelect(builder)) - # Add back button - back_button = discord.ui.Button(label="Back", style=discord.ButtonStyle.secondary, emoji="ā¬…ļø") + back_button = discord.ui.Button( + label="Back", style=discord.ButtonStyle.secondary + ) back_button.callback = self.back_callback self.add_item(back_button) @@ -202,35 +192,36 @@ class RemoveTradeMovesSelect(discord.ui.Select): def __init__(self, builder: TradeBuilder): self.builder = builder - # Create options from all moves (cross-team and supplementary) options = [] move_count = 0 - # Add cross-team moves - for move in builder.trade.cross_team_moves[:20]: # Limit to avoid Discord's 25 option limit - options.append(discord.SelectOption( - label=f"{move.player.name}", - description=move.description[:100], # Discord description limit - value=str(move.player.id), - emoji="šŸ”„" - )) + for move in builder.trade.cross_team_moves[ + :20 + ]: # Limit to avoid Discord's 25 option limit + options.append( + discord.SelectOption( + label=f"{move.player.name}", + description=move.description[:100], + value=str(move.player.id), + ) + ) move_count += 1 - # Add supplementary moves if there's room remaining_slots = 25 - move_count for move in builder.trade.supplementary_moves[:remaining_slots]: - options.append(discord.SelectOption( - label=f"{move.player.name}", - description=move.description[:100], - value=str(move.player.id), - emoji="āš™ļø" - )) + options.append( + discord.SelectOption( + label=f"{move.player.name}", + description=move.description[:100], + value=str(move.player.id), + ) + ) super().__init__( placeholder="Select a move to remove...", min_values=1, max_values=1, - options=options + options=options, ) async def callback(self, interaction: discord.Interaction): @@ -241,27 +232,25 @@ class RemoveTradeMovesSelect(discord.ui.Select): if success: await interaction.response.send_message( - f"āœ… Removed move for player ID {player_id}", - ephemeral=True + f"Removed move for player ID {player_id}", ephemeral=True ) - # Update the embed main_view = TradeEmbedView(self.builder, interaction.user.id) embed = await create_trade_embed(self.builder) - # Edit the original message await interaction.edit_original_response(embed=embed, view=main_view) else: await interaction.response.send_message( - f"āŒ Could not remove move: {error_msg}", - ephemeral=True + f"Could not remove move: {error_msg}", ephemeral=True ) class SubmitTradeConfirmationModal(discord.ui.Modal): """Modal for confirming trade submission - posts acceptance request to trade channel.""" - def __init__(self, builder: TradeBuilder, trade_channel: Optional[discord.TextChannel] = None): + def __init__( + self, builder: TradeBuilder, trade_channel: Optional[discord.TextChannel] = None + ): super().__init__(title="Confirm Trade Submission") self.builder = builder self.trade_channel = trade_channel @@ -270,7 +259,7 @@ class SubmitTradeConfirmationModal(discord.ui.Modal): label="Type 'CONFIRM' to submit for approval", placeholder="CONFIRM", required=True, - max_length=7 + max_length=7, ) self.add_item(self.confirmation) @@ -279,56 +268,52 @@ class SubmitTradeConfirmationModal(discord.ui.Modal): """Handle confirmation submission - posts acceptance view to trade channel.""" if self.confirmation.value.upper() != "CONFIRM": await interaction.response.send_message( - "āŒ Trade not submitted. You must type 'CONFIRM' exactly.", - ephemeral=True + "Trade not submitted. You must type 'CONFIRM' exactly.", + ephemeral=True, ) return await interaction.response.defer(ephemeral=True) try: - # Update trade status to PROPOSED from models.trade import TradeStatus + self.builder.trade.status = TradeStatus.PROPOSED - # Create acceptance embed and view acceptance_embed = await create_trade_acceptance_embed(self.builder) acceptance_view = TradeAcceptanceView(self.builder) - # Find the trade channel to post to channel = self.trade_channel if not channel: - # Try to find trade channel by name pattern - trade_channel_name = f"trade-{'-'.join(t.abbrev.lower() for t in self.builder.participating_teams)}" for ch in interaction.guild.text_channels: # type: ignore - if ch.name.startswith("trade-") and self.builder.trade_id[:4] in ch.name: + if ( + ch.name.startswith("trade-") + and self.builder.trade_id[:4] in ch.name + ): channel = ch break if channel: - # Post acceptance request to trade channel await channel.send( - content="šŸ“‹ **Trade submitted for approval!** All teams must accept to complete the trade.", + content="**Trade submitted for approval.** All teams must accept to complete the trade.", embed=acceptance_embed, - view=acceptance_view + view=acceptance_view, ) await interaction.followup.send( - f"āœ… Trade submitted for approval!\n\nThe acceptance request has been posted to {channel.mention}.\n" + f"Trade submitted for approval.\n\nThe acceptance request has been posted to {channel.mention}.\n" f"All participating teams must click **Accept Trade** to finalize.", - ephemeral=True + ephemeral=True, ) else: - # No trade channel found, post in current channel await interaction.followup.send( - content="šŸ“‹ **Trade submitted for approval!** All teams must accept to complete the trade.", + content="**Trade submitted for approval.** All teams must accept to complete the trade.", embed=acceptance_embed, - view=acceptance_view + view=acceptance_view, ) except Exception as e: await interaction.followup.send( - f"āŒ Error submitting trade: {str(e)}", - ephemeral=True + f"Error submitting trade: {str(e)}", ephemeral=True ) @@ -343,8 +328,11 @@ class TradeAcceptanceView(discord.ui.View): """Get the team owned by the interacting user.""" from services.team_service import team_service from config import get_config + config = get_config() - return await team_service.get_team_by_owner(interaction.user.id, config.sba_season) + return await team_service.get_team_by_owner( + interaction.user.id, config.sba_season + ) async def interaction_check(self, interaction: discord.Interaction) -> bool: """Check if user is a GM of a participating team.""" @@ -352,17 +340,14 @@ class TradeAcceptanceView(discord.ui.View): if not user_team: await interaction.response.send_message( - "āŒ You don't own a team in the league.", - ephemeral=True + "You don't own a team in the league.", ephemeral=True ) return False - # Check if their team (or organization) is participating participant = self.builder.trade.get_participant_by_organization(user_team) if not participant: await interaction.response.send_message( - "āŒ Your team is not part of this trade.", - ephemeral=True + "Your team is not part of this trade.", ephemeral=True ) return False @@ -374,47 +359,45 @@ class TradeAcceptanceView(discord.ui.View): if isinstance(item, discord.ui.Button): item.disabled = True - @discord.ui.button(label="Accept Trade", style=discord.ButtonStyle.success, emoji="āœ…") - async def accept_button(self, interaction: discord.Interaction, button: discord.ui.Button): + @discord.ui.button(label="Accept Trade", style=discord.ButtonStyle.success) + async def accept_button( + self, interaction: discord.Interaction, button: discord.ui.Button + ): """Handle accept button click.""" user_team = await self._get_user_team(interaction) if not user_team: return - # Find the participating team (could be org affiliate) participant = self.builder.trade.get_participant_by_organization(user_team) if not participant: return team_id = participant.team.id - # Check if already accepted if self.builder.has_team_accepted(team_id): await interaction.response.send_message( - f"āœ… {participant.team.abbrev} has already accepted this trade.", - ephemeral=True + f"{participant.team.abbrev} has already accepted this trade.", + ephemeral=True, ) return - # Record acceptance all_accepted = self.builder.accept_trade(team_id) if all_accepted: - # All teams accepted - finalize the trade await self._finalize_trade(interaction) else: - # Update embed to show new acceptance status embed = await create_trade_acceptance_embed(self.builder) await interaction.response.edit_message(embed=embed, view=self) - # Send confirmation to channel await interaction.followup.send( - f"āœ… **{participant.team.abbrev}** has accepted the trade! " + f"**{participant.team.abbrev}** has accepted the trade. " f"({len(self.builder.accepted_teams)}/{self.builder.team_count} teams)" ) - @discord.ui.button(label="Reject Trade", style=discord.ButtonStyle.danger, emoji="āŒ") - async def reject_button(self, interaction: discord.Interaction, button: discord.ui.Button): + @discord.ui.button(label="Reject Trade", style=discord.ButtonStyle.danger) + async def reject_button( + self, interaction: discord.Interaction, button: discord.ui.Button + ): """Handle reject button click - moves trade back to DRAFT.""" user_team = await self._get_user_team(interaction) if not user_team: @@ -424,20 +407,16 @@ class TradeAcceptanceView(discord.ui.View): if not participant: return - # Reject the trade self.builder.reject_trade() - # Disable buttons self.accept_button.disabled = True self.reject_button.disabled = True - # Update embed to show rejection embed = await create_trade_rejection_embed(self.builder, participant.team) await interaction.response.edit_message(embed=embed, view=self) - # Notify the channel await interaction.followup.send( - f"āŒ **{participant.team.abbrev}** has rejected the trade.\n\n" + f"**{participant.team.abbrev}** has rejected the trade.\n\n" f"The trade has been moved back to **DRAFT** status. " f"Teams can continue negotiating using `/trade` commands." ) @@ -459,41 +438,52 @@ class TradeAcceptanceView(discord.ui.View): config = get_config() - # Get next week for transactions current = await league_service.get_current_state() next_week = current.week + 1 if current else 1 - # Create FA team for reference fa_team = Team( id=config.free_agent_team_id, abbrev="FA", sname="Free Agents", lname="Free Agency", - season=self.builder.trade.season + season=self.builder.trade.season, ) # type: ignore - # Create transactions from all moves transactions: List[Transaction] = [] move_id = f"Trade-{self.builder.trade_id}-{int(datetime.now(timezone.utc).timestamp())}" - # Process cross-team moves for move in self.builder.trade.cross_team_moves: - # Get actual team affiliates for from/to based on roster type if move.from_roster == RosterType.MAJOR_LEAGUE: old_team = move.source_team elif move.from_roster == RosterType.MINOR_LEAGUE: - old_team = await move.source_team.minor_league_affiliate() if move.source_team else None + old_team = ( + await move.source_team.minor_league_affiliate() + if move.source_team + else None + ) elif move.from_roster == RosterType.INJURED_LIST: - old_team = await move.source_team.injured_list_affiliate() if move.source_team else None + old_team = ( + await move.source_team.injured_list_affiliate() + if move.source_team + else None + ) else: old_team = move.source_team if move.to_roster == RosterType.MAJOR_LEAGUE: new_team = move.destination_team elif move.to_roster == RosterType.MINOR_LEAGUE: - new_team = await move.destination_team.minor_league_affiliate() if move.destination_team else None + new_team = ( + await move.destination_team.minor_league_affiliate() + if move.destination_team + else None + ) elif move.to_roster == RosterType.INJURED_LIST: - new_team = await move.destination_team.injured_list_affiliate() if move.destination_team else None + new_team = ( + await move.destination_team.injured_list_affiliate() + if move.destination_team + else None + ) else: new_team = move.destination_team @@ -507,18 +497,25 @@ class TradeAcceptanceView(discord.ui.View): oldteam=old_team, newteam=new_team, cancelled=False, - frozen=False # Trades are NOT frozen - immediately effective + frozen=False, ) transactions.append(transaction) - # Process supplementary moves for move in self.builder.trade.supplementary_moves: if move.from_roster == RosterType.MAJOR_LEAGUE: old_team = move.source_team elif move.from_roster == RosterType.MINOR_LEAGUE: - old_team = await move.source_team.minor_league_affiliate() if move.source_team else None + old_team = ( + await move.source_team.minor_league_affiliate() + if move.source_team + else None + ) elif move.from_roster == RosterType.INJURED_LIST: - old_team = await move.source_team.injured_list_affiliate() if move.source_team else None + old_team = ( + await move.source_team.injured_list_affiliate() + if move.source_team + else None + ) elif move.from_roster == RosterType.FREE_AGENCY: old_team = fa_team else: @@ -527,9 +524,17 @@ class TradeAcceptanceView(discord.ui.View): if move.to_roster == RosterType.MAJOR_LEAGUE: new_team = move.destination_team elif move.to_roster == RosterType.MINOR_LEAGUE: - new_team = await move.destination_team.minor_league_affiliate() if move.destination_team else None + new_team = ( + await move.destination_team.minor_league_affiliate() + if move.destination_team + else None + ) elif move.to_roster == RosterType.INJURED_LIST: - new_team = await move.destination_team.injured_list_affiliate() if move.destination_team else None + new_team = ( + await move.destination_team.injured_list_affiliate() + if move.destination_team + else None + ) elif move.to_roster == RosterType.FREE_AGENCY: new_team = fa_team else: @@ -545,45 +550,42 @@ class TradeAcceptanceView(discord.ui.View): oldteam=old_team, newteam=new_team, cancelled=False, - frozen=False # Trades are NOT frozen - immediately effective + frozen=False, ) transactions.append(transaction) - # POST transactions to database if transactions: - created_transactions = await transaction_service.create_transaction_batch(transactions) + created_transactions = ( + await transaction_service.create_transaction_batch(transactions) + ) else: created_transactions = [] - # Post to #transaction-log channel if created_transactions and interaction.client: await post_trade_to_log( bot=interaction.client, builder=self.builder, transactions=created_transactions, - effective_week=next_week + effective_week=next_week, ) - # Update trade status self.builder.trade.status = TradeStatus.ACCEPTED - # Disable buttons self.accept_button.disabled = True self.reject_button.disabled = True - # Update embed to show completion - embed = await create_trade_complete_embed(self.builder, len(created_transactions), next_week) + embed = await create_trade_complete_embed( + self.builder, len(created_transactions), next_week + ) await interaction.edit_original_response(embed=embed, view=self) - # Send completion message await interaction.followup.send( - f"šŸŽ‰ **Trade Complete!**\n\n" + f"**Trade Complete!**\n\n" f"All {self.builder.team_count} teams have accepted the trade.\n" f"**{len(created_transactions)} transactions** have been created for **Week {next_week}**.\n\n" f"Trade ID: `{self.builder.trade_id}`" ) - # Clear the trade builder for team in self.builder.participating_teams: clear_trade_builder_by_team(team.id) @@ -591,81 +593,79 @@ class TradeAcceptanceView(discord.ui.View): except Exception as e: await interaction.followup.send( - f"āŒ Error finalizing trade: {str(e)}", - ephemeral=True + f"Error finalizing trade: {str(e)}", ephemeral=True ) async def create_trade_acceptance_embed(builder: TradeBuilder) -> discord.Embed: """Create embed showing trade details and acceptance status.""" embed = EmbedTemplate.create_base_embed( - title=f"šŸ“‹ Trade Pending Acceptance - {builder.trade.get_trade_summary()}", + title=f"Trade Pending Acceptance - {builder.trade.get_trade_summary()}", description="All participating teams must accept to complete the trade.", - color=EmbedColors.WARNING + color=EmbedColors.WARNING, ) - # Show participating teams - team_list = [f"• {team.abbrev} - {team.sname}" for team in builder.participating_teams] + team_list = [ + f"- {team.abbrev} - {team.sname}" for team in builder.participating_teams + ] embed.add_field( - name=f"šŸŸļø Participating Teams ({builder.team_count})", + name=f"Participating Teams ({builder.team_count})", value="\n".join(team_list), - inline=False + inline=False, ) - # Show cross-team moves if builder.trade.cross_team_moves: moves_text = "" for move in builder.trade.cross_team_moves[:10]: - moves_text += f"• {move.description}\n" + moves_text += f"- {move.description}\n" if len(builder.trade.cross_team_moves) > 10: moves_text += f"... and {len(builder.trade.cross_team_moves) - 10} more" embed.add_field( - name=f"šŸ”„ Player Exchanges ({len(builder.trade.cross_team_moves)})", + name=f"Player Exchanges ({len(builder.trade.cross_team_moves)})", value=moves_text, - inline=False + inline=False, ) - # Show supplementary moves if any if builder.trade.supplementary_moves: supp_text = "" for move in builder.trade.supplementary_moves[:5]: - supp_text += f"• {move.description}\n" + supp_text += f"- {move.description}\n" if len(builder.trade.supplementary_moves) > 5: supp_text += f"... and {len(builder.trade.supplementary_moves) - 5} more" embed.add_field( - name=f"āš™ļø Supplementary Moves ({len(builder.trade.supplementary_moves)})", + name=f"Supplementary Moves ({len(builder.trade.supplementary_moves)})", value=supp_text, - inline=False + inline=False, ) - # Show acceptance status status_lines = [] for team in builder.participating_teams: if team.id in builder.accepted_teams: - status_lines.append(f"āœ… **{team.abbrev}** - Accepted") + status_lines.append(f"**{team.abbrev}** - Accepted") else: - status_lines.append(f"ā³ **{team.abbrev}** - Pending") + status_lines.append(f"**{team.abbrev}** - Pending") embed.add_field( - name="šŸ“Š Acceptance Status", - value="\n".join(status_lines), - inline=False + name="Acceptance Status", value="\n".join(status_lines), inline=False ) - # Add footer - embed.set_footer(text=f"Trade ID: {builder.trade_id} • {len(builder.accepted_teams)}/{builder.team_count} teams accepted") + embed.set_footer( + text=f"Trade ID: {builder.trade_id} | {len(builder.accepted_teams)}/{builder.team_count} teams accepted" + ) return embed -async def create_trade_rejection_embed(builder: TradeBuilder, rejecting_team: Team) -> discord.Embed: +async def create_trade_rejection_embed( + builder: TradeBuilder, rejecting_team: Team +) -> discord.Embed: """Create embed showing trade was rejected.""" embed = EmbedTemplate.create_base_embed( - title=f"āŒ Trade Rejected - {builder.trade.get_trade_summary()}", + title=f"Trade Rejected - {builder.trade.get_trade_summary()}", description=f"**{rejecting_team.abbrev}** has rejected the trade.\n\n" - f"The trade has been moved back to **DRAFT** status.\n" - f"Teams can continue negotiating using `/trade` commands.", - color=EmbedColors.ERROR + f"The trade has been moved back to **DRAFT** status.\n" + f"Teams can continue negotiating using `/trade` commands.", + color=EmbedColors.ERROR, ) embed.set_footer(text=f"Trade ID: {builder.trade_id}") @@ -673,37 +673,33 @@ async def create_trade_rejection_embed(builder: TradeBuilder, rejecting_team: Te return embed -async def create_trade_complete_embed(builder: TradeBuilder, transaction_count: int, effective_week: int) -> discord.Embed: +async def create_trade_complete_embed( + builder: TradeBuilder, transaction_count: int, effective_week: int +) -> discord.Embed: """Create embed showing trade was completed.""" embed = EmbedTemplate.create_base_embed( - title=f"šŸŽ‰ Trade Complete! - {builder.trade.get_trade_summary()}", - description=f"All {builder.team_count} teams have accepted the trade!\n\n" - f"**{transaction_count} transactions** created for **Week {effective_week}**.", - color=EmbedColors.SUCCESS + title=f"Trade Complete - {builder.trade.get_trade_summary()}", + description=f"All {builder.team_count} teams have accepted the trade.\n\n" + f"**{transaction_count} transactions** created for **Week {effective_week}**.", + color=EmbedColors.SUCCESS, ) - # Show final acceptance status (all green) - status_lines = [f"āœ… **{team.abbrev}** - Accepted" for team in builder.participating_teams] - embed.add_field( - name="šŸ“Š Final Status", - value="\n".join(status_lines), - inline=False - ) + status_lines = [ + f"**{team.abbrev}** - Accepted" for team in builder.participating_teams + ] + embed.add_field(name="Final Status", value="\n".join(status_lines), inline=False) - # Show cross-team moves if builder.trade.cross_team_moves: moves_text = "" for move in builder.trade.cross_team_moves[:8]: - moves_text += f"• {move.description}\n" + moves_text += f"- {move.description}\n" if len(builder.trade.cross_team_moves) > 8: moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more" - embed.add_field( - name=f"šŸ”„ Player Exchanges", - value=moves_text, - inline=False - ) + embed.add_field(name="Player Exchanges", value=moves_text, inline=False) - embed.set_footer(text=f"Trade ID: {builder.trade_id} • Effective: Week {effective_week}") + embed.set_footer( + text=f"Trade ID: {builder.trade_id} | Effective: Week {effective_week}" + ) return embed @@ -718,7 +714,6 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed: Returns: Discord embed with current trade state """ - # Determine embed color based on trade status if builder.is_empty: color = EmbedColors.SECONDARY else: @@ -726,79 +721,79 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed: color = EmbedColors.SUCCESS if validation.is_legal else EmbedColors.WARNING embed = EmbedTemplate.create_base_embed( - title=f"šŸ“‹ Trade Builder - {builder.trade.get_trade_summary()}", - description=f"Build your multi-team trade", - color=color + title=f"Trade Builder - {builder.trade.get_trade_summary()}", + description="Build your multi-team trade", + color=color, ) - # Add participating teams section - team_list = [f"• {team.abbrev} - {team.sname}" for team in builder.participating_teams] + team_list = [ + f"- {team.abbrev} - {team.sname}" for team in builder.participating_teams + ] embed.add_field( - name=f"šŸŸļø Participating Teams ({builder.team_count})", + name=f"Participating Teams ({builder.team_count})", value="\n".join(team_list) if team_list else "*No teams yet*", - inline=False + inline=False, ) - # Add current moves section if builder.is_empty: embed.add_field( name="Current Moves", value="*No moves yet. Use the `/trade` commands to build your trade.*", - inline=False + inline=False, ) else: - # Show cross-team moves if builder.trade.cross_team_moves: moves_text = "" - for i, move in enumerate(builder.trade.cross_team_moves[:8], 1): # Limit display + for i, move in enumerate(builder.trade.cross_team_moves[:8], 1): moves_text += f"{i}. {move.description}\n" if len(builder.trade.cross_team_moves) > 8: moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more" embed.add_field( - name=f"šŸ”„ Player Exchanges ({len(builder.trade.cross_team_moves)})", + name=f"Player Exchanges ({len(builder.trade.cross_team_moves)})", value=moves_text, - inline=False + inline=False, ) - # Show supplementary moves if builder.trade.supplementary_moves: supp_text = "" - for i, move in enumerate(builder.trade.supplementary_moves[:5], 1): # Limit display + for i, move in enumerate(builder.trade.supplementary_moves[:5], 1): supp_text += f"{i}. {move.description}\n" if len(builder.trade.supplementary_moves) > 5: - supp_text += f"... and {len(builder.trade.supplementary_moves) - 5} more" + supp_text += ( + f"... and {len(builder.trade.supplementary_moves) - 5} more" + ) embed.add_field( - name=f"āš™ļø Supplementary Moves ({len(builder.trade.supplementary_moves)})", + name=f"Supplementary Moves ({len(builder.trade.supplementary_moves)})", value=supp_text, - inline=False + inline=False, ) - # Add quick validation summary validation = await builder.validate_trade() if validation.is_legal: - status_text = "āœ… Trade appears legal" + status_text = "Trade appears legal" else: error_count = len(validation.all_errors) - status_text = f"āŒ {error_count} error{'s' if error_count != 1 else ''} found" + status_text = f"{error_count} error{'s' if error_count != 1 else ''} found\n" + status_text += "\n".join(f"- {error}" for error in validation.all_errors) + if validation.all_suggestions: + status_text += "\n" + "\n".join( + f"- {s}" for s in validation.all_suggestions + ) + + embed.add_field(name="Quick Status", value=status_text, inline=False) embed.add_field( - name="šŸ” Quick Status", - value=status_text, - inline=False + name="Build Your Trade", + value="- `/trade add-player` - Add player exchanges\n- `/trade supplementary` - Add internal moves\n- `/trade add-team` - Add more teams", + inline=False, ) - # Add instructions for adding more moves - embed.add_field( - name="āž• Build Your Trade", - value="• `/trade add-player` - Add player exchanges\n• `/trade supplementary` - Add internal moves\n• `/trade add-team` - Add more teams", - inline=False + embed.set_footer( + text=f"Trade ID: {builder.trade_id} | Created: {datetime.fromisoformat(builder.trade.created_at).strftime('%H:%M:%S')}" ) - # Add footer with trade ID and timestamp - embed.set_footer(text=f"Trade ID: {builder.trade_id} • Created: {datetime.fromisoformat(builder.trade.created_at).strftime('%H:%M:%S')}") - - return embed \ No newline at end of file + return embed From 58fe9f22deda0341151d6b886dceb8a63795e1b5 Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Sun, 8 Mar 2026 11:25:31 -0500 Subject: [PATCH 3/7] fix: trade validation now checks against next week's projected roster validate_trade() was passing next_week=None to each team's validate_transaction(), which skipped load_existing_transactions() entirely. Trades were validated against the current roster only, ignoring pending /dropadd transactions for next week. Now auto-fetches current week from league_service and passes next_week=current_week+1, matching /dropadd validation behavior. Co-Authored-By: Claude Opus 4.6 --- services/trade_builder.py | 138 ++++++++++++++++++++++++++++---------- 1 file changed, 102 insertions(+), 36 deletions(-) diff --git a/services/trade_builder.py b/services/trade_builder.py index 879959f..8ab1d5d 100644 --- a/services/trade_builder.py +++ b/services/trade_builder.py @@ -3,6 +3,7 @@ Trade Builder Service Extends the TransactionBuilder to support multi-team trades and player exchanges. """ + import logging from typing import Dict, List, Optional, Set from datetime import datetime, timezone @@ -12,10 +13,14 @@ from config import get_config from models.trade import Trade, TradeMove, TradeStatus from models.team import Team, RosterType from models.player import Player -from services.transaction_builder import TransactionBuilder, RosterValidationResult, TransactionMove +from services.transaction_builder import ( + TransactionBuilder, + RosterValidationResult, + TransactionMove, +) from services.team_service import team_service -logger = logging.getLogger(f'{__name__}.TradeBuilder') +logger = logging.getLogger(f"{__name__}.TradeBuilder") class TradeValidationResult: @@ -52,7 +57,9 @@ class TradeValidationResult: suggestions.extend(validation.suggestions) return suggestions - def get_participant_validation(self, team_id: int) -> Optional[RosterValidationResult]: + def get_participant_validation( + self, team_id: int + ) -> Optional[RosterValidationResult]: """Get validation result for a specific team.""" return self.participant_validations.get(team_id) @@ -64,7 +71,12 @@ class TradeBuilder: Extends the functionality of TransactionBuilder to support trades between teams. """ - def __init__(self, initiated_by: int, initiating_team: Team, season: int = get_config().sba_season): + def __init__( + self, + initiated_by: int, + initiating_team: Team, + season: int = get_config().sba_season, + ): """ Initialize trade builder. @@ -79,7 +91,7 @@ class TradeBuilder: status=TradeStatus.DRAFT, initiated_by=initiated_by, created_at=datetime.now(timezone.utc).isoformat(), - season=season + season=season, ) # Add the initiating team as first participant @@ -91,7 +103,9 @@ class TradeBuilder: # Track which teams have accepted the trade (team_id -> True) self.accepted_teams: Set[int] = set() - logger.info(f"TradeBuilder initialized: {self.trade.trade_id} by user {initiated_by} for {initiating_team.abbrev}") + logger.info( + f"TradeBuilder initialized: {self.trade.trade_id} by user {initiated_by} for {initiating_team.abbrev}" + ) @property def trade_id(self) -> str: @@ -127,7 +141,11 @@ class TradeBuilder: @property def pending_teams(self) -> List[Team]: """Get list of teams that haven't accepted yet.""" - return [team for team in self.participating_teams if team.id not in self.accepted_teams] + return [ + team + for team in self.participating_teams + if team.id not in self.accepted_teams + ] def accept_trade(self, team_id: int) -> bool: """ @@ -140,7 +158,9 @@ class TradeBuilder: True if all teams have now accepted, False otherwise """ self.accepted_teams.add(team_id) - logger.info(f"Team {team_id} accepted trade {self.trade_id}. Accepted: {len(self.accepted_teams)}/{self.team_count}") + logger.info( + f"Team {team_id} accepted trade {self.trade_id}. Accepted: {len(self.accepted_teams)}/{self.team_count}" + ) return self.all_teams_accepted def reject_trade(self) -> None: @@ -160,7 +180,9 @@ class TradeBuilder: Returns: Dict mapping team_id to acceptance status (True/False) """ - return {team.id: team.id in self.accepted_teams for team in self.participating_teams} + return { + team.id: team.id in self.accepted_teams for team in self.participating_teams + } def has_team_accepted(self, team_id: int) -> bool: """Check if a specific team has accepted.""" @@ -184,7 +206,9 @@ class TradeBuilder: participant = self.trade.add_participant(team) # Create transaction builder for this team - self._team_builders[team.id] = TransactionBuilder(team, self.trade.initiated_by, self.trade.season) + self._team_builders[team.id] = TransactionBuilder( + team, self.trade.initiated_by, self.trade.season + ) # Register team in secondary index for multi-GM access trade_key = f"{self.trade.initiated_by}:trade" @@ -209,7 +233,10 @@ class TradeBuilder: # Check if team has moves - prevent removal if they do if participant.all_moves: - return False, f"{participant.team.abbrev} has moves in this trade and cannot be removed" + return ( + False, + f"{participant.team.abbrev} has moves in this trade and cannot be removed", + ) # Remove team removed = self.trade.remove_participant(team_id) @@ -229,7 +256,7 @@ class TradeBuilder: from_team: Team, to_team: Team, from_roster: RosterType, - to_roster: RosterType + to_roster: RosterType, ) -> tuple[bool, str]: """ Add a player move to the trade. @@ -246,7 +273,10 @@ class TradeBuilder: """ # Validate player is not from Free Agency if player.team_id == get_config().free_agent_team_id: - return False, f"Cannot add {player.name} from Free Agency. Players must be traded from teams within the organizations involved in the trade." + return ( + False, + f"Cannot add {player.name} from Free Agency. Players must be traded from teams within the organizations involved in the trade.", + ) # Validate player has a valid team assignment if not player.team_id: @@ -259,7 +289,10 @@ class TradeBuilder: # Check if player's team is in the same organization as from_team if not player_team.is_same_organization(from_team): - return False, f"{player.name} is on {player_team.abbrev}, they are not eligible to be added to the trade." + return ( + False, + f"{player.name} is on {player_team.abbrev}, they are not eligible to be added to the trade.", + ) # Ensure both teams are participating (check by organization for ML authority) from_participant = self.trade.get_participant_by_organization(from_team) @@ -274,7 +307,10 @@ class TradeBuilder: for participant in self.trade.participants: for existing_move in participant.all_moves: if existing_move.player.id == player.id: - return False, f"{player.name} is already involved in a move in this trade" + return ( + False, + f"{player.name} is already involved in a move in this trade", + ) # Create trade move trade_move = TradeMove( @@ -284,7 +320,7 @@ class TradeBuilder: from_team=from_team, to_team=to_team, source_team=from_team, - destination_team=to_team + destination_team=to_team, ) # Add to giving team's moves @@ -303,7 +339,7 @@ class TradeBuilder: from_roster=from_roster, to_roster=RosterType.FREE_AGENCY, # Conceptually leaving the org from_team=from_team, - to_team=None + to_team=None, ) # Move for receiving team (player joining) @@ -312,19 +348,23 @@ class TradeBuilder: from_roster=RosterType.FREE_AGENCY, # Conceptually joining from outside to_roster=to_roster, from_team=None, - to_team=to_team + to_team=to_team, ) # Add moves to respective builders # Skip pending transaction check for trades - they have their own validation workflow - from_success, from_error = await from_builder.add_move(from_move, check_pending_transactions=False) + from_success, from_error = await from_builder.add_move( + from_move, check_pending_transactions=False + ) if not from_success: # Remove from trade if builder failed from_participant.moves_giving.remove(trade_move) to_participant.moves_receiving.remove(trade_move) return False, f"Error adding move to {from_team.abbrev}: {from_error}" - to_success, to_error = await to_builder.add_move(to_move, check_pending_transactions=False) + to_success, to_error = await to_builder.add_move( + to_move, check_pending_transactions=False + ) if not to_success: # Rollback both if second failed from_builder.remove_move(player.id) @@ -332,15 +372,13 @@ class TradeBuilder: to_participant.moves_receiving.remove(trade_move) return False, f"Error adding move to {to_team.abbrev}: {to_error}" - logger.info(f"Added player move to trade {self.trade_id}: {trade_move.description}") + logger.info( + f"Added player move to trade {self.trade_id}: {trade_move.description}" + ) return True, "" async def add_supplementary_move( - self, - team: Team, - player: Player, - from_roster: RosterType, - to_roster: RosterType + self, team: Team, player: Player, from_roster: RosterType, to_roster: RosterType ) -> tuple[bool, str]: """ Add a supplementary move (internal organizational move) for roster legality. @@ -366,7 +404,7 @@ class TradeBuilder: from_team=team, to_team=team, source_team=team, - destination_team=team + destination_team=team, ) # Add to participant's supplementary moves @@ -379,16 +417,20 @@ class TradeBuilder: from_roster=from_roster, to_roster=to_roster, from_team=team, - to_team=team + to_team=team, ) # Skip pending transaction check for trade supplementary moves - success, error = await builder.add_move(trans_move, check_pending_transactions=False) + success, error = await builder.add_move( + trans_move, check_pending_transactions=False + ) if not success: participant.supplementary_moves.remove(supp_move) return False, error - logger.info(f"Added supplementary move for {team.abbrev}: {supp_move.description}") + logger.info( + f"Added supplementary move for {team.abbrev}: {supp_move.description}" + ) return True, "" async def remove_move(self, player_id: int) -> tuple[bool, str]: @@ -432,21 +474,41 @@ class TradeBuilder: for builder in self._team_builders.values(): builder.remove_move(player_id) - logger.info(f"Removed move from trade {self.trade_id}: {removed_move.description}") + logger.info( + f"Removed move from trade {self.trade_id}: {removed_move.description}" + ) return True, "" - async def validate_trade(self, next_week: Optional[int] = None) -> TradeValidationResult: + async def validate_trade( + self, next_week: Optional[int] = None + ) -> TradeValidationResult: """ Validate the entire trade including all teams' roster legality. + Validates against next week's projected roster (current roster + pending + transactions), matching the behavior of /dropadd validation. + Args: - next_week: Week to validate for (optional) + next_week: Week to validate for (auto-fetched if not provided) Returns: TradeValidationResult with comprehensive validation """ result = TradeValidationResult() + # Auto-fetch next week so validation includes pending transactions + if next_week is None: + try: + from services.league_service import league_service + + current_state = await league_service.get_current_state() + next_week = (current_state.week + 1) if current_state else 1 + except Exception as e: + logger.warning( + f"Could not determine next week for trade validation: {e}" + ) + next_week = None + # Validate trade structure is_balanced, balance_errors = self.trade.validate_trade_balance() if not is_balanced: @@ -472,13 +534,17 @@ class TradeBuilder: if self.team_count < 2: result.trade_suggestions.append("Add another team to create a trade") - logger.debug(f"Trade validation for {self.trade_id}: Legal={result.is_legal}, Errors={len(result.all_errors)}") + logger.debug( + f"Trade validation for {self.trade_id}: Legal={result.is_legal}, Errors={len(result.all_errors)}" + ) return result def _get_or_create_builder(self, team: Team) -> TransactionBuilder: """Get or create a transaction builder for a team.""" if team.id not in self._team_builders: - self._team_builders[team.id] = TransactionBuilder(team, self.trade.initiated_by, self.trade.season) + self._team_builders[team.id] = TransactionBuilder( + team, self.trade.initiated_by, self.trade.season + ) return self._team_builders[team.id] def clear_trade(self) -> None: @@ -592,4 +658,4 @@ def clear_trade_builder_by_team(team_id: int) -> bool: def get_active_trades() -> Dict[str, TradeBuilder]: """Get all active trade builders (for debugging/admin purposes).""" - return _active_trade_builders.copy() \ No newline at end of file + return _active_trade_builders.copy() From 9379ba587a7e99503efdc9b6667a9c8cbe4b559e Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Mon, 9 Mar 2026 17:25:08 -0500 Subject: [PATCH 4/7] fix: add trailing slashes to API URLs to prevent 307 redirects dropping POST bodies The FastAPI server returns 307 redirects for URLs without trailing slashes. aiohttp follows these redirects but converts POST to GET, silently dropping the request body. This caused play-by-play and decision data from /submit-scorecard to never be persisted to the database despite the API returning success. Co-Authored-By: Claude Opus 4.6 --- api/client.py | 4 + tests/test_api_client.py | 365 +++++++++++++++++++-------------------- 2 files changed, 183 insertions(+), 186 deletions(-) diff --git a/api/client.py b/api/client.py index 4ce5425..b262fb9 100644 --- a/api/client.py +++ b/api/client.py @@ -88,6 +88,10 @@ class APIClient: encoded_id = quote(str(object_id), safe="") path += f"/{encoded_id}" + # Ensure trailing slash to prevent 307 redirects that drop POST bodies + if not path.endswith("/"): + path += "/" + return urljoin(self.base_url.rstrip("/") + "/", path) def _add_params(self, url: str, params: Optional[List[tuple]] = None) -> str: diff --git a/tests/test_api_client.py b/tests/test_api_client.py index 14edf0f..b0b16cc 100644 --- a/tests/test_api_client.py +++ b/tests/test_api_client.py @@ -1,18 +1,24 @@ """ API client tests using aioresponses for clean HTTP mocking """ + import pytest import asyncio from unittest.mock import MagicMock, patch from aioresponses import aioresponses -from api.client import APIClient, get_api_client, get_global_client, cleanup_global_client +from api.client import ( + APIClient, + get_api_client, + get_global_client, + cleanup_global_client, +) from exceptions import APIException class TestAPIClientWithAioresponses: """Test API client with aioresponses for HTTP mocking.""" - + @pytest.fixture def mock_config(self): """Mock configuration for testing.""" @@ -20,202 +26,185 @@ class TestAPIClientWithAioresponses: config.db_url = "https://api.example.com" config.api_token = "test-token" return config - + @pytest.fixture def api_client(self, mock_config): """Create API client with mocked config.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): return APIClient() - + @pytest.mark.asyncio async def test_get_request_success(self, api_client): """Test successful GET request.""" expected_data = {"id": 1, "name": "Test Player"} - + with aioresponses() as m: m.get( - "https://api.example.com/v3/players/1", + "https://api.example.com/v3/players/1/", payload=expected_data, - status=200 + status=200, ) - + result = await api_client.get("players", object_id=1) - + assert result == expected_data - + @pytest.mark.asyncio async def test_get_request_404(self, api_client): """Test GET request returning 404.""" with aioresponses() as m: - m.get( - "https://api.example.com/v3/players/999", - status=404 - ) - + m.get("https://api.example.com/v3/players/999/", status=404) + result = await api_client.get("players", object_id=999) - + assert result is None - + @pytest.mark.asyncio async def test_get_request_401_auth_error(self, api_client): """Test GET request with authentication error.""" with aioresponses() as m: - m.get( - "https://api.example.com/v3/players", - status=401 - ) - + m.get("https://api.example.com/v3/players/", status=401) + with pytest.raises(APIException, match="Authentication failed"): await api_client.get("players") - + @pytest.mark.asyncio async def test_get_request_403_forbidden(self, api_client): """Test GET request with forbidden error.""" with aioresponses() as m: - m.get( - "https://api.example.com/v3/players", - status=403 - ) - + m.get("https://api.example.com/v3/players/", status=403) + with pytest.raises(APIException, match="Access forbidden"): await api_client.get("players") - + @pytest.mark.asyncio async def test_get_request_500_server_error(self, api_client): """Test GET request with server error.""" with aioresponses() as m: m.get( - "https://api.example.com/v3/players", + "https://api.example.com/v3/players/", status=500, - body="Internal Server Error" + body="Internal Server Error", ) - - with pytest.raises(APIException, match="API request failed with status 500"): + + with pytest.raises( + APIException, match="API request failed with status 500" + ): await api_client.get("players") - + @pytest.mark.asyncio async def test_get_request_with_params(self, api_client): """Test GET request with query parameters.""" expected_data = {"count": 2, "players": [{"id": 1}, {"id": 2}]} - + with aioresponses() as m: m.get( - "https://api.example.com/v3/players?team_id=5&season=12", + "https://api.example.com/v3/players/?team_id=5&season=12", payload=expected_data, - status=200 + status=200, ) - - result = await api_client.get("players", params=[("team_id", "5"), ("season", "12")]) - + + result = await api_client.get( + "players", params=[("team_id", "5"), ("season", "12")] + ) + assert result == expected_data - + @pytest.mark.asyncio async def test_post_request_success(self, api_client): """Test successful POST request.""" input_data = {"name": "New Player", "position": "C"} expected_response = {"id": 1, "name": "New Player", "position": "C"} - + with aioresponses() as m: m.post( - "https://api.example.com/v3/players", + "https://api.example.com/v3/players/", payload=expected_response, - status=201 + status=201, ) - + result = await api_client.post("players", input_data) - + assert result == expected_response - + @pytest.mark.asyncio async def test_post_request_400_error(self, api_client): """Test POST request with validation error.""" input_data = {"invalid": "data"} - + with aioresponses() as m: m.post( - "https://api.example.com/v3/players", - status=400, - body="Invalid data" + "https://api.example.com/v3/players/", status=400, body="Invalid data" ) - - with pytest.raises(APIException, match="POST request failed with status 400"): + + with pytest.raises( + APIException, match="POST request failed with status 400" + ): await api_client.post("players", input_data) - + @pytest.mark.asyncio async def test_put_request_success(self, api_client): """Test successful PUT request.""" update_data = {"name": "Updated Player"} expected_response = {"id": 1, "name": "Updated Player"} - + with aioresponses() as m: m.put( - "https://api.example.com/v3/players/1", + "https://api.example.com/v3/players/1/", payload=expected_response, - status=200 + status=200, ) - + result = await api_client.put("players", update_data, object_id=1) - + assert result == expected_response - + @pytest.mark.asyncio async def test_put_request_404(self, api_client): """Test PUT request with 404.""" update_data = {"name": "Updated Player"} - + with aioresponses() as m: - m.put( - "https://api.example.com/v3/players/999", - status=404 - ) - + m.put("https://api.example.com/v3/players/999/", status=404) + result = await api_client.put("players", update_data, object_id=999) - + assert result is None - + @pytest.mark.asyncio async def test_delete_request_success(self, api_client): """Test successful DELETE request.""" with aioresponses() as m: - m.delete( - "https://api.example.com/v3/players/1", - status=204 - ) - + m.delete("https://api.example.com/v3/players/1/", status=204) + result = await api_client.delete("players", object_id=1) - + assert result is True - + @pytest.mark.asyncio async def test_delete_request_404(self, api_client): """Test DELETE request with 404.""" with aioresponses() as m: - m.delete( - "https://api.example.com/v3/players/999", - status=404 - ) - + m.delete("https://api.example.com/v3/players/999/", status=404) + result = await api_client.delete("players", object_id=999) - + assert result is False - + @pytest.mark.asyncio async def test_delete_request_200_success(self, api_client): """Test DELETE request with 200 success.""" with aioresponses() as m: - m.delete( - "https://api.example.com/v3/players/1", - status=200 - ) - + m.delete("https://api.example.com/v3/players/1/", status=200) + result = await api_client.delete("players", object_id=1) - + assert result is True class TestAPIClientHelpers: """Test API client helper functions.""" - + @pytest.fixture def mock_config(self): """Mock configuration for testing.""" @@ -223,49 +212,49 @@ class TestAPIClientHelpers: config.db_url = "https://api.example.com" config.api_token = "test-token" return config - + @pytest.mark.asyncio async def test_get_api_client_context_manager(self, mock_config): """Test get_api_client context manager.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): with aioresponses() as m: m.get( - "https://api.example.com/v3/test", + "https://api.example.com/v3/test/", payload={"success": True}, - status=200 + status=200, ) - + async with get_api_client() as client: assert isinstance(client, APIClient) result = await client.get("test") assert result == {"success": True} - + @pytest.mark.asyncio async def test_global_client_management(self, mock_config): """Test global client getter and cleanup.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): # Get global client client1 = await get_global_client() client2 = await get_global_client() - + # Should return same instance assert client1 is client2 assert isinstance(client1, APIClient) - + # Test cleanup await cleanup_global_client() - + # New client should be different instance client3 = await get_global_client() assert client3 is not client1 - + # Clean up for other tests await cleanup_global_client() class TestIntegrationScenarios: """Test realistic integration scenarios.""" - + @pytest.fixture def mock_config(self): """Mock configuration for testing.""" @@ -273,11 +262,11 @@ class TestIntegrationScenarios: config.db_url = "https://api.example.com" config.api_token = "test-token" return config - + @pytest.mark.asyncio async def test_player_retrieval_with_team_lookup(self, mock_config): """Test realistic scenario: get player with team data.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): with aioresponses() as m: # Mock player data response player_data = { @@ -287,43 +276,41 @@ class TestIntegrationScenarios: "season": 12, "team_id": 5, "image": "https://example.com/player1.jpg", - "pos_1": "C" + "pos_1": "C", } m.get( - "https://api.example.com/v3/players/1", + "https://api.example.com/v3/players/1/", payload=player_data, - status=200 + status=200, ) - + # Mock team data response team_data = { "id": 5, "abbrev": "TST", "sname": "Test Team", "lname": "Test Team Full Name", - "season": 12 + "season": 12, } m.get( - "https://api.example.com/v3/teams/5", - payload=team_data, - status=200 + "https://api.example.com/v3/teams/5/", payload=team_data, status=200 ) - + client = APIClient() - + # Get player player = await client.get("players", object_id=1) assert player["name"] == "Test Player" assert player["team_id"] == 5 - + # Get team for player team = await client.get("teams", object_id=player["team_id"]) assert team["sname"] == "Test Team" - + @pytest.mark.asyncio async def test_api_response_format_handling(self, mock_config): """Test handling of the API's count + list format.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): with aioresponses() as m: # Mock API response with count format api_response = { @@ -336,7 +323,7 @@ class TestIntegrationScenarios: "season": 12, "team_id": 5, "image": "https://example.com/player1.jpg", - "pos_1": "C" + "pos_1": "C", }, { "id": 2, @@ -345,93 +332,93 @@ class TestIntegrationScenarios: "season": 12, "team_id": 6, "image": "https://example.com/player2.jpg", - "pos_1": "1B" - } - ] + "pos_1": "1B", + }, + ], } - + m.get( - "https://api.example.com/v3/players?team_id=5", + "https://api.example.com/v3/players/?team_id=5", payload=api_response, - status=200 + status=200, ) - + client = APIClient() result = await client.get("players", params=[("team_id", "5")]) - + assert result["count"] == 25 assert len(result["players"]) == 2 assert result["players"][0]["name"] == "Player 1" - + @pytest.mark.asyncio async def test_error_recovery_scenarios(self, mock_config): """Test error handling and recovery.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): with aioresponses() as m: # First request fails with 500 m.get( - "https://api.example.com/v3/players/1", + "https://api.example.com/v3/players/1/", status=500, - body="Internal Server Error" + body="Internal Server Error", ) - + # Second request succeeds m.get( - "https://api.example.com/v3/players/2", + "https://api.example.com/v3/players/2/", payload={"id": 2, "name": "Working Player"}, - status=200 + status=200, ) - + client = APIClient() - + # First request should raise exception with pytest.raises(APIException, match="API request failed"): await client.get("players", object_id=1) - + # Second request should work fine result = await client.get("players", object_id=2) assert result["name"] == "Working Player" - + # Client should still be functional await client.close() - + @pytest.mark.asyncio async def test_concurrent_requests(self, mock_config): """Test multiple concurrent requests.""" import asyncio - - with patch('api.client.get_config', return_value=mock_config): + + with patch("api.client.get_config", return_value=mock_config): with aioresponses() as m: # Mock multiple endpoints for i in range(1, 4): m.get( - f"https://api.example.com/v3/players/{i}", + f"https://api.example.com/v3/players/{i}/", payload={"id": i, "name": f"Player {i}"}, - status=200 + status=200, ) - + client = APIClient() - + # Make concurrent requests tasks = [ client.get("players", object_id=1), client.get("players", object_id=2), - client.get("players", object_id=3) + client.get("players", object_id=3), ] - + results = await asyncio.gather(*tasks) - + assert len(results) == 3 assert results[0]["name"] == "Player 1" assert results[1]["name"] == "Player 2" assert results[2]["name"] == "Player 3" - + await client.close() class TestAPIClientCoverageExtras: """Additional coverage tests for API client edge cases.""" - + @pytest.fixture def mock_config(self): """Mock configuration for testing.""" @@ -439,98 +426,104 @@ class TestAPIClientCoverageExtras: config.db_url = "https://api.example.com" config.api_token = "test-token" return config - + @pytest.mark.asyncio async def test_global_client_cleanup_when_none(self): """Test cleanup when no global client exists.""" # Ensure no global client exists await cleanup_global_client() - + # Should not raise error await cleanup_global_client() - + @pytest.mark.asyncio async def test_url_building_edge_cases(self, mock_config): """Test URL building with various edge cases.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): client = APIClient() - + # Test trailing slash handling client.base_url = "https://api.example.com/" url = client._build_url("players") - assert url == "https://api.example.com/v3/players" - assert "//" not in url.replace("https://", "") - + assert url == "https://api.example.com/v3/players/" + assert "//" not in url.replace("https://", "").replace("//", "") + @pytest.mark.asyncio async def test_parameter_handling_edge_cases(self, mock_config): """Test parameter handling with various scenarios.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): client = APIClient() - + # Test with existing query string - url = client._add_params("https://example.com/api?existing=true", [("new", "param")]) + url = client._add_params( + "https://example.com/api?existing=true", [("new", "param")] + ) assert url == "https://example.com/api?existing=true&new=param" - + # Test with no parameters url = client._add_params("https://example.com/api") assert url == "https://example.com/api" - + @pytest.mark.asyncio async def test_timeout_error_handling(self, mock_config): """Test timeout error handling using aioresponses.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): client = APIClient() - + # Test timeout using aioresponses exception parameter with aioresponses() as m: m.get( - "https://api.example.com/v3/players", - exception=asyncio.TimeoutError("Request timed out") + "https://api.example.com/v3/players/", + exception=asyncio.TimeoutError("Request timed out"), ) - - with pytest.raises(APIException, match="API call failed.*Request timed out"): + + with pytest.raises( + APIException, match="API call failed.*Request timed out" + ): await client.get("players") - + await client.close() - + @pytest.mark.asyncio async def test_generic_exception_handling(self, mock_config): """Test generic exception handling.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): client = APIClient() - + # Test generic exception with aioresponses() as m: m.get( - "https://api.example.com/v3/players", - exception=Exception("Generic error") + "https://api.example.com/v3/players/", + exception=Exception("Generic error"), ) - - with pytest.raises(APIException, match="API call failed.*Generic error"): + + with pytest.raises( + APIException, match="API call failed.*Generic error" + ): await client.get("players") - + await client.close() - + @pytest.mark.asyncio async def test_session_closed_handling(self, mock_config): """Test handling of closed session.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): # Test that the client recreates session when needed with aioresponses() as m: m.get( - "https://api.example.com/v3/players", + "https://api.example.com/v3/players/", payload={"success": True}, - status=200 + status=200, ) - + client = APIClient() - + # Close the session manually await client._ensure_session() await client._session.close() - + # Client should recreate session and work fine result = await client.get("players") assert result == {"success": True} - - await client.close() \ No newline at end of file + + await client.close() From f6a25aa16d046bbca9df45187e3fce3ca9cd9fee Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Mon, 9 Mar 2026 17:50:58 -0500 Subject: [PATCH 5/7] fix: use targeted trailing slashes instead of universal (hotfix) Reverts universal trailing slash in _build_url which broke custom_commands endpoints (401 on /execute/). Instead, add trailing slashes only to the two batch POST endpoints (plays/, decisions/) that need them to avoid 307 redirects dropping request bodies. Co-Authored-By: Claude Opus 4.6 --- api/client.py | 4 --- services/decision_service.py | 39 ++++++++++++++-------------- services/play_service.py | 37 ++++++++++++-------------- tests/test_api_client.py | 50 ++++++++++++++++++------------------ 4 files changed, 61 insertions(+), 69 deletions(-) diff --git a/api/client.py b/api/client.py index b262fb9..4ce5425 100644 --- a/api/client.py +++ b/api/client.py @@ -88,10 +88,6 @@ class APIClient: encoded_id = quote(str(object_id), safe="") path += f"/{encoded_id}" - # Ensure trailing slash to prevent 307 redirects that drop POST bodies - if not path.endswith("/"): - path += "/" - return urljoin(self.base_url.rstrip("/") + "/", path) def _add_params(self, url: str, params: Optional[List[tuple]] = None) -> str: diff --git a/services/decision_service.py b/services/decision_service.py index e101d0e..11a36b5 100644 --- a/services/decision_service.py +++ b/services/decision_service.py @@ -3,6 +3,7 @@ Decision Service Manages pitching decision operations for game submission. """ + from typing import List, Dict, Any, Optional, Tuple from utils.logging import get_contextual_logger @@ -16,17 +17,14 @@ class DecisionService: def __init__(self): """Initialize decision service.""" - self.logger = get_contextual_logger(f'{__name__}.DecisionService') + self.logger = get_contextual_logger(f"{__name__}.DecisionService") self._get_client = get_global_client async def get_client(self): """Get the API client.""" return await self._get_client() - async def create_decisions_batch( - self, - decisions: List[Dict[str, Any]] - ) -> bool: + async def create_decisions_batch(self, decisions: List[Dict[str, Any]]) -> bool: """ POST batch of decisions to /decisions endpoint. @@ -42,8 +40,10 @@ class DecisionService: try: client = await self.get_client() - payload = {'decisions': decisions} - await client.post('decisions', payload) + payload = {"decisions": decisions} + # Trailing slash required: without it, the server returns a 307 redirect + # and aiohttp drops the POST body when following the redirect + await client.post("decisions/", payload) self.logger.info(f"Created {len(decisions)} decisions") return True @@ -70,7 +70,7 @@ class DecisionService: """ try: client = await self.get_client() - await client.delete(f'decisions/game/{game_id}') + await client.delete(f"decisions/game/{game_id}") self.logger.info(f"Deleted decisions for game {game_id}") return True @@ -80,9 +80,10 @@ class DecisionService: raise APIException(f"Failed to delete decisions: {e}") async def find_winning_losing_pitchers( - self, - decisions_data: List[Dict[str, Any]] - ) -> Tuple[Optional[Player], Optional[Player], Optional[Player], List[Player], List[Player]]: + self, decisions_data: List[Dict[str, Any]] + ) -> Tuple[ + Optional[Player], Optional[Player], Optional[Player], List[Player], List[Player] + ]: """ Extract WP, LP, SV, Holds, Blown Saves from decisions list and fetch Player objects. @@ -110,17 +111,17 @@ class DecisionService: # First pass: Extract IDs for decision in decisions_data: - pitcher_id = int(decision.get('pitcher_id', 0)) + pitcher_id = int(decision.get("pitcher_id", 0)) - if int(decision.get('win', 0)) == 1: + if int(decision.get("win", 0)) == 1: wp_id = pitcher_id - if int(decision.get('loss', 0)) == 1: + if int(decision.get("loss", 0)) == 1: lp_id = pitcher_id - if int(decision.get('is_save', 0)) == 1: + if int(decision.get("is_save", 0)) == 1: sv_id = pitcher_id - if int(decision.get('hold', 0)) == 1: + if int(decision.get("hold", 0)) == 1: hold_ids.append(pitcher_id) - if int(decision.get('b_save', 0)) == 1: + if int(decision.get("b_save", 0)) == 1: bsv_ids.append(pitcher_id) # Second pass: Fetch Player objects @@ -154,9 +155,9 @@ class DecisionService: """ error_str = str(error) - if 'Player ID' in error_str and 'not found' in error_str: + if "Player ID" in error_str and "not found" in error_str: return "Invalid pitcher ID in decision data." - elif 'Game ID' in error_str and 'not found' in error_str: + elif "Game ID" in error_str and "not found" in error_str: return "Game not found for decisions." else: return f"Error submitting decisions: {error_str}" diff --git a/services/play_service.py b/services/play_service.py index 7b08bf6..cdaf293 100644 --- a/services/play_service.py +++ b/services/play_service.py @@ -3,6 +3,7 @@ Play Service Manages play-by-play data operations for game submission. """ + from typing import List, Dict, Any from utils.logging import get_contextual_logger @@ -16,7 +17,7 @@ class PlayService: def __init__(self): """Initialize play service.""" - self.logger = get_contextual_logger(f'{__name__}.PlayService') + self.logger = get_contextual_logger(f"{__name__}.PlayService") self._get_client = get_global_client async def get_client(self): @@ -39,8 +40,10 @@ class PlayService: try: client = await self.get_client() - payload = {'plays': plays} - response = await client.post('plays', payload) + payload = {"plays": plays} + # Trailing slash required: without it, the server returns a 307 redirect + # and aiohttp drops the POST body when following the redirect + response = await client.post("plays/", payload) self.logger.info(f"Created {len(plays)} plays") return True @@ -68,7 +71,7 @@ class PlayService: """ try: client = await self.get_client() - response = await client.delete(f'plays/game/{game_id}') + response = await client.delete(f"plays/game/{game_id}") self.logger.info(f"Deleted plays for game {game_id}") return True @@ -77,11 +80,7 @@ class PlayService: self.logger.error(f"Failed to delete plays for game {game_id}: {e}") raise APIException(f"Failed to delete plays: {e}") - async def get_top_plays_by_wpa( - self, - game_id: int, - limit: int = 3 - ) -> List[Play]: + async def get_top_plays_by_wpa(self, game_id: int, limit: int = 3) -> List[Play]: """ Get top plays by WPA (absolute value) for key plays display. @@ -95,19 +94,15 @@ class PlayService: try: client = await self.get_client() - params = [ - ('game_id', game_id), - ('sort', 'wpa-desc'), - ('limit', limit) - ] + params = [("game_id", game_id), ("sort", "wpa-desc"), ("limit", limit)] - response = await client.get('plays', params=params) + response = await client.get("plays", params=params) - if not response or 'plays' not in response: - self.logger.info(f'No plays found for game ID {game_id}') + if not response or "plays" not in response: + self.logger.info(f"No plays found for game ID {game_id}") return [] - plays = [Play.from_api_data(p) for p in response['plays']] + plays = [Play.from_api_data(p) for p in response["plays"]] self.logger.debug(f"Retrieved {len(plays)} top plays for game {game_id}") return plays @@ -129,11 +124,11 @@ class PlayService: error_str = str(error) # Common error patterns - if 'Player ID' in error_str and 'not found' in error_str: + if "Player ID" in error_str and "not found" in error_str: return "Invalid player ID in scorecard data. Please check player IDs." - elif 'Game ID' in error_str and 'not found' in error_str: + elif "Game ID" in error_str and "not found" in error_str: return "Game not found in database. Please contact an admin." - elif 'validation' in error_str.lower(): + elif "validation" in error_str.lower(): return f"Data validation error: {error_str}" else: return f"Error submitting plays: {error_str}" diff --git a/tests/test_api_client.py b/tests/test_api_client.py index b0b16cc..c2ca5b2 100644 --- a/tests/test_api_client.py +++ b/tests/test_api_client.py @@ -40,7 +40,7 @@ class TestAPIClientWithAioresponses: with aioresponses() as m: m.get( - "https://api.example.com/v3/players/1/", + "https://api.example.com/v3/players/1", payload=expected_data, status=200, ) @@ -53,7 +53,7 @@ class TestAPIClientWithAioresponses: async def test_get_request_404(self, api_client): """Test GET request returning 404.""" with aioresponses() as m: - m.get("https://api.example.com/v3/players/999/", status=404) + m.get("https://api.example.com/v3/players/999", status=404) result = await api_client.get("players", object_id=999) @@ -63,7 +63,7 @@ class TestAPIClientWithAioresponses: async def test_get_request_401_auth_error(self, api_client): """Test GET request with authentication error.""" with aioresponses() as m: - m.get("https://api.example.com/v3/players/", status=401) + m.get("https://api.example.com/v3/players", status=401) with pytest.raises(APIException, match="Authentication failed"): await api_client.get("players") @@ -72,7 +72,7 @@ class TestAPIClientWithAioresponses: async def test_get_request_403_forbidden(self, api_client): """Test GET request with forbidden error.""" with aioresponses() as m: - m.get("https://api.example.com/v3/players/", status=403) + m.get("https://api.example.com/v3/players", status=403) with pytest.raises(APIException, match="Access forbidden"): await api_client.get("players") @@ -82,7 +82,7 @@ class TestAPIClientWithAioresponses: """Test GET request with server error.""" with aioresponses() as m: m.get( - "https://api.example.com/v3/players/", + "https://api.example.com/v3/players", status=500, body="Internal Server Error", ) @@ -99,7 +99,7 @@ class TestAPIClientWithAioresponses: with aioresponses() as m: m.get( - "https://api.example.com/v3/players/?team_id=5&season=12", + "https://api.example.com/v3/players?team_id=5&season=12", payload=expected_data, status=200, ) @@ -118,7 +118,7 @@ class TestAPIClientWithAioresponses: with aioresponses() as m: m.post( - "https://api.example.com/v3/players/", + "https://api.example.com/v3/players", payload=expected_response, status=201, ) @@ -134,7 +134,7 @@ class TestAPIClientWithAioresponses: with aioresponses() as m: m.post( - "https://api.example.com/v3/players/", status=400, body="Invalid data" + "https://api.example.com/v3/players", status=400, body="Invalid data" ) with pytest.raises( @@ -150,7 +150,7 @@ class TestAPIClientWithAioresponses: with aioresponses() as m: m.put( - "https://api.example.com/v3/players/1/", + "https://api.example.com/v3/players/1", payload=expected_response, status=200, ) @@ -165,7 +165,7 @@ class TestAPIClientWithAioresponses: update_data = {"name": "Updated Player"} with aioresponses() as m: - m.put("https://api.example.com/v3/players/999/", status=404) + m.put("https://api.example.com/v3/players/999", status=404) result = await api_client.put("players", update_data, object_id=999) @@ -175,7 +175,7 @@ class TestAPIClientWithAioresponses: async def test_delete_request_success(self, api_client): """Test successful DELETE request.""" with aioresponses() as m: - m.delete("https://api.example.com/v3/players/1/", status=204) + m.delete("https://api.example.com/v3/players/1", status=204) result = await api_client.delete("players", object_id=1) @@ -185,7 +185,7 @@ class TestAPIClientWithAioresponses: async def test_delete_request_404(self, api_client): """Test DELETE request with 404.""" with aioresponses() as m: - m.delete("https://api.example.com/v3/players/999/", status=404) + m.delete("https://api.example.com/v3/players/999", status=404) result = await api_client.delete("players", object_id=999) @@ -195,7 +195,7 @@ class TestAPIClientWithAioresponses: async def test_delete_request_200_success(self, api_client): """Test DELETE request with 200 success.""" with aioresponses() as m: - m.delete("https://api.example.com/v3/players/1/", status=200) + m.delete("https://api.example.com/v3/players/1", status=200) result = await api_client.delete("players", object_id=1) @@ -219,7 +219,7 @@ class TestAPIClientHelpers: with patch("api.client.get_config", return_value=mock_config): with aioresponses() as m: m.get( - "https://api.example.com/v3/test/", + "https://api.example.com/v3/test", payload={"success": True}, status=200, ) @@ -279,7 +279,7 @@ class TestIntegrationScenarios: "pos_1": "C", } m.get( - "https://api.example.com/v3/players/1/", + "https://api.example.com/v3/players/1", payload=player_data, status=200, ) @@ -293,7 +293,7 @@ class TestIntegrationScenarios: "season": 12, } m.get( - "https://api.example.com/v3/teams/5/", payload=team_data, status=200 + "https://api.example.com/v3/teams/5", payload=team_data, status=200 ) client = APIClient() @@ -338,7 +338,7 @@ class TestIntegrationScenarios: } m.get( - "https://api.example.com/v3/players/?team_id=5", + "https://api.example.com/v3/players?team_id=5", payload=api_response, status=200, ) @@ -357,14 +357,14 @@ class TestIntegrationScenarios: with aioresponses() as m: # First request fails with 500 m.get( - "https://api.example.com/v3/players/1/", + "https://api.example.com/v3/players/1", status=500, body="Internal Server Error", ) # Second request succeeds m.get( - "https://api.example.com/v3/players/2/", + "https://api.example.com/v3/players/2", payload={"id": 2, "name": "Working Player"}, status=200, ) @@ -392,7 +392,7 @@ class TestIntegrationScenarios: # Mock multiple endpoints for i in range(1, 4): m.get( - f"https://api.example.com/v3/players/{i}/", + f"https://api.example.com/v3/players/{i}", payload={"id": i, "name": f"Player {i}"}, status=200, ) @@ -445,8 +445,8 @@ class TestAPIClientCoverageExtras: # Test trailing slash handling client.base_url = "https://api.example.com/" url = client._build_url("players") - assert url == "https://api.example.com/v3/players/" - assert "//" not in url.replace("https://", "").replace("//", "") + assert url == "https://api.example.com/v3/players" + assert "//" not in url.replace("https://", "") @pytest.mark.asyncio async def test_parameter_handling_edge_cases(self, mock_config): @@ -473,7 +473,7 @@ class TestAPIClientCoverageExtras: # Test timeout using aioresponses exception parameter with aioresponses() as m: m.get( - "https://api.example.com/v3/players/", + "https://api.example.com/v3/players", exception=asyncio.TimeoutError("Request timed out"), ) @@ -493,7 +493,7 @@ class TestAPIClientCoverageExtras: # Test generic exception with aioresponses() as m: m.get( - "https://api.example.com/v3/players/", + "https://api.example.com/v3/players", exception=Exception("Generic error"), ) @@ -511,7 +511,7 @@ class TestAPIClientCoverageExtras: # Test that the client recreates session when needed with aioresponses() as m: m.get( - "https://api.example.com/v3/players/", + "https://api.example.com/v3/players", payload={"success": True}, status=200, ) From ba55ed3109bc0dd2f4c16b8fd5a13f43e0679fda Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Mon, 9 Mar 2026 19:34:36 -0500 Subject: [PATCH 6/7] fix: add trailing slashes to all collection POST calls Ensures all client.post() calls to collection endpoints include trailing slashes, matching the standardized database API routes. Covers BaseService.create(), TransactionService, InjuryService, and DraftListService POST calls. Co-Authored-By: Claude Opus 4.6 --- services/base_service.py | 2 +- services/draft_list_service.py | 222 ++++++++++++++-------------- services/injury_service.py | 83 ++++++----- services/transaction_service.py | 2 +- tests/test_services_base_service.py | 197 ++++++++++++------------ 5 files changed, 252 insertions(+), 254 deletions(-) diff --git a/services/base_service.py b/services/base_service.py index faf0dba..7046efb 100644 --- a/services/base_service.py +++ b/services/base_service.py @@ -245,7 +245,7 @@ class BaseService(Generic[T]): """ try: client = await self.get_client() - response = await client.post(self.endpoint, model_data) + response = await client.post(f"{self.endpoint}/", model_data) if not response: logger.warning(f"No response from {self.model_class.__name__} creation") diff --git a/services/draft_list_service.py b/services/draft_list_service.py index 2a7017d..934a44f 100644 --- a/services/draft_list_service.py +++ b/services/draft_list_service.py @@ -3,13 +3,14 @@ Draft list service for Discord Bot v2.0 Handles team draft list (auto-draft queue) operations. NO CACHING - lists change frequently. """ + import logging from typing import Optional, List from services.base_service import BaseService from models.draft_list import DraftList -logger = logging.getLogger(f'{__name__}.DraftListService') +logger = logging.getLogger(f"{__name__}.DraftListService") class DraftListService(BaseService[DraftList]): @@ -32,7 +33,7 @@ class DraftListService(BaseService[DraftList]): def __init__(self): """Initialize draft list service.""" - super().__init__(DraftList, 'draftlist') + super().__init__(DraftList, "draftlist") logger.debug("DraftListService initialized") def _extract_items_and_count_from_response(self, data): @@ -54,20 +55,16 @@ class DraftListService(BaseService[DraftList]): return [], 0 # Get count - count = data.get('count', 0) + count = data.get("count", 0) # API returns items under 'picks' key (not 'draftlist') - if 'picks' in data and isinstance(data['picks'], list): - return data['picks'], count or len(data['picks']) + if "picks" in data and isinstance(data["picks"], list): + return data["picks"], count or len(data["picks"]) # Fallback to standard extraction return super()._extract_items_and_count_from_response(data) - async def get_team_list( - self, - season: int, - team_id: int - ) -> List[DraftList]: + async def get_team_list(self, season: int, team_id: int) -> List[DraftList]: """ Get team's draft list ordered by rank. @@ -82,8 +79,8 @@ class DraftListService(BaseService[DraftList]): """ try: params = [ - ('season', str(season)), - ('team_id', str(team_id)) + ("season", str(season)), + ("team_id", str(team_id)), # NOTE: API does not support 'sort' param - results must be sorted client-side ] @@ -100,11 +97,7 @@ class DraftListService(BaseService[DraftList]): return [] async def add_to_list( - self, - season: int, - team_id: int, - player_id: int, - rank: Optional[int] = None + self, season: int, team_id: int, player_id: int, rank: Optional[int] = None ) -> Optional[List[DraftList]]: """ Add player to team's draft list. @@ -133,10 +126,10 @@ class DraftListService(BaseService[DraftList]): # Create new entry data new_entry_data = { - 'season': season, - 'team_id': team_id, - 'player_id': player_id, - 'rank': rank + "season": season, + "team_id": team_id, + "player_id": player_id, + "rank": rank, } # Build complete list for bulk replacement @@ -146,36 +139,42 @@ class DraftListService(BaseService[DraftList]): for entry in current_list: if entry.rank >= rank: # Shift down entries at or after insertion point - draft_list_entries.append({ - 'season': entry.season, - 'team_id': entry.team_id, - 'player_id': entry.player_id, - 'rank': entry.rank + 1 - }) + draft_list_entries.append( + { + "season": entry.season, + "team_id": entry.team_id, + "player_id": entry.player_id, + "rank": entry.rank + 1, + } + ) else: # Keep existing rank for entries before insertion point - draft_list_entries.append({ - 'season': entry.season, - 'team_id': entry.team_id, - 'player_id': entry.player_id, - 'rank': entry.rank - }) + draft_list_entries.append( + { + "season": entry.season, + "team_id": entry.team_id, + "player_id": entry.player_id, + "rank": entry.rank, + } + ) # Add new entry draft_list_entries.append(new_entry_data) # Sort by rank for consistency - draft_list_entries.sort(key=lambda x: x['rank']) + draft_list_entries.sort(key=lambda x: x["rank"]) # POST entire list (bulk replacement) client = await self.get_client() payload = { - 'count': len(draft_list_entries), - 'draft_list': draft_list_entries + "count": len(draft_list_entries), + "draft_list": draft_list_entries, } - logger.debug(f"Posting draft list for team {team_id}: {len(draft_list_entries)} entries") - response = await client.post(self.endpoint, payload) + logger.debug( + f"Posting draft list for team {team_id}: {len(draft_list_entries)} entries" + ) + response = await client.post(f"{self.endpoint}/", payload) logger.debug(f"POST response: {response}") # Verify by fetching the list back (API returns full objects) @@ -184,20 +183,21 @@ class DraftListService(BaseService[DraftList]): # Verify the player was added if not any(entry.player_id == player_id for entry in verification): - logger.error(f"Player {player_id} not found in list after POST - operation may have failed") + logger.error( + f"Player {player_id} not found in list after POST - operation may have failed" + ) return None - logger.info(f"Added player {player_id} to team {team_id} draft list at rank {rank}") + logger.info( + f"Added player {player_id} to team {team_id} draft list at rank {rank}" + ) return verification # Return full updated list except Exception as e: logger.error(f"Error adding player {player_id} to draft list: {e}") return None - async def remove_from_list( - self, - entry_id: int - ) -> bool: + async def remove_from_list(self, entry_id: int) -> bool: """ Remove entry from draft list by ID. @@ -209,14 +209,13 @@ class DraftListService(BaseService[DraftList]): Returns: True if deletion succeeded """ - logger.warning("remove_from_list() called with entry_id - use remove_player_from_list() instead") + logger.warning( + "remove_from_list() called with entry_id - use remove_player_from_list() instead" + ) return False async def remove_player_from_list( - self, - season: int, - team_id: int, - player_id: int + self, season: int, team_id: int, player_id: int ) -> bool: """ Remove specific player from team's draft list. @@ -238,7 +237,9 @@ class DraftListService(BaseService[DraftList]): # Check if player is in list player_found = any(entry.player_id == player_id for entry in current_list) if not player_found: - logger.warning(f"Player {player_id} not found in team {team_id} draft list") + logger.warning( + f"Player {player_id} not found in team {team_id} draft list" + ) return False # Build new list without the player, adjusting ranks @@ -246,22 +247,24 @@ class DraftListService(BaseService[DraftList]): new_rank = 1 for entry in current_list: if entry.player_id != player_id: - draft_list_entries.append({ - 'season': entry.season, - 'team_id': entry.team_id, - 'player_id': entry.player_id, - 'rank': new_rank - }) + draft_list_entries.append( + { + "season": entry.season, + "team_id": entry.team_id, + "player_id": entry.player_id, + "rank": new_rank, + } + ) new_rank += 1 # POST updated list (bulk replacement) client = await self.get_client() payload = { - 'count': len(draft_list_entries), - 'draft_list': draft_list_entries + "count": len(draft_list_entries), + "draft_list": draft_list_entries, } - await client.post(self.endpoint, payload) + await client.post(f"{self.endpoint}/", payload) logger.info(f"Removed player {player_id} from team {team_id} draft list") return True @@ -270,11 +273,7 @@ class DraftListService(BaseService[DraftList]): logger.error(f"Error removing player {player_id} from draft list: {e}") return False - async def clear_list( - self, - season: int, - team_id: int - ) -> bool: + async def clear_list(self, season: int, team_id: int) -> bool: """ Clear entire draft list for team. @@ -309,10 +308,7 @@ class DraftListService(BaseService[DraftList]): return False async def reorder_list( - self, - season: int, - team_id: int, - new_order: List[int] + self, season: int, team_id: int, new_order: List[int] ) -> bool: """ Reorder team's draft list. @@ -342,21 +338,23 @@ class DraftListService(BaseService[DraftList]): continue entry = entry_map[player_id] - draft_list_entries.append({ - 'season': entry.season, - 'team_id': entry.team_id, - 'player_id': entry.player_id, - 'rank': new_rank - }) + draft_list_entries.append( + { + "season": entry.season, + "team_id": entry.team_id, + "player_id": entry.player_id, + "rank": new_rank, + } + ) # POST reordered list (bulk replacement) client = await self.get_client() payload = { - 'count': len(draft_list_entries), - 'draft_list': draft_list_entries + "count": len(draft_list_entries), + "draft_list": draft_list_entries, } - await client.post(self.endpoint, payload) + await client.post(f"{self.endpoint}/", payload) logger.info(f"Reordered draft list for team {team_id}") return True @@ -365,12 +363,7 @@ class DraftListService(BaseService[DraftList]): logger.error(f"Error reordering draft list for team {team_id}: {e}") return False - async def move_entry_up( - self, - season: int, - team_id: int, - player_id: int - ) -> bool: + async def move_entry_up(self, season: int, team_id: int, player_id: int) -> bool: """ Move player up one position in draft list (higher priority). @@ -403,7 +396,9 @@ class DraftListService(BaseService[DraftList]): return False # Find entry above (rank - 1) - above_entry = next((e for e in entries if e.rank == current_entry.rank - 1), None) + above_entry = next( + (e for e in entries if e.rank == current_entry.rank - 1), None + ) if not above_entry: logger.error(f"Could not find entry above rank {current_entry.rank}") return False @@ -421,24 +416,26 @@ class DraftListService(BaseService[DraftList]): # Keep existing rank new_rank = entry.rank - draft_list_entries.append({ - 'season': entry.season, - 'team_id': entry.team_id, - 'player_id': entry.player_id, - 'rank': new_rank - }) + draft_list_entries.append( + { + "season": entry.season, + "team_id": entry.team_id, + "player_id": entry.player_id, + "rank": new_rank, + } + ) # Sort by rank - draft_list_entries.sort(key=lambda x: x['rank']) + draft_list_entries.sort(key=lambda x: x["rank"]) # POST updated list (bulk replacement) client = await self.get_client() payload = { - 'count': len(draft_list_entries), - 'draft_list': draft_list_entries + "count": len(draft_list_entries), + "draft_list": draft_list_entries, } - await client.post(self.endpoint, payload) + await client.post(f"{self.endpoint}/", payload) logger.info(f"Moved player {player_id} up to rank {current_entry.rank - 1}") return True @@ -447,12 +444,7 @@ class DraftListService(BaseService[DraftList]): logger.error(f"Error moving player {player_id} up in draft list: {e}") return False - async def move_entry_down( - self, - season: int, - team_id: int, - player_id: int - ) -> bool: + async def move_entry_down(self, season: int, team_id: int, player_id: int) -> bool: """ Move player down one position in draft list (lower priority). @@ -485,7 +477,9 @@ class DraftListService(BaseService[DraftList]): return False # Find entry below (rank + 1) - below_entry = next((e for e in entries if e.rank == current_entry.rank + 1), None) + below_entry = next( + (e for e in entries if e.rank == current_entry.rank + 1), None + ) if not below_entry: logger.error(f"Could not find entry below rank {current_entry.rank}") return False @@ -503,25 +497,29 @@ class DraftListService(BaseService[DraftList]): # Keep existing rank new_rank = entry.rank - draft_list_entries.append({ - 'season': entry.season, - 'team_id': entry.team_id, - 'player_id': entry.player_id, - 'rank': new_rank - }) + draft_list_entries.append( + { + "season": entry.season, + "team_id": entry.team_id, + "player_id": entry.player_id, + "rank": new_rank, + } + ) # Sort by rank - draft_list_entries.sort(key=lambda x: x['rank']) + draft_list_entries.sort(key=lambda x: x["rank"]) # POST updated list (bulk replacement) client = await self.get_client() payload = { - 'count': len(draft_list_entries), - 'draft_list': draft_list_entries + "count": len(draft_list_entries), + "draft_list": draft_list_entries, } - await client.post(self.endpoint, payload) - logger.info(f"Moved player {player_id} down to rank {current_entry.rank + 1}") + await client.post(f"{self.endpoint}/", payload) + logger.info( + f"Moved player {player_id} down to rank {current_entry.rank + 1}" + ) return True diff --git a/services/injury_service.py b/services/injury_service.py index 555a7a5..0164447 100644 --- a/services/injury_service.py +++ b/services/injury_service.py @@ -3,13 +3,14 @@ Injury service for Discord Bot v2.0 Handles injury-related operations including checking, creating, and clearing injuries. """ + import logging from typing import Optional, List from services.base_service import BaseService from models.injury import Injury -logger = logging.getLogger(f'{__name__}.InjuryService') +logger = logging.getLogger(f"{__name__}.InjuryService") class InjuryService(BaseService[Injury]): @@ -25,7 +26,7 @@ class InjuryService(BaseService[Injury]): def __init__(self): """Initialize injury service.""" - super().__init__(Injury, 'injuries') + super().__init__(Injury, "injuries") logger.debug("InjuryService initialized") async def get_active_injury(self, player_id: int, season: int) -> Optional[Injury]: @@ -41,25 +42,31 @@ class InjuryService(BaseService[Injury]): """ try: params = [ - ('player_id', str(player_id)), - ('season', str(season)), - ('is_active', 'true') + ("player_id", str(player_id)), + ("season", str(season)), + ("is_active", "true"), ] injuries = await self.get_all_items(params=params) if injuries: - logger.debug(f"Found active injury for player {player_id} in season {season}") + logger.debug( + f"Found active injury for player {player_id} in season {season}" + ) return injuries[0] - logger.debug(f"No active injury found for player {player_id} in season {season}") + logger.debug( + f"No active injury found for player {player_id} in season {season}" + ) return None except Exception as e: logger.error(f"Error getting active injury for player {player_id}: {e}") return None - async def get_injuries_by_player(self, player_id: int, season: int, active_only: bool = False) -> List[Injury]: + async def get_injuries_by_player( + self, player_id: int, season: int, active_only: bool = False + ) -> List[Injury]: """ Get all injuries for a player in a specific season. @@ -72,13 +79,10 @@ class InjuryService(BaseService[Injury]): List of injuries for the player """ try: - params = [ - ('player_id', str(player_id)), - ('season', str(season)) - ] + params = [("player_id", str(player_id)), ("season", str(season))] if active_only: - params.append(('is_active', 'true')) + params.append(("is_active", "true")) injuries = await self.get_all_items(params=params) logger.debug(f"Retrieved {len(injuries)} injuries for player {player_id}") @@ -88,7 +92,9 @@ class InjuryService(BaseService[Injury]): logger.error(f"Error getting injuries for player {player_id}: {e}") return [] - async def get_injuries_by_team(self, team_id: int, season: int, active_only: bool = True) -> List[Injury]: + async def get_injuries_by_team( + self, team_id: int, season: int, active_only: bool = True + ) -> List[Injury]: """ Get all injuries for a team in a specific season. @@ -101,13 +107,10 @@ class InjuryService(BaseService[Injury]): List of injuries for the team """ try: - params = [ - ('team_id', str(team_id)), - ('season', str(season)) - ] + params = [("team_id", str(team_id)), ("season", str(season))] if active_only: - params.append(('is_active', 'true')) + params.append(("is_active", "true")) injuries = await self.get_all_items(params=params) logger.debug(f"Retrieved {len(injuries)} injuries for team {team_id}") @@ -125,7 +128,7 @@ class InjuryService(BaseService[Injury]): start_week: int, start_game: int, end_week: int, - end_game: int + end_game: int, ) -> Optional[Injury]: """ Create a new injury record. @@ -144,22 +147,24 @@ class InjuryService(BaseService[Injury]): """ try: injury_data = { - 'season': season, - 'player_id': player_id, - 'total_games': total_games, - 'start_week': start_week, - 'start_game': start_game, - 'end_week': end_week, - 'end_game': end_game, - 'is_active': True + "season": season, + "player_id": player_id, + "total_games": total_games, + "start_week": start_week, + "start_game": start_game, + "end_week": end_week, + "end_game": end_game, + "is_active": True, } # Call the API to create the injury client = await self.get_client() - response = await client.post(self.endpoint, injury_data) + response = await client.post(f"{self.endpoint}/", injury_data) if not response: - logger.error(f"Failed to create injury for player {player_id}: No response from API") + logger.error( + f"Failed to create injury for player {player_id}: No response from API" + ) return None # Merge the request data with the response to ensure all required fields are present @@ -187,7 +192,9 @@ class InjuryService(BaseService[Injury]): """ try: # Note: API expects is_active as query parameter, not JSON body - updated_injury = await self.patch(injury_id, {'is_active': False}, use_query_params=True) + updated_injury = await self.patch( + injury_id, {"is_active": False}, use_query_params=True + ) if updated_injury: logger.info(f"Cleared injury {injury_id}") @@ -216,16 +223,18 @@ class InjuryService(BaseService[Injury]): try: client = await self.get_client() params = [ - ('season', str(season)), - ('is_active', 'true'), - ('sort', 'return-asc') + ("season", str(season)), + ("is_active", "true"), + ("sort", "return-asc"), ] response = await client.get(self.endpoint, params=params) - if response and 'injuries' in response: - logger.debug(f"Retrieved {len(response['injuries'])} active injuries for season {season}") - return response['injuries'] + if response and "injuries" in response: + logger.debug( + f"Retrieved {len(response['injuries'])} active injuries for season {season}" + ) + return response["injuries"] logger.debug(f"No active injuries found for season {season}") return [] diff --git a/services/transaction_service.py b/services/transaction_service.py index 80ce0e5..57c9c90 100644 --- a/services/transaction_service.py +++ b/services/transaction_service.py @@ -248,7 +248,7 @@ class TransactionService(BaseService[Transaction]): # POST batch to API client = await self.get_client() - response = await client.post(self.endpoint, data=batch_data) + response = await client.post(f"{self.endpoint}/", data=batch_data) # API returns a string like "2 transactions have been added" # We need to return the original Transaction objects (they won't have IDs assigned by API) diff --git a/tests/test_services_base_service.py b/tests/test_services_base_service.py index 4ef6602..19a3d42 100644 --- a/tests/test_services_base_service.py +++ b/tests/test_services_base_service.py @@ -1,6 +1,7 @@ """ Tests for BaseService functionality """ + import pytest from unittest.mock import AsyncMock @@ -10,6 +11,7 @@ from models.base import SBABaseModel class MockModel(SBABaseModel): """Mock model for testing BaseService.""" + id: int name: str value: int = 100 @@ -17,240 +19,229 @@ class MockModel(SBABaseModel): class TestBaseService: """Test BaseService functionality.""" - + @pytest.fixture def mock_client(self): """Mock API client.""" client = AsyncMock() return client - + @pytest.fixture def base_service(self, mock_client): """Create BaseService instance for testing.""" - service = BaseService(MockModel, 'mocks', client=mock_client) + service = BaseService(MockModel, "mocks", client=mock_client) return service - + @pytest.mark.asyncio async def test_init(self): """Test service initialization.""" - service = BaseService(MockModel, 'test_endpoint') + service = BaseService(MockModel, "test_endpoint") assert service.model_class == MockModel - assert service.endpoint == 'test_endpoint' + assert service.endpoint == "test_endpoint" assert service._client is None - + @pytest.mark.asyncio async def test_get_by_id_success(self, base_service, mock_client): """Test successful get_by_id.""" - mock_data = {'id': 1, 'name': 'Test', 'value': 200} + mock_data = {"id": 1, "name": "Test", "value": 200} mock_client.get.return_value = mock_data - + result = await base_service.get_by_id(1) - + assert isinstance(result, MockModel) assert result.id == 1 - assert result.name == 'Test' + assert result.name == "Test" assert result.value == 200 - mock_client.get.assert_called_once_with('mocks', object_id=1) - + mock_client.get.assert_called_once_with("mocks", object_id=1) + @pytest.mark.asyncio async def test_get_by_id_not_found(self, base_service, mock_client): """Test get_by_id when object not found.""" mock_client.get.return_value = None - + result = await base_service.get_by_id(999) - + assert result is None - mock_client.get.assert_called_once_with('mocks', object_id=999) - + mock_client.get.assert_called_once_with("mocks", object_id=999) + @pytest.mark.asyncio async def test_get_all_with_count(self, base_service, mock_client): """Test get_all with count response format.""" mock_data = { - 'count': 2, - 'mocks': [ - {'id': 1, 'name': 'Test1', 'value': 100}, - {'id': 2, 'name': 'Test2', 'value': 200} - ] + "count": 2, + "mocks": [ + {"id": 1, "name": "Test1", "value": 100}, + {"id": 2, "name": "Test2", "value": 200}, + ], } mock_client.get.return_value = mock_data - + result, count = await base_service.get_all() - + assert len(result) == 2 assert count == 2 assert all(isinstance(item, MockModel) for item in result) - mock_client.get.assert_called_once_with('mocks', params=None) - + mock_client.get.assert_called_once_with("mocks", params=None) + @pytest.mark.asyncio async def test_get_all_items_convenience(self, base_service, mock_client): """Test get_all_items convenience method.""" - mock_data = { - 'count': 1, - 'mocks': [{'id': 1, 'name': 'Test', 'value': 100}] - } + mock_data = {"count": 1, "mocks": [{"id": 1, "name": "Test", "value": 100}]} mock_client.get.return_value = mock_data - + result = await base_service.get_all_items() - + assert len(result) == 1 assert isinstance(result[0], MockModel) - + @pytest.mark.asyncio async def test_create_success(self, base_service, mock_client): """Test successful object creation.""" - input_data = {'name': 'New Item', 'value': 300} - response_data = {'id': 3, 'name': 'New Item', 'value': 300} + input_data = {"name": "New Item", "value": 300} + response_data = {"id": 3, "name": "New Item", "value": 300} mock_client.post.return_value = response_data - + result = await base_service.create(input_data) - + assert isinstance(result, MockModel) assert result.id == 3 - assert result.name == 'New Item' - mock_client.post.assert_called_once_with('mocks', input_data) - + assert result.name == "New Item" + mock_client.post.assert_called_once_with("mocks/", input_data) + @pytest.mark.asyncio async def test_update_success(self, base_service, mock_client): """Test successful object update.""" - update_data = {'name': 'Updated'} - response_data = {'id': 1, 'name': 'Updated', 'value': 100} + update_data = {"name": "Updated"} + response_data = {"id": 1, "name": "Updated", "value": 100} mock_client.put.return_value = response_data - + result = await base_service.update(1, update_data) - + assert isinstance(result, MockModel) - assert result.name == 'Updated' - mock_client.put.assert_called_once_with('mocks', update_data, object_id=1) - + assert result.name == "Updated" + mock_client.put.assert_called_once_with("mocks", update_data, object_id=1) + @pytest.mark.asyncio async def test_delete_success(self, base_service, mock_client): """Test successful object deletion.""" mock_client.delete.return_value = True - + result = await base_service.delete(1) - + assert result is True - mock_client.delete.assert_called_once_with('mocks', object_id=1) - - + mock_client.delete.assert_called_once_with("mocks", object_id=1) + @pytest.mark.asyncio async def test_get_by_field(self, base_service, mock_client): """Test get_by_field functionality.""" - mock_data = { - 'count': 1, - 'mocks': [{'id': 1, 'name': 'Test', 'value': 100}] - } + mock_data = {"count": 1, "mocks": [{"id": 1, "name": "Test", "value": 100}]} mock_client.get.return_value = mock_data - - result = await base_service.get_by_field('name', 'Test') - + + result = await base_service.get_by_field("name", "Test") + assert len(result) == 1 - mock_client.get.assert_called_once_with('mocks', params=[('name', 'Test')]) - + mock_client.get.assert_called_once_with("mocks", params=[("name", "Test")]) + def test_extract_items_and_count_standard_format(self, base_service): """Test response parsing for standard format.""" data = { - 'count': 3, - 'mocks': [ - {'id': 1, 'name': 'Test1'}, - {'id': 2, 'name': 'Test2'}, - {'id': 3, 'name': 'Test3'} - ] + "count": 3, + "mocks": [ + {"id": 1, "name": "Test1"}, + {"id": 2, "name": "Test2"}, + {"id": 3, "name": "Test3"}, + ], } - + items, count = base_service._extract_items_and_count_from_response(data) - + assert len(items) == 3 assert count == 3 - assert items[0]['name'] == 'Test1' - + assert items[0]["name"] == "Test1" + def test_extract_items_and_count_single_object(self, base_service): """Test response parsing for single object.""" - data = {'id': 1, 'name': 'Single'} - + data = {"id": 1, "name": "Single"} + items, count = base_service._extract_items_and_count_from_response(data) - + assert len(items) == 1 assert count == 1 assert items[0] == data - + def test_extract_items_and_count_direct_list(self, base_service): """Test response parsing for direct list.""" - data = [ - {'id': 1, 'name': 'Test1'}, - {'id': 2, 'name': 'Test2'} - ] - + data = [{"id": 1, "name": "Test1"}, {"id": 2, "name": "Test2"}] + items, count = base_service._extract_items_and_count_from_response(data) - + assert len(items) == 2 assert count == 2 class TestBaseServiceExtras: """Additional coverage tests for BaseService edge cases.""" - + @pytest.mark.asyncio async def test_base_service_additional_methods(self): """Test additional BaseService methods for coverage.""" from services.base_service import BaseService from models.base import SBABaseModel - + class TestModel(SBABaseModel): name: str value: int = 100 - + mock_client = AsyncMock() - service = BaseService(TestModel, 'test', client=mock_client) - - + service = BaseService(TestModel, "test", client=mock_client) + # Test count method mock_client.reset_mock() - mock_client.get.return_value = {'count': 42, 'test': []} - count = await service.count(params=[('active', 'true')]) + mock_client.get.return_value = {"count": 42, "test": []} + count = await service.count(params=[("active", "true")]) assert count == 42 - + # Test update_from_model with ID mock_client.reset_mock() model = TestModel(id=1, name="Updated", value=300) mock_client.put.return_value = {"id": 1, "name": "Updated", "value": 300} result = await service.update_from_model(model) assert result.name == "Updated" - + # Test update_from_model without ID model_no_id = TestModel(name="Test") with pytest.raises(ValueError, match="Cannot update TestModel without ID"): await service.update_from_model(model_no_id) - + def test_base_service_response_parsing_edge_cases(self): """Test edge cases in response parsing.""" from services.base_service import BaseService from models.base import SBABaseModel - + class TestModel(SBABaseModel): name: str - - service = BaseService(TestModel, 'test') - + + service = BaseService(TestModel, "test") + # Test with 'items' field - data = {'count': 2, 'items': [{'name': 'Item1'}, {'name': 'Item2'}]} + data = {"count": 2, "items": [{"name": "Item1"}, {"name": "Item2"}]} items, count = service._extract_items_and_count_from_response(data) assert len(items) == 2 assert count == 2 - + # Test with 'data' field - data = {'count': 1, 'data': [{'name': 'DataItem'}]} + data = {"count": 1, "data": [{"name": "DataItem"}]} items, count = service._extract_items_and_count_from_response(data) assert len(items) == 1 assert count == 1 - + # Test with count but no recognizable list field - data = {'count': 5, 'unknown_field': [{'name': 'Item'}]} + data = {"count": 5, "unknown_field": [{"name": "Item"}]} items, count = service._extract_items_and_count_from_response(data) assert len(items) == 0 assert count == 5 - + # Test with unexpected data type items, count = service._extract_items_and_count_from_response("unexpected") assert len(items) == 0 - assert count == 0 \ No newline at end of file + assert count == 0 From 88edd1fa1065f2893f7a45e4b49183b26b8e25ab Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Tue, 10 Mar 2026 01:03:27 -0500 Subject: [PATCH 7/7] chore: pin all Python dependency versions in requirements.txt (#76) - Pin redis==7.3.0 and move to requirements.txt (production) - Create requirements-dev.txt with all dev/test deps pinned to exact versions (pytest-mock==3.15.1, black==26.1.0, ruff==0.15.0) - Remove dev/test tools from requirements.txt (not needed in Docker image) - Document pinning policy and requirements-dev.txt usage in CLAUDE.md Co-Authored-By: Claude Sonnet 4.6 --- CLAUDE.md | 17 +++++++++++++++++ requirements-dev.txt | 9 +++++++++ requirements.txt | 12 ++---------- 3 files changed, 28 insertions(+), 10 deletions(-) create mode 100644 requirements-dev.txt diff --git a/CLAUDE.md b/CLAUDE.md index 3e1a819..1f1a090 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -81,6 +81,23 @@ class MyCog(commands.Cog): - API errors → verify `DB_URL` points to correct database API and `API_TOKEN` matches - Redis errors are non-fatal (graceful fallback when `REDIS_URL` is empty) +## Dependencies + +### Pinning Policy +All dependencies are pinned to exact versions (`==`). This ensures every Docker build +produces an identical image — a `git revert` actually rolls back to the previous working state. + +- **`requirements.txt`** — production runtime deps only (used by Dockerfile) +- **`requirements-dev.txt`** — includes `-r requirements.txt` plus dev/test tools + +When installing for local development or running tests: +```bash +pip install -r requirements-dev.txt +``` + +When upgrading a dependency, update BOTH the `==` pin and (if applicable) the comment in +the file. Test before committing. Never use `>=` or `~=` constraints. + ## API Reference - OpenAPI spec: https://sba.manticorum.com/api/openapi.json (use WebFetch for current endpoints) diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..afd48d9 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,9 @@ +-r requirements.txt + +# Development & Testing +pytest==8.4.1 +pytest-asyncio==1.0.0 +pytest-mock==3.15.1 +aioresponses==0.7.8 +black==26.1.0 +ruff==0.15.0 diff --git a/requirements.txt b/requirements.txt index b31b28c..0616fc6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,15 +6,7 @@ aiohttp==3.12.13 # Utilities python-dotenv==1.1.1 -redis>=5.0.0 # For optional API response caching (not currently installed) - -# Development & Testing -pytest==8.4.1 -pytest-asyncio==1.0.0 -pytest-mock>=3.10.0 # Not currently installed -aioresponses==0.7.8 -black>=23.0.0 # Not currently installed -ruff>=0.1.0 # Not currently installed +redis==7.3.0 # Optional Dependencies -pygsheets==2.0.6 # For Google Sheets integration (scorecard submission) \ No newline at end of file +pygsheets==2.0.6 # For Google Sheets integration (scorecard submission)