diff --git a/CLAUDE.md b/CLAUDE.md index a559bcf..3beb3d3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -86,6 +86,23 @@ class MyCog(commands.Cog): - API errors → verify `DB_URL` points to correct database API and `API_TOKEN` matches - Redis errors are non-fatal (graceful fallback when `REDIS_URL` is empty) +## Dependencies + +### Pinning Policy +All dependencies are pinned to exact versions (`==`). This ensures every Docker build +produces an identical image — a `git revert` actually rolls back to the previous working state. + +- **`requirements.txt`** — production runtime deps only (used by Dockerfile) +- **`requirements-dev.txt`** — includes `-r requirements.txt` plus dev/test tools + +When installing for local development or running tests: +```bash +pip install -r requirements-dev.txt +``` + +When upgrading a dependency, update BOTH the `==` pin and (if applicable) the comment in +the file. Test before committing. Never use `>=` or `~=` constraints. + ## API Reference - OpenAPI spec: https://sba.manticorum.com/api/openapi.json (use WebFetch for current endpoints) diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..afd48d9 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,9 @@ +-r requirements.txt + +# Development & Testing +pytest==8.4.1 +pytest-asyncio==1.0.0 +pytest-mock==3.15.1 +aioresponses==0.7.8 +black==26.1.0 +ruff==0.15.0 diff --git a/requirements.txt b/requirements.txt index b31b28c..0616fc6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,15 +6,7 @@ aiohttp==3.12.13 # Utilities python-dotenv==1.1.1 -redis>=5.0.0 # For optional API response caching (not currently installed) - -# Development & Testing -pytest==8.4.1 -pytest-asyncio==1.0.0 -pytest-mock>=3.10.0 # Not currently installed -aioresponses==0.7.8 -black>=23.0.0 # Not currently installed -ruff>=0.1.0 # Not currently installed +redis==7.3.0 # Optional Dependencies -pygsheets==2.0.6 # For Google Sheets integration (scorecard submission) \ No newline at end of file +pygsheets==2.0.6 # For Google Sheets integration (scorecard submission) diff --git a/services/base_service.py b/services/base_service.py index e919e6b..7046efb 100644 --- a/services/base_service.py +++ b/services/base_service.py @@ -3,6 +3,7 @@ Base service class for Discord Bot v2.0 Provides common CRUD operations and error handling for all data services. """ + import logging import hashlib from typing import Optional, Type, TypeVar, Generic, Dict, Any, List, Tuple @@ -12,15 +13,15 @@ from models.base import SBABaseModel from exceptions import APIException from utils.cache import CacheManager -logger = logging.getLogger(f'{__name__}.BaseService') +logger = logging.getLogger(f"{__name__}.BaseService") -T = TypeVar('T', bound=SBABaseModel) +T = TypeVar("T", bound=SBABaseModel) class BaseService(Generic[T]): """ Base service class providing common CRUD operations for SBA models. - + Features: - Generic type support for any SBABaseModel subclass - Automatic model validation and conversion @@ -28,15 +29,17 @@ class BaseService(Generic[T]): - API response format handling (count + list format) - Connection management via global client """ - - def __init__(self, - model_class: Type[T], - endpoint: str, - client: Optional[APIClient] = None, - cache_manager: Optional[CacheManager] = None): + + def __init__( + self, + model_class: Type[T], + endpoint: str, + client: Optional[APIClient] = None, + cache_manager: Optional[CacheManager] = None, + ): """ Initialize base service. - + Args: model_class: Pydantic model class for this service endpoint: API endpoint path (e.g., 'players', 'teams') @@ -48,40 +51,44 @@ class BaseService(Generic[T]): self._client = client self._cached_client: Optional[APIClient] = None self.cache = cache_manager or CacheManager() - - logger.debug(f"Initialized {self.__class__.__name__} for {model_class.__name__} at endpoint '{endpoint}'") - - def _generate_cache_key(self, method: str, params: Optional[List[Tuple[str, Any]]] = None) -> str: + + logger.debug( + f"Initialized {self.__class__.__name__} for {model_class.__name__} at endpoint '{endpoint}'" + ) + + def _generate_cache_key( + self, method: str, params: Optional[List[Tuple[str, Any]]] = None + ) -> str: """ Generate consistent cache key for API calls. - + Args: method: API method name params: Query parameters as list of tuples - + Returns: SHA256-hashed cache key """ key_parts = [self.endpoint, method] - + if params: # Sort parameters for consistent key generation sorted_params = sorted(params, key=lambda x: str(x[0])) param_str = "&".join([f"{k}={v}" for k, v in sorted_params]) key_parts.append(param_str) - + key_data = ":".join(key_parts) key_hash = hashlib.sha256(key_data.encode()).hexdigest()[:16] # First 16 chars - + return self.cache.cache_key("sba", f"{self.endpoint}_{key_hash}") - + async def _get_cached_items(self, cache_key: str) -> Optional[List[T]]: """ Get cached list of model items. - + Args: cache_key: Cache key to lookup - + Returns: List of model instances or None if not cached """ @@ -91,13 +98,15 @@ class BaseService(Generic[T]): return [self.model_class.from_api_data(item) for item in cached_data] except Exception as e: logger.warning(f"Error deserializing cached data for {cache_key}: {e}") - + return None - - async def _cache_items(self, cache_key: str, items: List[T], ttl: Optional[int] = None) -> None: + + async def _cache_items( + self, cache_key: str, items: List[T], ttl: Optional[int] = None + ) -> None: """ Cache list of model items. - + Args: cache_key: Cache key to store under items: List of model instances to cache @@ -105,40 +114,40 @@ class BaseService(Generic[T]): """ if not items: return - + try: # Convert to JSON-serializable format cache_data = [item.model_dump() for item in items] await self.cache.set(cache_key, cache_data, ttl) except Exception as e: logger.warning(f"Error caching items for {cache_key}: {e}") - + async def get_client(self) -> APIClient: """ Get API client instance with caching to reduce async overhead. - + Returns: APIClient instance (cached after first access) """ if self._client: return self._client - + # Cache the global client to avoid repeated async calls if self._cached_client is None: self._cached_client = await get_global_client() - + return self._cached_client - + async def get_by_id(self, object_id: int) -> Optional[T]: """ Get single object by ID. - + Args: object_id: Unique identifier for the object - + Returns: Model instance or None if not found - + Raises: APIException: For API errors ValueError: For invalid data @@ -146,167 +155,181 @@ class BaseService(Generic[T]): try: client = await self.get_client() data = await client.get(self.endpoint, object_id=object_id) - + if not data: logger.debug(f"{self.model_class.__name__} {object_id} not found") return None - + model = self.model_class.from_api_data(data) logger.debug(f"Retrieved {self.model_class.__name__} {object_id}: {model}") return model - + except APIException: - logger.error(f"API error retrieving {self.model_class.__name__} {object_id}") + logger.error( + f"API error retrieving {self.model_class.__name__} {object_id}" + ) raise except Exception as e: - logger.error(f"Error retrieving {self.model_class.__name__} {object_id}: {e}") + logger.error( + f"Error retrieving {self.model_class.__name__} {object_id}: {e}" + ) raise APIException(f"Failed to retrieve {self.model_class.__name__}: {e}") - - async def get_all(self, params: Optional[List[tuple]] = None) -> Tuple[List[T], int]: + + async def get_all( + self, params: Optional[List[tuple]] = None + ) -> Tuple[List[T], int]: """ Get all objects with optional query parameters. - + Args: params: Query parameters as list of (key, value) tuples - + Returns: Tuple of (list of model instances, total count) - + Raises: APIException: For API errors """ try: client = await self.get_client() data = await client.get(self.endpoint, params=params) - + if not data: logger.debug(f"No {self.model_class.__name__} objects found") return [], 0 - + # Handle API response format: {'count': int, '': [...]} items, count = self._extract_items_and_count_from_response(data) - + models = [self.model_class.from_api_data(item) for item in items] - logger.debug(f"Retrieved {len(models)} of {count} {self.model_class.__name__} objects") + logger.debug( + f"Retrieved {len(models)} of {count} {self.model_class.__name__} objects" + ) return models, count - + except APIException: logger.error(f"API error retrieving {self.model_class.__name__} list") raise except Exception as e: logger.error(f"Error retrieving {self.model_class.__name__} list: {e}") - raise APIException(f"Failed to retrieve {self.model_class.__name__} list: {e}") - + raise APIException( + f"Failed to retrieve {self.model_class.__name__} list: {e}" + ) + async def get_all_items(self, params: Optional[List[tuple]] = None) -> List[T]: """ Get all objects (convenience method that only returns the list). - + Args: params: Query parameters as list of (key, value) tuples - + Returns: List of model instances """ items, _ = await self.get_all(params=params) return items - + async def create(self, model_data: Dict[str, Any]) -> Optional[T]: """ Create new object from data dictionary. - + Args: model_data: Dictionary of model fields - + Returns: Created model instance or None - + Raises: APIException: For API errors ValueError: For invalid data """ try: client = await self.get_client() - response = await client.post(self.endpoint, model_data) - + response = await client.post(f"{self.endpoint}/", model_data) + if not response: logger.warning(f"No response from {self.model_class.__name__} creation") return None - + model = self.model_class.from_api_data(response) logger.debug(f"Created {self.model_class.__name__}: {model}") return model - + except APIException: logger.error(f"API error creating {self.model_class.__name__}") raise except Exception as e: logger.error(f"Error creating {self.model_class.__name__}: {e}") raise APIException(f"Failed to create {self.model_class.__name__}: {e}") - + async def create_from_model(self, model: T) -> Optional[T]: """ Create new object from model instance. - + Args: model: Model instance to create - + Returns: Created model instance or None """ return await self.create(model.to_dict(exclude_none=True)) - + async def update(self, object_id: int, model_data: Dict[str, Any]) -> Optional[T]: """ Update existing object. - + Args: object_id: ID of object to update model_data: Dictionary of fields to update - + Returns: Updated model instance or None if not found - + Raises: APIException: For API errors """ try: client = await self.get_client() response = await client.put(self.endpoint, model_data, object_id=object_id) - + if not response: - logger.debug(f"{self.model_class.__name__} {object_id} not found for update") + logger.debug( + f"{self.model_class.__name__} {object_id} not found for update" + ) return None - + model = self.model_class.from_api_data(response) logger.debug(f"Updated {self.model_class.__name__} {object_id}: {model}") return model - + except APIException: logger.error(f"API error updating {self.model_class.__name__} {object_id}") raise except Exception as e: logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}") raise APIException(f"Failed to update {self.model_class.__name__}: {e}") - + async def update_from_model(self, model: T) -> Optional[T]: """ Update object from model instance. - + Args: model: Model instance to update (must have ID) - + Returns: Updated model instance or None - + Raises: ValueError: If model has no ID """ if not model.id: raise ValueError(f"Cannot update {self.model_class.__name__} without ID") - + return await self.update(model.id, model.to_dict(exclude_none=True)) - - async def patch(self, object_id: int, model_data: Dict[str, Any], use_query_params: bool = False) -> Optional[T]: + + async def patch( + self, object_id: int, model_data: Dict[str, Any], use_query_params: bool = False + ) -> Optional[T]: """ Update existing object with HTTP PATCH. @@ -323,10 +346,14 @@ class BaseService(Generic[T]): """ try: client = await self.get_client() - response = await client.patch(self.endpoint, model_data, object_id, use_query_params=use_query_params) + response = await client.patch( + self.endpoint, model_data, object_id, use_query_params=use_query_params + ) if not response: - logger.debug(f"{self.model_class.__name__} {object_id} not found for update") + logger.debug( + f"{self.model_class.__name__} {object_id} not found for update" + ) return None model = self.model_class.from_api_data(response) @@ -340,134 +367,142 @@ class BaseService(Generic[T]): logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}") raise APIException(f"Failed to update {self.model_class.__name__}: {e}") - async def delete(self, object_id: int) -> bool: """ Delete object by ID. - + Args: object_id: ID of object to delete - + Returns: True if deleted, False if not found - + Raises: APIException: For API errors """ try: client = await self.get_client() success = await client.delete(self.endpoint, object_id=object_id) - + if success: logger.debug(f"Deleted {self.model_class.__name__} {object_id}") else: - logger.debug(f"{self.model_class.__name__} {object_id} not found for deletion") - + logger.debug( + f"{self.model_class.__name__} {object_id} not found for deletion" + ) + return success - + except APIException: logger.error(f"API error deleting {self.model_class.__name__} {object_id}") raise except Exception as e: logger.error(f"Error deleting {self.model_class.__name__} {object_id}: {e}") raise APIException(f"Failed to delete {self.model_class.__name__}: {e}") - - + async def get_by_field(self, field: str, value: Any) -> List[T]: """ Get objects by specific field value. - + Args: field: Field name to search value: Field value to match - + Returns: List of matching model instances """ params = [(field, str(value))] return await self.get_all_items(params=params) - + async def count(self, params: Optional[List[tuple]] = None) -> int: """ Get count of objects matching parameters. - + Args: params: Query parameters - + Returns: Number of matching objects (from API count field) """ _, count = await self.get_all(params=params) return count - - def _extract_items_and_count_from_response(self, data: Any) -> Tuple[List[Dict[str, Any]], int]: + + def _extract_items_and_count_from_response( + self, data: Any + ) -> Tuple[List[Dict[str, Any]], int]: """ Extract items list and count from API response with optimized parsing. - + Expected format: {'count': int, '': [...]} Single object format: {'id': 1, 'name': '...'} - + Args: data: API response data - + Returns: Tuple of (items list, total count) """ if isinstance(data, list): return data, len(data) - + if not isinstance(data, dict): - logger.warning(f"Unexpected response format for {self.model_class.__name__}: {type(data)}") + logger.warning( + f"Unexpected response format for {self.model_class.__name__}: {type(data)}" + ) return [], 0 - + # Single pass through the response dict - get count first - count = data.get('count', 0) - + count = data.get("count", 0) + # Priority order for finding items list (most common first) - field_candidates = [self.endpoint, 'items', 'data', 'results'] + field_candidates = [self.endpoint, "items", "data", "results"] for field_name in field_candidates: if field_name in data and isinstance(data[field_name], list): return data[field_name], count or len(data[field_name]) - + # Single object response (check for common identifying fields) - if any(key in data for key in ['id', 'name', 'abbrev']): + if any(key in data for key in ["id", "name", "abbrev"]): return [data], 1 - + return [], count - - async def get_items_with_params(self, params: Optional[List[tuple]] = None) -> List[T]: + + async def get_items_with_params( + self, params: Optional[List[tuple]] = None + ) -> List[T]: """ Get all items with parameters (alias for get_all_items for compatibility). - + Args: params: Query parameters as list of (key, value) tuples - + Returns: List of model instances """ return await self.get_all_items(params=params) - + async def create_item(self, model_data: Dict[str, Any]) -> Optional[T]: """ Create item (alias for create for compatibility). - + Args: model_data: Dictionary of model fields - + Returns: Created model instance or None """ return await self.create(model_data) - - async def update_item_by_field(self, field: str, value: Any, update_data: Dict[str, Any]) -> Optional[T]: + + async def update_item_by_field( + self, field: str, value: Any, update_data: Dict[str, Any] + ) -> Optional[T]: """ Update item by field value. - + Args: field: Field name to search by value: Field value to match update_data: Data to update - + Returns: Updated model instance or None if not found """ @@ -475,22 +510,22 @@ class BaseService(Generic[T]): items = await self.get_by_field(field, value) if not items: return None - + # Update the first matching item item = items[0] if not item.id: return None - + return await self.update(item.id, update_data) - + async def delete_item_by_field(self, field: str, value: Any) -> bool: """ Delete item by field value. - + Args: field: Field name to search by value: Field value to match - + Returns: True if deleted, False if not found """ @@ -498,62 +533,41 @@ class BaseService(Generic[T]): items = await self.get_by_field(field, value) if not items: return False - + # Delete the first matching item item = items[0] if not item.id: return False - + return await self.delete(item.id) - - async def create_item_in_table(self, table_name: str, item_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """ - Create item in a specific table (simplified for custom commands service). - This is a placeholder - real implementation would need table-specific endpoints. - - Args: - table_name: Name of the table - item_data: Data to create - - Returns: - Created item data or None - """ - # For now, use the main endpoint - this would need proper implementation - # for different tables like 'custom_command_creators' - try: - client = await self.get_client() - # Use table name as endpoint for now - response = await client.post(table_name, item_data) - return response - except Exception as e: - logger.error(f"Error creating item in table {table_name}: {e}") - return None - - async def get_items_from_table_with_params(self, table_name: str, params: List[tuple]) -> List[Dict[str, Any]]: + + async def get_items_from_table_with_params( + self, table_name: str, params: List[tuple] + ) -> List[Dict[str, Any]]: """ Get items from a specific table with parameters. - + Args: table_name: Name of the table params: Query parameters - + Returns: List of item dictionaries """ try: client = await self.get_client() data = await client.get(table_name, params=params) - + if not data: return [] - + # Handle response format items, _ = self._extract_items_and_count_from_response(data) return items - + except Exception as e: logger.error(f"Error getting items from table {table_name}: {e}") return [] def __repr__(self) -> str: - return f"{self.__class__.__name__}(model={self.model_class.__name__}, endpoint='{self.endpoint}')" \ No newline at end of file + return f"{self.__class__.__name__}(model={self.model_class.__name__}, endpoint='{self.endpoint}')" diff --git a/services/custom_commands_service.py b/services/custom_commands_service.py index 72838bb..03445bc 100644 --- a/services/custom_commands_service.py +++ b/services/custom_commands_service.py @@ -552,9 +552,8 @@ class CustomCommandsService(BaseService[CustomCommand]): "active_commands": 0, } - result = await self.create_item_in_table( - "custom_commands/creators", creator_data - ) + client = await self.get_client() + result = await client.post("custom_commands/creators", creator_data) if not result: raise BotException("Failed to create command creator") diff --git a/services/decision_service.py b/services/decision_service.py index e101d0e..11a36b5 100644 --- a/services/decision_service.py +++ b/services/decision_service.py @@ -3,6 +3,7 @@ Decision Service Manages pitching decision operations for game submission. """ + from typing import List, Dict, Any, Optional, Tuple from utils.logging import get_contextual_logger @@ -16,17 +17,14 @@ class DecisionService: def __init__(self): """Initialize decision service.""" - self.logger = get_contextual_logger(f'{__name__}.DecisionService') + self.logger = get_contextual_logger(f"{__name__}.DecisionService") self._get_client = get_global_client async def get_client(self): """Get the API client.""" return await self._get_client() - async def create_decisions_batch( - self, - decisions: List[Dict[str, Any]] - ) -> bool: + async def create_decisions_batch(self, decisions: List[Dict[str, Any]]) -> bool: """ POST batch of decisions to /decisions endpoint. @@ -42,8 +40,10 @@ class DecisionService: try: client = await self.get_client() - payload = {'decisions': decisions} - await client.post('decisions', payload) + payload = {"decisions": decisions} + # Trailing slash required: without it, the server returns a 307 redirect + # and aiohttp drops the POST body when following the redirect + await client.post("decisions/", payload) self.logger.info(f"Created {len(decisions)} decisions") return True @@ -70,7 +70,7 @@ class DecisionService: """ try: client = await self.get_client() - await client.delete(f'decisions/game/{game_id}') + await client.delete(f"decisions/game/{game_id}") self.logger.info(f"Deleted decisions for game {game_id}") return True @@ -80,9 +80,10 @@ class DecisionService: raise APIException(f"Failed to delete decisions: {e}") async def find_winning_losing_pitchers( - self, - decisions_data: List[Dict[str, Any]] - ) -> Tuple[Optional[Player], Optional[Player], Optional[Player], List[Player], List[Player]]: + self, decisions_data: List[Dict[str, Any]] + ) -> Tuple[ + Optional[Player], Optional[Player], Optional[Player], List[Player], List[Player] + ]: """ Extract WP, LP, SV, Holds, Blown Saves from decisions list and fetch Player objects. @@ -110,17 +111,17 @@ class DecisionService: # First pass: Extract IDs for decision in decisions_data: - pitcher_id = int(decision.get('pitcher_id', 0)) + pitcher_id = int(decision.get("pitcher_id", 0)) - if int(decision.get('win', 0)) == 1: + if int(decision.get("win", 0)) == 1: wp_id = pitcher_id - if int(decision.get('loss', 0)) == 1: + if int(decision.get("loss", 0)) == 1: lp_id = pitcher_id - if int(decision.get('is_save', 0)) == 1: + if int(decision.get("is_save", 0)) == 1: sv_id = pitcher_id - if int(decision.get('hold', 0)) == 1: + if int(decision.get("hold", 0)) == 1: hold_ids.append(pitcher_id) - if int(decision.get('b_save', 0)) == 1: + if int(decision.get("b_save", 0)) == 1: bsv_ids.append(pitcher_id) # Second pass: Fetch Player objects @@ -154,9 +155,9 @@ class DecisionService: """ error_str = str(error) - if 'Player ID' in error_str and 'not found' in error_str: + if "Player ID" in error_str and "not found" in error_str: return "Invalid pitcher ID in decision data." - elif 'Game ID' in error_str and 'not found' in error_str: + elif "Game ID" in error_str and "not found" in error_str: return "Game not found for decisions." else: return f"Error submitting decisions: {error_str}" diff --git a/services/draft_list_service.py b/services/draft_list_service.py index 2a7017d..934a44f 100644 --- a/services/draft_list_service.py +++ b/services/draft_list_service.py @@ -3,13 +3,14 @@ Draft list service for Discord Bot v2.0 Handles team draft list (auto-draft queue) operations. NO CACHING - lists change frequently. """ + import logging from typing import Optional, List from services.base_service import BaseService from models.draft_list import DraftList -logger = logging.getLogger(f'{__name__}.DraftListService') +logger = logging.getLogger(f"{__name__}.DraftListService") class DraftListService(BaseService[DraftList]): @@ -32,7 +33,7 @@ class DraftListService(BaseService[DraftList]): def __init__(self): """Initialize draft list service.""" - super().__init__(DraftList, 'draftlist') + super().__init__(DraftList, "draftlist") logger.debug("DraftListService initialized") def _extract_items_and_count_from_response(self, data): @@ -54,20 +55,16 @@ class DraftListService(BaseService[DraftList]): return [], 0 # Get count - count = data.get('count', 0) + count = data.get("count", 0) # API returns items under 'picks' key (not 'draftlist') - if 'picks' in data and isinstance(data['picks'], list): - return data['picks'], count or len(data['picks']) + if "picks" in data and isinstance(data["picks"], list): + return data["picks"], count or len(data["picks"]) # Fallback to standard extraction return super()._extract_items_and_count_from_response(data) - async def get_team_list( - self, - season: int, - team_id: int - ) -> List[DraftList]: + async def get_team_list(self, season: int, team_id: int) -> List[DraftList]: """ Get team's draft list ordered by rank. @@ -82,8 +79,8 @@ class DraftListService(BaseService[DraftList]): """ try: params = [ - ('season', str(season)), - ('team_id', str(team_id)) + ("season", str(season)), + ("team_id", str(team_id)), # NOTE: API does not support 'sort' param - results must be sorted client-side ] @@ -100,11 +97,7 @@ class DraftListService(BaseService[DraftList]): return [] async def add_to_list( - self, - season: int, - team_id: int, - player_id: int, - rank: Optional[int] = None + self, season: int, team_id: int, player_id: int, rank: Optional[int] = None ) -> Optional[List[DraftList]]: """ Add player to team's draft list. @@ -133,10 +126,10 @@ class DraftListService(BaseService[DraftList]): # Create new entry data new_entry_data = { - 'season': season, - 'team_id': team_id, - 'player_id': player_id, - 'rank': rank + "season": season, + "team_id": team_id, + "player_id": player_id, + "rank": rank, } # Build complete list for bulk replacement @@ -146,36 +139,42 @@ class DraftListService(BaseService[DraftList]): for entry in current_list: if entry.rank >= rank: # Shift down entries at or after insertion point - draft_list_entries.append({ - 'season': entry.season, - 'team_id': entry.team_id, - 'player_id': entry.player_id, - 'rank': entry.rank + 1 - }) + draft_list_entries.append( + { + "season": entry.season, + "team_id": entry.team_id, + "player_id": entry.player_id, + "rank": entry.rank + 1, + } + ) else: # Keep existing rank for entries before insertion point - draft_list_entries.append({ - 'season': entry.season, - 'team_id': entry.team_id, - 'player_id': entry.player_id, - 'rank': entry.rank - }) + draft_list_entries.append( + { + "season": entry.season, + "team_id": entry.team_id, + "player_id": entry.player_id, + "rank": entry.rank, + } + ) # Add new entry draft_list_entries.append(new_entry_data) # Sort by rank for consistency - draft_list_entries.sort(key=lambda x: x['rank']) + draft_list_entries.sort(key=lambda x: x["rank"]) # POST entire list (bulk replacement) client = await self.get_client() payload = { - 'count': len(draft_list_entries), - 'draft_list': draft_list_entries + "count": len(draft_list_entries), + "draft_list": draft_list_entries, } - logger.debug(f"Posting draft list for team {team_id}: {len(draft_list_entries)} entries") - response = await client.post(self.endpoint, payload) + logger.debug( + f"Posting draft list for team {team_id}: {len(draft_list_entries)} entries" + ) + response = await client.post(f"{self.endpoint}/", payload) logger.debug(f"POST response: {response}") # Verify by fetching the list back (API returns full objects) @@ -184,20 +183,21 @@ class DraftListService(BaseService[DraftList]): # Verify the player was added if not any(entry.player_id == player_id for entry in verification): - logger.error(f"Player {player_id} not found in list after POST - operation may have failed") + logger.error( + f"Player {player_id} not found in list after POST - operation may have failed" + ) return None - logger.info(f"Added player {player_id} to team {team_id} draft list at rank {rank}") + logger.info( + f"Added player {player_id} to team {team_id} draft list at rank {rank}" + ) return verification # Return full updated list except Exception as e: logger.error(f"Error adding player {player_id} to draft list: {e}") return None - async def remove_from_list( - self, - entry_id: int - ) -> bool: + async def remove_from_list(self, entry_id: int) -> bool: """ Remove entry from draft list by ID. @@ -209,14 +209,13 @@ class DraftListService(BaseService[DraftList]): Returns: True if deletion succeeded """ - logger.warning("remove_from_list() called with entry_id - use remove_player_from_list() instead") + logger.warning( + "remove_from_list() called with entry_id - use remove_player_from_list() instead" + ) return False async def remove_player_from_list( - self, - season: int, - team_id: int, - player_id: int + self, season: int, team_id: int, player_id: int ) -> bool: """ Remove specific player from team's draft list. @@ -238,7 +237,9 @@ class DraftListService(BaseService[DraftList]): # Check if player is in list player_found = any(entry.player_id == player_id for entry in current_list) if not player_found: - logger.warning(f"Player {player_id} not found in team {team_id} draft list") + logger.warning( + f"Player {player_id} not found in team {team_id} draft list" + ) return False # Build new list without the player, adjusting ranks @@ -246,22 +247,24 @@ class DraftListService(BaseService[DraftList]): new_rank = 1 for entry in current_list: if entry.player_id != player_id: - draft_list_entries.append({ - 'season': entry.season, - 'team_id': entry.team_id, - 'player_id': entry.player_id, - 'rank': new_rank - }) + draft_list_entries.append( + { + "season": entry.season, + "team_id": entry.team_id, + "player_id": entry.player_id, + "rank": new_rank, + } + ) new_rank += 1 # POST updated list (bulk replacement) client = await self.get_client() payload = { - 'count': len(draft_list_entries), - 'draft_list': draft_list_entries + "count": len(draft_list_entries), + "draft_list": draft_list_entries, } - await client.post(self.endpoint, payload) + await client.post(f"{self.endpoint}/", payload) logger.info(f"Removed player {player_id} from team {team_id} draft list") return True @@ -270,11 +273,7 @@ class DraftListService(BaseService[DraftList]): logger.error(f"Error removing player {player_id} from draft list: {e}") return False - async def clear_list( - self, - season: int, - team_id: int - ) -> bool: + async def clear_list(self, season: int, team_id: int) -> bool: """ Clear entire draft list for team. @@ -309,10 +308,7 @@ class DraftListService(BaseService[DraftList]): return False async def reorder_list( - self, - season: int, - team_id: int, - new_order: List[int] + self, season: int, team_id: int, new_order: List[int] ) -> bool: """ Reorder team's draft list. @@ -342,21 +338,23 @@ class DraftListService(BaseService[DraftList]): continue entry = entry_map[player_id] - draft_list_entries.append({ - 'season': entry.season, - 'team_id': entry.team_id, - 'player_id': entry.player_id, - 'rank': new_rank - }) + draft_list_entries.append( + { + "season": entry.season, + "team_id": entry.team_id, + "player_id": entry.player_id, + "rank": new_rank, + } + ) # POST reordered list (bulk replacement) client = await self.get_client() payload = { - 'count': len(draft_list_entries), - 'draft_list': draft_list_entries + "count": len(draft_list_entries), + "draft_list": draft_list_entries, } - await client.post(self.endpoint, payload) + await client.post(f"{self.endpoint}/", payload) logger.info(f"Reordered draft list for team {team_id}") return True @@ -365,12 +363,7 @@ class DraftListService(BaseService[DraftList]): logger.error(f"Error reordering draft list for team {team_id}: {e}") return False - async def move_entry_up( - self, - season: int, - team_id: int, - player_id: int - ) -> bool: + async def move_entry_up(self, season: int, team_id: int, player_id: int) -> bool: """ Move player up one position in draft list (higher priority). @@ -403,7 +396,9 @@ class DraftListService(BaseService[DraftList]): return False # Find entry above (rank - 1) - above_entry = next((e for e in entries if e.rank == current_entry.rank - 1), None) + above_entry = next( + (e for e in entries if e.rank == current_entry.rank - 1), None + ) if not above_entry: logger.error(f"Could not find entry above rank {current_entry.rank}") return False @@ -421,24 +416,26 @@ class DraftListService(BaseService[DraftList]): # Keep existing rank new_rank = entry.rank - draft_list_entries.append({ - 'season': entry.season, - 'team_id': entry.team_id, - 'player_id': entry.player_id, - 'rank': new_rank - }) + draft_list_entries.append( + { + "season": entry.season, + "team_id": entry.team_id, + "player_id": entry.player_id, + "rank": new_rank, + } + ) # Sort by rank - draft_list_entries.sort(key=lambda x: x['rank']) + draft_list_entries.sort(key=lambda x: x["rank"]) # POST updated list (bulk replacement) client = await self.get_client() payload = { - 'count': len(draft_list_entries), - 'draft_list': draft_list_entries + "count": len(draft_list_entries), + "draft_list": draft_list_entries, } - await client.post(self.endpoint, payload) + await client.post(f"{self.endpoint}/", payload) logger.info(f"Moved player {player_id} up to rank {current_entry.rank - 1}") return True @@ -447,12 +444,7 @@ class DraftListService(BaseService[DraftList]): logger.error(f"Error moving player {player_id} up in draft list: {e}") return False - async def move_entry_down( - self, - season: int, - team_id: int, - player_id: int - ) -> bool: + async def move_entry_down(self, season: int, team_id: int, player_id: int) -> bool: """ Move player down one position in draft list (lower priority). @@ -485,7 +477,9 @@ class DraftListService(BaseService[DraftList]): return False # Find entry below (rank + 1) - below_entry = next((e for e in entries if e.rank == current_entry.rank + 1), None) + below_entry = next( + (e for e in entries if e.rank == current_entry.rank + 1), None + ) if not below_entry: logger.error(f"Could not find entry below rank {current_entry.rank}") return False @@ -503,25 +497,29 @@ class DraftListService(BaseService[DraftList]): # Keep existing rank new_rank = entry.rank - draft_list_entries.append({ - 'season': entry.season, - 'team_id': entry.team_id, - 'player_id': entry.player_id, - 'rank': new_rank - }) + draft_list_entries.append( + { + "season": entry.season, + "team_id": entry.team_id, + "player_id": entry.player_id, + "rank": new_rank, + } + ) # Sort by rank - draft_list_entries.sort(key=lambda x: x['rank']) + draft_list_entries.sort(key=lambda x: x["rank"]) # POST updated list (bulk replacement) client = await self.get_client() payload = { - 'count': len(draft_list_entries), - 'draft_list': draft_list_entries + "count": len(draft_list_entries), + "draft_list": draft_list_entries, } - await client.post(self.endpoint, payload) - logger.info(f"Moved player {player_id} down to rank {current_entry.rank + 1}") + await client.post(f"{self.endpoint}/", payload) + logger.info( + f"Moved player {player_id} down to rank {current_entry.rank + 1}" + ) return True diff --git a/services/injury_service.py b/services/injury_service.py index 555a7a5..0164447 100644 --- a/services/injury_service.py +++ b/services/injury_service.py @@ -3,13 +3,14 @@ Injury service for Discord Bot v2.0 Handles injury-related operations including checking, creating, and clearing injuries. """ + import logging from typing import Optional, List from services.base_service import BaseService from models.injury import Injury -logger = logging.getLogger(f'{__name__}.InjuryService') +logger = logging.getLogger(f"{__name__}.InjuryService") class InjuryService(BaseService[Injury]): @@ -25,7 +26,7 @@ class InjuryService(BaseService[Injury]): def __init__(self): """Initialize injury service.""" - super().__init__(Injury, 'injuries') + super().__init__(Injury, "injuries") logger.debug("InjuryService initialized") async def get_active_injury(self, player_id: int, season: int) -> Optional[Injury]: @@ -41,25 +42,31 @@ class InjuryService(BaseService[Injury]): """ try: params = [ - ('player_id', str(player_id)), - ('season', str(season)), - ('is_active', 'true') + ("player_id", str(player_id)), + ("season", str(season)), + ("is_active", "true"), ] injuries = await self.get_all_items(params=params) if injuries: - logger.debug(f"Found active injury for player {player_id} in season {season}") + logger.debug( + f"Found active injury for player {player_id} in season {season}" + ) return injuries[0] - logger.debug(f"No active injury found for player {player_id} in season {season}") + logger.debug( + f"No active injury found for player {player_id} in season {season}" + ) return None except Exception as e: logger.error(f"Error getting active injury for player {player_id}: {e}") return None - async def get_injuries_by_player(self, player_id: int, season: int, active_only: bool = False) -> List[Injury]: + async def get_injuries_by_player( + self, player_id: int, season: int, active_only: bool = False + ) -> List[Injury]: """ Get all injuries for a player in a specific season. @@ -72,13 +79,10 @@ class InjuryService(BaseService[Injury]): List of injuries for the player """ try: - params = [ - ('player_id', str(player_id)), - ('season', str(season)) - ] + params = [("player_id", str(player_id)), ("season", str(season))] if active_only: - params.append(('is_active', 'true')) + params.append(("is_active", "true")) injuries = await self.get_all_items(params=params) logger.debug(f"Retrieved {len(injuries)} injuries for player {player_id}") @@ -88,7 +92,9 @@ class InjuryService(BaseService[Injury]): logger.error(f"Error getting injuries for player {player_id}: {e}") return [] - async def get_injuries_by_team(self, team_id: int, season: int, active_only: bool = True) -> List[Injury]: + async def get_injuries_by_team( + self, team_id: int, season: int, active_only: bool = True + ) -> List[Injury]: """ Get all injuries for a team in a specific season. @@ -101,13 +107,10 @@ class InjuryService(BaseService[Injury]): List of injuries for the team """ try: - params = [ - ('team_id', str(team_id)), - ('season', str(season)) - ] + params = [("team_id", str(team_id)), ("season", str(season))] if active_only: - params.append(('is_active', 'true')) + params.append(("is_active", "true")) injuries = await self.get_all_items(params=params) logger.debug(f"Retrieved {len(injuries)} injuries for team {team_id}") @@ -125,7 +128,7 @@ class InjuryService(BaseService[Injury]): start_week: int, start_game: int, end_week: int, - end_game: int + end_game: int, ) -> Optional[Injury]: """ Create a new injury record. @@ -144,22 +147,24 @@ class InjuryService(BaseService[Injury]): """ try: injury_data = { - 'season': season, - 'player_id': player_id, - 'total_games': total_games, - 'start_week': start_week, - 'start_game': start_game, - 'end_week': end_week, - 'end_game': end_game, - 'is_active': True + "season": season, + "player_id": player_id, + "total_games": total_games, + "start_week": start_week, + "start_game": start_game, + "end_week": end_week, + "end_game": end_game, + "is_active": True, } # Call the API to create the injury client = await self.get_client() - response = await client.post(self.endpoint, injury_data) + response = await client.post(f"{self.endpoint}/", injury_data) if not response: - logger.error(f"Failed to create injury for player {player_id}: No response from API") + logger.error( + f"Failed to create injury for player {player_id}: No response from API" + ) return None # Merge the request data with the response to ensure all required fields are present @@ -187,7 +192,9 @@ class InjuryService(BaseService[Injury]): """ try: # Note: API expects is_active as query parameter, not JSON body - updated_injury = await self.patch(injury_id, {'is_active': False}, use_query_params=True) + updated_injury = await self.patch( + injury_id, {"is_active": False}, use_query_params=True + ) if updated_injury: logger.info(f"Cleared injury {injury_id}") @@ -216,16 +223,18 @@ class InjuryService(BaseService[Injury]): try: client = await self.get_client() params = [ - ('season', str(season)), - ('is_active', 'true'), - ('sort', 'return-asc') + ("season", str(season)), + ("is_active", "true"), + ("sort", "return-asc"), ] response = await client.get(self.endpoint, params=params) - if response and 'injuries' in response: - logger.debug(f"Retrieved {len(response['injuries'])} active injuries for season {season}") - return response['injuries'] + if response and "injuries" in response: + logger.debug( + f"Retrieved {len(response['injuries'])} active injuries for season {season}" + ) + return response["injuries"] logger.debug(f"No active injuries found for season {season}") return [] diff --git a/services/play_service.py b/services/play_service.py index 7b08bf6..cdaf293 100644 --- a/services/play_service.py +++ b/services/play_service.py @@ -3,6 +3,7 @@ Play Service Manages play-by-play data operations for game submission. """ + from typing import List, Dict, Any from utils.logging import get_contextual_logger @@ -16,7 +17,7 @@ class PlayService: def __init__(self): """Initialize play service.""" - self.logger = get_contextual_logger(f'{__name__}.PlayService') + self.logger = get_contextual_logger(f"{__name__}.PlayService") self._get_client = get_global_client async def get_client(self): @@ -39,8 +40,10 @@ class PlayService: try: client = await self.get_client() - payload = {'plays': plays} - response = await client.post('plays', payload) + payload = {"plays": plays} + # Trailing slash required: without it, the server returns a 307 redirect + # and aiohttp drops the POST body when following the redirect + response = await client.post("plays/", payload) self.logger.info(f"Created {len(plays)} plays") return True @@ -68,7 +71,7 @@ class PlayService: """ try: client = await self.get_client() - response = await client.delete(f'plays/game/{game_id}') + response = await client.delete(f"plays/game/{game_id}") self.logger.info(f"Deleted plays for game {game_id}") return True @@ -77,11 +80,7 @@ class PlayService: self.logger.error(f"Failed to delete plays for game {game_id}: {e}") raise APIException(f"Failed to delete plays: {e}") - async def get_top_plays_by_wpa( - self, - game_id: int, - limit: int = 3 - ) -> List[Play]: + async def get_top_plays_by_wpa(self, game_id: int, limit: int = 3) -> List[Play]: """ Get top plays by WPA (absolute value) for key plays display. @@ -95,19 +94,15 @@ class PlayService: try: client = await self.get_client() - params = [ - ('game_id', game_id), - ('sort', 'wpa-desc'), - ('limit', limit) - ] + params = [("game_id", game_id), ("sort", "wpa-desc"), ("limit", limit)] - response = await client.get('plays', params=params) + response = await client.get("plays", params=params) - if not response or 'plays' not in response: - self.logger.info(f'No plays found for game ID {game_id}') + if not response or "plays" not in response: + self.logger.info(f"No plays found for game ID {game_id}") return [] - plays = [Play.from_api_data(p) for p in response['plays']] + plays = [Play.from_api_data(p) for p in response["plays"]] self.logger.debug(f"Retrieved {len(plays)} top plays for game {game_id}") return plays @@ -129,11 +124,11 @@ class PlayService: error_str = str(error) # Common error patterns - if 'Player ID' in error_str and 'not found' in error_str: + if "Player ID" in error_str and "not found" in error_str: return "Invalid player ID in scorecard data. Please check player IDs." - elif 'Game ID' in error_str and 'not found' in error_str: + elif "Game ID" in error_str and "not found" in error_str: return "Game not found in database. Please contact an admin." - elif 'validation' in error_str.lower(): + elif "validation" in error_str.lower(): return f"Data validation error: {error_str}" else: return f"Error submitting plays: {error_str}" diff --git a/services/transaction_service.py b/services/transaction_service.py index 80ce0e5..57c9c90 100644 --- a/services/transaction_service.py +++ b/services/transaction_service.py @@ -248,7 +248,7 @@ class TransactionService(BaseService[Transaction]): # POST batch to API client = await self.get_client() - response = await client.post(self.endpoint, data=batch_data) + response = await client.post(f"{self.endpoint}/", data=batch_data) # API returns a string like "2 transactions have been added" # We need to return the original Transaction objects (they won't have IDs assigned by API) diff --git a/tests/test_api_client.py b/tests/test_api_client.py index 14edf0f..c2ca5b2 100644 --- a/tests/test_api_client.py +++ b/tests/test_api_client.py @@ -1,18 +1,24 @@ """ API client tests using aioresponses for clean HTTP mocking """ + import pytest import asyncio from unittest.mock import MagicMock, patch from aioresponses import aioresponses -from api.client import APIClient, get_api_client, get_global_client, cleanup_global_client +from api.client import ( + APIClient, + get_api_client, + get_global_client, + cleanup_global_client, +) from exceptions import APIException class TestAPIClientWithAioresponses: """Test API client with aioresponses for HTTP mocking.""" - + @pytest.fixture def mock_config(self): """Mock configuration for testing.""" @@ -20,66 +26,57 @@ class TestAPIClientWithAioresponses: config.db_url = "https://api.example.com" config.api_token = "test-token" return config - + @pytest.fixture def api_client(self, mock_config): """Create API client with mocked config.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): return APIClient() - + @pytest.mark.asyncio async def test_get_request_success(self, api_client): """Test successful GET request.""" expected_data = {"id": 1, "name": "Test Player"} - + with aioresponses() as m: m.get( "https://api.example.com/v3/players/1", payload=expected_data, - status=200 + status=200, ) - + result = await api_client.get("players", object_id=1) - + assert result == expected_data - + @pytest.mark.asyncio async def test_get_request_404(self, api_client): """Test GET request returning 404.""" with aioresponses() as m: - m.get( - "https://api.example.com/v3/players/999", - status=404 - ) - + m.get("https://api.example.com/v3/players/999", status=404) + result = await api_client.get("players", object_id=999) - + assert result is None - + @pytest.mark.asyncio async def test_get_request_401_auth_error(self, api_client): """Test GET request with authentication error.""" with aioresponses() as m: - m.get( - "https://api.example.com/v3/players", - status=401 - ) - + m.get("https://api.example.com/v3/players", status=401) + with pytest.raises(APIException, match="Authentication failed"): await api_client.get("players") - + @pytest.mark.asyncio async def test_get_request_403_forbidden(self, api_client): """Test GET request with forbidden error.""" with aioresponses() as m: - m.get( - "https://api.example.com/v3/players", - status=403 - ) - + m.get("https://api.example.com/v3/players", status=403) + with pytest.raises(APIException, match="Access forbidden"): await api_client.get("players") - + @pytest.mark.asyncio async def test_get_request_500_server_error(self, api_client): """Test GET request with server error.""" @@ -87,135 +84,127 @@ class TestAPIClientWithAioresponses: m.get( "https://api.example.com/v3/players", status=500, - body="Internal Server Error" + body="Internal Server Error", ) - - with pytest.raises(APIException, match="API request failed with status 500"): + + with pytest.raises( + APIException, match="API request failed with status 500" + ): await api_client.get("players") - + @pytest.mark.asyncio async def test_get_request_with_params(self, api_client): """Test GET request with query parameters.""" expected_data = {"count": 2, "players": [{"id": 1}, {"id": 2}]} - + with aioresponses() as m: m.get( "https://api.example.com/v3/players?team_id=5&season=12", payload=expected_data, - status=200 + status=200, ) - - result = await api_client.get("players", params=[("team_id", "5"), ("season", "12")]) - + + result = await api_client.get( + "players", params=[("team_id", "5"), ("season", "12")] + ) + assert result == expected_data - + @pytest.mark.asyncio async def test_post_request_success(self, api_client): """Test successful POST request.""" input_data = {"name": "New Player", "position": "C"} expected_response = {"id": 1, "name": "New Player", "position": "C"} - + with aioresponses() as m: m.post( "https://api.example.com/v3/players", payload=expected_response, - status=201 + status=201, ) - + result = await api_client.post("players", input_data) - + assert result == expected_response - + @pytest.mark.asyncio async def test_post_request_400_error(self, api_client): """Test POST request with validation error.""" input_data = {"invalid": "data"} - + with aioresponses() as m: m.post( - "https://api.example.com/v3/players", - status=400, - body="Invalid data" + "https://api.example.com/v3/players", status=400, body="Invalid data" ) - - with pytest.raises(APIException, match="POST request failed with status 400"): + + with pytest.raises( + APIException, match="POST request failed with status 400" + ): await api_client.post("players", input_data) - + @pytest.mark.asyncio async def test_put_request_success(self, api_client): """Test successful PUT request.""" update_data = {"name": "Updated Player"} expected_response = {"id": 1, "name": "Updated Player"} - + with aioresponses() as m: m.put( "https://api.example.com/v3/players/1", payload=expected_response, - status=200 + status=200, ) - + result = await api_client.put("players", update_data, object_id=1) - + assert result == expected_response - + @pytest.mark.asyncio async def test_put_request_404(self, api_client): """Test PUT request with 404.""" update_data = {"name": "Updated Player"} - + with aioresponses() as m: - m.put( - "https://api.example.com/v3/players/999", - status=404 - ) - + m.put("https://api.example.com/v3/players/999", status=404) + result = await api_client.put("players", update_data, object_id=999) - + assert result is None - + @pytest.mark.asyncio async def test_delete_request_success(self, api_client): """Test successful DELETE request.""" with aioresponses() as m: - m.delete( - "https://api.example.com/v3/players/1", - status=204 - ) - + m.delete("https://api.example.com/v3/players/1", status=204) + result = await api_client.delete("players", object_id=1) - + assert result is True - + @pytest.mark.asyncio async def test_delete_request_404(self, api_client): """Test DELETE request with 404.""" with aioresponses() as m: - m.delete( - "https://api.example.com/v3/players/999", - status=404 - ) - + m.delete("https://api.example.com/v3/players/999", status=404) + result = await api_client.delete("players", object_id=999) - + assert result is False - + @pytest.mark.asyncio async def test_delete_request_200_success(self, api_client): """Test DELETE request with 200 success.""" with aioresponses() as m: - m.delete( - "https://api.example.com/v3/players/1", - status=200 - ) - + m.delete("https://api.example.com/v3/players/1", status=200) + result = await api_client.delete("players", object_id=1) - + assert result is True class TestAPIClientHelpers: """Test API client helper functions.""" - + @pytest.fixture def mock_config(self): """Mock configuration for testing.""" @@ -223,49 +212,49 @@ class TestAPIClientHelpers: config.db_url = "https://api.example.com" config.api_token = "test-token" return config - + @pytest.mark.asyncio async def test_get_api_client_context_manager(self, mock_config): """Test get_api_client context manager.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): with aioresponses() as m: m.get( "https://api.example.com/v3/test", payload={"success": True}, - status=200 + status=200, ) - + async with get_api_client() as client: assert isinstance(client, APIClient) result = await client.get("test") assert result == {"success": True} - + @pytest.mark.asyncio async def test_global_client_management(self, mock_config): """Test global client getter and cleanup.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): # Get global client client1 = await get_global_client() client2 = await get_global_client() - + # Should return same instance assert client1 is client2 assert isinstance(client1, APIClient) - + # Test cleanup await cleanup_global_client() - + # New client should be different instance client3 = await get_global_client() assert client3 is not client1 - + # Clean up for other tests await cleanup_global_client() class TestIntegrationScenarios: """Test realistic integration scenarios.""" - + @pytest.fixture def mock_config(self): """Mock configuration for testing.""" @@ -273,11 +262,11 @@ class TestIntegrationScenarios: config.db_url = "https://api.example.com" config.api_token = "test-token" return config - + @pytest.mark.asyncio async def test_player_retrieval_with_team_lookup(self, mock_config): """Test realistic scenario: get player with team data.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): with aioresponses() as m: # Mock player data response player_data = { @@ -287,43 +276,41 @@ class TestIntegrationScenarios: "season": 12, "team_id": 5, "image": "https://example.com/player1.jpg", - "pos_1": "C" + "pos_1": "C", } m.get( "https://api.example.com/v3/players/1", payload=player_data, - status=200 + status=200, ) - + # Mock team data response team_data = { "id": 5, "abbrev": "TST", "sname": "Test Team", "lname": "Test Team Full Name", - "season": 12 + "season": 12, } m.get( - "https://api.example.com/v3/teams/5", - payload=team_data, - status=200 + "https://api.example.com/v3/teams/5", payload=team_data, status=200 ) - + client = APIClient() - + # Get player player = await client.get("players", object_id=1) assert player["name"] == "Test Player" assert player["team_id"] == 5 - + # Get team for player team = await client.get("teams", object_id=player["team_id"]) assert team["sname"] == "Test Team" - + @pytest.mark.asyncio async def test_api_response_format_handling(self, mock_config): """Test handling of the API's count + list format.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): with aioresponses() as m: # Mock API response with count format api_response = { @@ -336,7 +323,7 @@ class TestIntegrationScenarios: "season": 12, "team_id": 5, "image": "https://example.com/player1.jpg", - "pos_1": "C" + "pos_1": "C", }, { "id": 2, @@ -345,93 +332,93 @@ class TestIntegrationScenarios: "season": 12, "team_id": 6, "image": "https://example.com/player2.jpg", - "pos_1": "1B" - } - ] + "pos_1": "1B", + }, + ], } - + m.get( "https://api.example.com/v3/players?team_id=5", payload=api_response, - status=200 + status=200, ) - + client = APIClient() result = await client.get("players", params=[("team_id", "5")]) - + assert result["count"] == 25 assert len(result["players"]) == 2 assert result["players"][0]["name"] == "Player 1" - + @pytest.mark.asyncio async def test_error_recovery_scenarios(self, mock_config): """Test error handling and recovery.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): with aioresponses() as m: # First request fails with 500 m.get( "https://api.example.com/v3/players/1", status=500, - body="Internal Server Error" + body="Internal Server Error", ) - + # Second request succeeds m.get( "https://api.example.com/v3/players/2", payload={"id": 2, "name": "Working Player"}, - status=200 + status=200, ) - + client = APIClient() - + # First request should raise exception with pytest.raises(APIException, match="API request failed"): await client.get("players", object_id=1) - + # Second request should work fine result = await client.get("players", object_id=2) assert result["name"] == "Working Player" - + # Client should still be functional await client.close() - + @pytest.mark.asyncio async def test_concurrent_requests(self, mock_config): """Test multiple concurrent requests.""" import asyncio - - with patch('api.client.get_config', return_value=mock_config): + + with patch("api.client.get_config", return_value=mock_config): with aioresponses() as m: # Mock multiple endpoints for i in range(1, 4): m.get( f"https://api.example.com/v3/players/{i}", payload={"id": i, "name": f"Player {i}"}, - status=200 + status=200, ) - + client = APIClient() - + # Make concurrent requests tasks = [ client.get("players", object_id=1), client.get("players", object_id=2), - client.get("players", object_id=3) + client.get("players", object_id=3), ] - + results = await asyncio.gather(*tasks) - + assert len(results) == 3 assert results[0]["name"] == "Player 1" assert results[1]["name"] == "Player 2" assert results[2]["name"] == "Player 3" - + await client.close() class TestAPIClientCoverageExtras: """Additional coverage tests for API client edge cases.""" - + @pytest.fixture def mock_config(self): """Mock configuration for testing.""" @@ -439,98 +426,104 @@ class TestAPIClientCoverageExtras: config.db_url = "https://api.example.com" config.api_token = "test-token" return config - + @pytest.mark.asyncio async def test_global_client_cleanup_when_none(self): """Test cleanup when no global client exists.""" # Ensure no global client exists await cleanup_global_client() - + # Should not raise error await cleanup_global_client() - + @pytest.mark.asyncio async def test_url_building_edge_cases(self, mock_config): """Test URL building with various edge cases.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): client = APIClient() - + # Test trailing slash handling client.base_url = "https://api.example.com/" url = client._build_url("players") assert url == "https://api.example.com/v3/players" assert "//" not in url.replace("https://", "") - + @pytest.mark.asyncio async def test_parameter_handling_edge_cases(self, mock_config): """Test parameter handling with various scenarios.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): client = APIClient() - + # Test with existing query string - url = client._add_params("https://example.com/api?existing=true", [("new", "param")]) + url = client._add_params( + "https://example.com/api?existing=true", [("new", "param")] + ) assert url == "https://example.com/api?existing=true&new=param" - + # Test with no parameters url = client._add_params("https://example.com/api") assert url == "https://example.com/api" - + @pytest.mark.asyncio async def test_timeout_error_handling(self, mock_config): """Test timeout error handling using aioresponses.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): client = APIClient() - + # Test timeout using aioresponses exception parameter with aioresponses() as m: m.get( "https://api.example.com/v3/players", - exception=asyncio.TimeoutError("Request timed out") + exception=asyncio.TimeoutError("Request timed out"), ) - - with pytest.raises(APIException, match="API call failed.*Request timed out"): + + with pytest.raises( + APIException, match="API call failed.*Request timed out" + ): await client.get("players") - + await client.close() - + @pytest.mark.asyncio async def test_generic_exception_handling(self, mock_config): """Test generic exception handling.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): client = APIClient() - + # Test generic exception with aioresponses() as m: m.get( "https://api.example.com/v3/players", - exception=Exception("Generic error") + exception=Exception("Generic error"), ) - - with pytest.raises(APIException, match="API call failed.*Generic error"): + + with pytest.raises( + APIException, match="API call failed.*Generic error" + ): await client.get("players") - + await client.close() - + @pytest.mark.asyncio async def test_session_closed_handling(self, mock_config): """Test handling of closed session.""" - with patch('api.client.get_config', return_value=mock_config): + with patch("api.client.get_config", return_value=mock_config): # Test that the client recreates session when needed with aioresponses() as m: m.get( "https://api.example.com/v3/players", payload={"success": True}, - status=200 + status=200, ) - + client = APIClient() - + # Close the session manually await client._ensure_session() await client._session.close() - + # Client should recreate session and work fine result = await client.get("players") assert result == {"success": True} - - await client.close() \ No newline at end of file + + await client.close() diff --git a/tests/test_services_base_service.py b/tests/test_services_base_service.py index 4ef6602..19a3d42 100644 --- a/tests/test_services_base_service.py +++ b/tests/test_services_base_service.py @@ -1,6 +1,7 @@ """ Tests for BaseService functionality """ + import pytest from unittest.mock import AsyncMock @@ -10,6 +11,7 @@ from models.base import SBABaseModel class MockModel(SBABaseModel): """Mock model for testing BaseService.""" + id: int name: str value: int = 100 @@ -17,240 +19,229 @@ class MockModel(SBABaseModel): class TestBaseService: """Test BaseService functionality.""" - + @pytest.fixture def mock_client(self): """Mock API client.""" client = AsyncMock() return client - + @pytest.fixture def base_service(self, mock_client): """Create BaseService instance for testing.""" - service = BaseService(MockModel, 'mocks', client=mock_client) + service = BaseService(MockModel, "mocks", client=mock_client) return service - + @pytest.mark.asyncio async def test_init(self): """Test service initialization.""" - service = BaseService(MockModel, 'test_endpoint') + service = BaseService(MockModel, "test_endpoint") assert service.model_class == MockModel - assert service.endpoint == 'test_endpoint' + assert service.endpoint == "test_endpoint" assert service._client is None - + @pytest.mark.asyncio async def test_get_by_id_success(self, base_service, mock_client): """Test successful get_by_id.""" - mock_data = {'id': 1, 'name': 'Test', 'value': 200} + mock_data = {"id": 1, "name": "Test", "value": 200} mock_client.get.return_value = mock_data - + result = await base_service.get_by_id(1) - + assert isinstance(result, MockModel) assert result.id == 1 - assert result.name == 'Test' + assert result.name == "Test" assert result.value == 200 - mock_client.get.assert_called_once_with('mocks', object_id=1) - + mock_client.get.assert_called_once_with("mocks", object_id=1) + @pytest.mark.asyncio async def test_get_by_id_not_found(self, base_service, mock_client): """Test get_by_id when object not found.""" mock_client.get.return_value = None - + result = await base_service.get_by_id(999) - + assert result is None - mock_client.get.assert_called_once_with('mocks', object_id=999) - + mock_client.get.assert_called_once_with("mocks", object_id=999) + @pytest.mark.asyncio async def test_get_all_with_count(self, base_service, mock_client): """Test get_all with count response format.""" mock_data = { - 'count': 2, - 'mocks': [ - {'id': 1, 'name': 'Test1', 'value': 100}, - {'id': 2, 'name': 'Test2', 'value': 200} - ] + "count": 2, + "mocks": [ + {"id": 1, "name": "Test1", "value": 100}, + {"id": 2, "name": "Test2", "value": 200}, + ], } mock_client.get.return_value = mock_data - + result, count = await base_service.get_all() - + assert len(result) == 2 assert count == 2 assert all(isinstance(item, MockModel) for item in result) - mock_client.get.assert_called_once_with('mocks', params=None) - + mock_client.get.assert_called_once_with("mocks", params=None) + @pytest.mark.asyncio async def test_get_all_items_convenience(self, base_service, mock_client): """Test get_all_items convenience method.""" - mock_data = { - 'count': 1, - 'mocks': [{'id': 1, 'name': 'Test', 'value': 100}] - } + mock_data = {"count": 1, "mocks": [{"id": 1, "name": "Test", "value": 100}]} mock_client.get.return_value = mock_data - + result = await base_service.get_all_items() - + assert len(result) == 1 assert isinstance(result[0], MockModel) - + @pytest.mark.asyncio async def test_create_success(self, base_service, mock_client): """Test successful object creation.""" - input_data = {'name': 'New Item', 'value': 300} - response_data = {'id': 3, 'name': 'New Item', 'value': 300} + input_data = {"name": "New Item", "value": 300} + response_data = {"id": 3, "name": "New Item", "value": 300} mock_client.post.return_value = response_data - + result = await base_service.create(input_data) - + assert isinstance(result, MockModel) assert result.id == 3 - assert result.name == 'New Item' - mock_client.post.assert_called_once_with('mocks', input_data) - + assert result.name == "New Item" + mock_client.post.assert_called_once_with("mocks/", input_data) + @pytest.mark.asyncio async def test_update_success(self, base_service, mock_client): """Test successful object update.""" - update_data = {'name': 'Updated'} - response_data = {'id': 1, 'name': 'Updated', 'value': 100} + update_data = {"name": "Updated"} + response_data = {"id": 1, "name": "Updated", "value": 100} mock_client.put.return_value = response_data - + result = await base_service.update(1, update_data) - + assert isinstance(result, MockModel) - assert result.name == 'Updated' - mock_client.put.assert_called_once_with('mocks', update_data, object_id=1) - + assert result.name == "Updated" + mock_client.put.assert_called_once_with("mocks", update_data, object_id=1) + @pytest.mark.asyncio async def test_delete_success(self, base_service, mock_client): """Test successful object deletion.""" mock_client.delete.return_value = True - + result = await base_service.delete(1) - + assert result is True - mock_client.delete.assert_called_once_with('mocks', object_id=1) - - + mock_client.delete.assert_called_once_with("mocks", object_id=1) + @pytest.mark.asyncio async def test_get_by_field(self, base_service, mock_client): """Test get_by_field functionality.""" - mock_data = { - 'count': 1, - 'mocks': [{'id': 1, 'name': 'Test', 'value': 100}] - } + mock_data = {"count": 1, "mocks": [{"id": 1, "name": "Test", "value": 100}]} mock_client.get.return_value = mock_data - - result = await base_service.get_by_field('name', 'Test') - + + result = await base_service.get_by_field("name", "Test") + assert len(result) == 1 - mock_client.get.assert_called_once_with('mocks', params=[('name', 'Test')]) - + mock_client.get.assert_called_once_with("mocks", params=[("name", "Test")]) + def test_extract_items_and_count_standard_format(self, base_service): """Test response parsing for standard format.""" data = { - 'count': 3, - 'mocks': [ - {'id': 1, 'name': 'Test1'}, - {'id': 2, 'name': 'Test2'}, - {'id': 3, 'name': 'Test3'} - ] + "count": 3, + "mocks": [ + {"id": 1, "name": "Test1"}, + {"id": 2, "name": "Test2"}, + {"id": 3, "name": "Test3"}, + ], } - + items, count = base_service._extract_items_and_count_from_response(data) - + assert len(items) == 3 assert count == 3 - assert items[0]['name'] == 'Test1' - + assert items[0]["name"] == "Test1" + def test_extract_items_and_count_single_object(self, base_service): """Test response parsing for single object.""" - data = {'id': 1, 'name': 'Single'} - + data = {"id": 1, "name": "Single"} + items, count = base_service._extract_items_and_count_from_response(data) - + assert len(items) == 1 assert count == 1 assert items[0] == data - + def test_extract_items_and_count_direct_list(self, base_service): """Test response parsing for direct list.""" - data = [ - {'id': 1, 'name': 'Test1'}, - {'id': 2, 'name': 'Test2'} - ] - + data = [{"id": 1, "name": "Test1"}, {"id": 2, "name": "Test2"}] + items, count = base_service._extract_items_and_count_from_response(data) - + assert len(items) == 2 assert count == 2 class TestBaseServiceExtras: """Additional coverage tests for BaseService edge cases.""" - + @pytest.mark.asyncio async def test_base_service_additional_methods(self): """Test additional BaseService methods for coverage.""" from services.base_service import BaseService from models.base import SBABaseModel - + class TestModel(SBABaseModel): name: str value: int = 100 - + mock_client = AsyncMock() - service = BaseService(TestModel, 'test', client=mock_client) - - + service = BaseService(TestModel, "test", client=mock_client) + # Test count method mock_client.reset_mock() - mock_client.get.return_value = {'count': 42, 'test': []} - count = await service.count(params=[('active', 'true')]) + mock_client.get.return_value = {"count": 42, "test": []} + count = await service.count(params=[("active", "true")]) assert count == 42 - + # Test update_from_model with ID mock_client.reset_mock() model = TestModel(id=1, name="Updated", value=300) mock_client.put.return_value = {"id": 1, "name": "Updated", "value": 300} result = await service.update_from_model(model) assert result.name == "Updated" - + # Test update_from_model without ID model_no_id = TestModel(name="Test") with pytest.raises(ValueError, match="Cannot update TestModel without ID"): await service.update_from_model(model_no_id) - + def test_base_service_response_parsing_edge_cases(self): """Test edge cases in response parsing.""" from services.base_service import BaseService from models.base import SBABaseModel - + class TestModel(SBABaseModel): name: str - - service = BaseService(TestModel, 'test') - + + service = BaseService(TestModel, "test") + # Test with 'items' field - data = {'count': 2, 'items': [{'name': 'Item1'}, {'name': 'Item2'}]} + data = {"count": 2, "items": [{"name": "Item1"}, {"name": "Item2"}]} items, count = service._extract_items_and_count_from_response(data) assert len(items) == 2 assert count == 2 - + # Test with 'data' field - data = {'count': 1, 'data': [{'name': 'DataItem'}]} + data = {"count": 1, "data": [{"name": "DataItem"}]} items, count = service._extract_items_and_count_from_response(data) assert len(items) == 1 assert count == 1 - + # Test with count but no recognizable list field - data = {'count': 5, 'unknown_field': [{'name': 'Item'}]} + data = {"count": 5, "unknown_field": [{"name": "Item"}]} items, count = service._extract_items_and_count_from_response(data) assert len(items) == 0 assert count == 5 - + # Test with unexpected data type items, count = service._extract_items_and_count_from_response("unexpected") assert len(items) == 0 - assert count == 0 \ No newline at end of file + assert count == 0 diff --git a/views/trade_embed.py b/views/trade_embed.py index 3b4406c..d3f7329 100644 --- a/views/trade_embed.py +++ b/views/trade_embed.py @@ -39,7 +39,7 @@ class TradeEmbedView(discord.ui.View): """Check if user has permission to interact with this view.""" if interaction.user.id != self.user_id: await interaction.response.send_message( - "❌ You don't have permission to use this trade builder.", + "You don't have permission to use this trade builder.", ephemeral=True, ) return False @@ -47,57 +47,48 @@ class TradeEmbedView(discord.ui.View): async def on_timeout(self) -> None: """Handle view timeout.""" - # Disable all buttons when timeout occurs for item in self.children: if isinstance(item, discord.ui.Button): item.disabled = True - @discord.ui.button(label="Remove Move", style=discord.ButtonStyle.red, emoji="➖") + @discord.ui.button(label="Remove Move", style=discord.ButtonStyle.red) async def remove_move_button( self, interaction: discord.Interaction, button: discord.ui.Button ): """Handle remove move button click.""" if self.builder.is_empty: await interaction.response.send_message( - "❌ No moves to remove. Add some moves first!", ephemeral=True + "No moves to remove. Add some moves first!", ephemeral=True ) return - # Create select menu for move removal select_view = RemoveTradeMovesView(self.builder, self.user_id) embed = await create_trade_embed(self.builder) await interaction.response.edit_message(embed=embed, view=select_view) - @discord.ui.button( - label="Validate Trade", style=discord.ButtonStyle.secondary, emoji="🔍" - ) + @discord.ui.button(label="Validate Trade", style=discord.ButtonStyle.secondary) async def validate_button( self, interaction: discord.Interaction, button: discord.ui.Button ): """Handle validate trade button click.""" await interaction.response.defer(ephemeral=True) - # Perform detailed validation validation = await self.builder.validate_trade() - # Create validation report if validation.is_legal: - status_emoji = "✅" status_text = "**Trade is LEGAL**" color = EmbedColors.SUCCESS else: - status_emoji = "❌" status_text = "**Trade has ERRORS**" color = EmbedColors.ERROR embed = EmbedTemplate.create_base_embed( - title=f"{status_emoji} Trade Validation Report", + title="Trade Validation Report", description=status_text, color=color, ) - # Add team-by-team validation for participant in self.builder.trade.participants: team_validation = validation.get_participant_validation(participant.team.id) if team_validation: @@ -111,59 +102,52 @@ class TradeEmbedView(discord.ui.View): team_status.append(team_validation.pre_existing_transactions_note) embed.add_field( - name=f"🏟️ {participant.team.abbrev} - {participant.team.sname}", + name=f"{participant.team.abbrev} - {participant.team.sname}", value="\n".join(team_status), inline=False, ) - # Add overall errors and suggestions if validation.all_errors: - error_text = "\n".join([f"• {error}" for error in validation.all_errors]) - embed.add_field(name="❌ Errors", value=error_text, inline=False) + error_text = "\n".join([f"- {error}" for error in validation.all_errors]) + embed.add_field(name="Errors", value=error_text, inline=False) if validation.all_suggestions: suggestion_text = "\n".join( - [f"💡 {suggestion}" for suggestion in validation.all_suggestions] + [f"- {suggestion}" for suggestion in validation.all_suggestions] ) - embed.add_field(name="💡 Suggestions", value=suggestion_text, inline=False) + embed.add_field(name="Suggestions", value=suggestion_text, inline=False) await interaction.followup.send(embed=embed, ephemeral=True) - @discord.ui.button( - label="Submit Trade", style=discord.ButtonStyle.primary, emoji="📤" - ) + @discord.ui.button(label="Submit Trade", style=discord.ButtonStyle.primary) async def submit_button( self, interaction: discord.Interaction, button: discord.ui.Button ): """Handle submit trade button click.""" if self.builder.is_empty: await interaction.response.send_message( - "❌ Cannot submit empty trade. Add some moves first!", ephemeral=True + "Cannot submit empty trade. Add some moves first!", ephemeral=True ) return - # Validate before submission validation = await self.builder.validate_trade() if not validation.is_legal: - error_msg = "❌ **Cannot submit illegal trade:**\n" - error_msg += "\n".join([f"• {error}" for error in validation.all_errors]) + error_msg = "**Cannot submit illegal trade:**\n" + error_msg += "\n".join([f"- {error}" for error in validation.all_errors]) if validation.all_suggestions: error_msg += "\n\n**Suggestions:**\n" error_msg += "\n".join( - [f"💡 {suggestion}" for suggestion in validation.all_suggestions] + [f"- {suggestion}" for suggestion in validation.all_suggestions] ) await interaction.response.send_message(error_msg, ephemeral=True) return - # Show confirmation modal modal = SubmitTradeConfirmationModal(self.builder) await interaction.response.send_modal(modal) - @discord.ui.button( - label="Cancel Trade", style=discord.ButtonStyle.secondary, emoji="❌" - ) + @discord.ui.button(label="Cancel Trade", style=discord.ButtonStyle.secondary) async def cancel_button( self, interaction: discord.Interaction, button: discord.ui.Button ): @@ -171,13 +155,12 @@ class TradeEmbedView(discord.ui.View): self.builder.clear_trade() embed = await create_trade_embed(self.builder) - # Disable all buttons after cancellation for item in self.children: if isinstance(item, discord.ui.Button): item.disabled = True await interaction.response.edit_message( - content="❌ **Trade cancelled and cleared.**", embed=embed, view=self + content="**Trade cancelled and cleared.**", embed=embed, view=self ) self.stop() @@ -190,13 +173,11 @@ class RemoveTradeMovesView(discord.ui.View): self.builder = builder self.user_id = user_id - # Create select menu with current moves if not builder.is_empty: self.add_item(RemoveTradeMovesSelect(builder)) - # Add back button back_button = discord.ui.Button( - label="Back", style=discord.ButtonStyle.secondary, emoji="⬅️" + label="Back", style=discord.ButtonStyle.secondary ) back_button.callback = self.back_callback self.add_item(back_button) @@ -218,25 +199,21 @@ class RemoveTradeMovesSelect(discord.ui.Select): def __init__(self, builder: TradeBuilder): self.builder = builder - # Create options from all moves (cross-team and supplementary) options = [] move_count = 0 - # Add cross-team moves for move in builder.trade.cross_team_moves[ :20 ]: # Limit to avoid Discord's 25 option limit options.append( discord.SelectOption( label=f"{move.player.name}", - description=move.description[:100], # Discord description limit + description=move.description[:100], value=str(move.player.id), - emoji="🔄", ) ) move_count += 1 - # Add supplementary moves if there's room remaining_slots = 25 - move_count for move in builder.trade.supplementary_moves[:remaining_slots]: options.append( @@ -244,7 +221,6 @@ class RemoveTradeMovesSelect(discord.ui.Select): label=f"{move.player.name}", description=move.description[:100], value=str(move.player.id), - emoji="⚙️", ) ) @@ -263,18 +239,16 @@ class RemoveTradeMovesSelect(discord.ui.Select): if success: await interaction.response.send_message( - f"✅ Removed move for player ID {player_id}", ephemeral=True + f"Removed move for player ID {player_id}", ephemeral=True ) - # Update the embed main_view = TradeEmbedView(self.builder, interaction.user.id) embed = await create_trade_embed(self.builder) - # Edit the original message await interaction.edit_original_response(embed=embed, view=main_view) else: await interaction.response.send_message( - f"❌ Could not remove move: {error_msg}", ephemeral=True + f"Could not remove move: {error_msg}", ephemeral=True ) @@ -301,7 +275,7 @@ class SubmitTradeConfirmationModal(discord.ui.Modal): """Handle confirmation submission - posts acceptance view to trade channel.""" if self.confirmation.value.upper() != "CONFIRM": await interaction.response.send_message( - "❌ Trade not submitted. You must type 'CONFIRM' exactly.", + "Trade not submitted. You must type 'CONFIRM' exactly.", ephemeral=True, ) return @@ -309,18 +283,13 @@ class SubmitTradeConfirmationModal(discord.ui.Modal): await interaction.response.defer(ephemeral=True) try: - # Update trade status to PROPOSED self.builder.trade.status = TradeStatus.PROPOSED - # Create acceptance embed and view acceptance_embed = await create_trade_acceptance_embed(self.builder) acceptance_view = TradeAcceptanceView(self.builder) - # Find the trade channel to post to channel = self.trade_channel if not channel: - # Try to find trade channel by name pattern - trade_channel_name = f"trade-{'-'.join(t.abbrev.lower() for t in self.builder.participating_teams)}" for ch in interaction.guild.text_channels: # type: ignore if ( ch.name.startswith("trade-") @@ -330,28 +299,26 @@ class SubmitTradeConfirmationModal(discord.ui.Modal): break if channel: - # Post acceptance request to trade channel await channel.send( - content="📋 **Trade submitted for approval!** All teams must accept to complete the trade.", + content="**Trade submitted for approval.** All teams must accept to complete the trade.", embed=acceptance_embed, view=acceptance_view, ) await interaction.followup.send( - f"✅ Trade submitted for approval!\n\nThe acceptance request has been posted to {channel.mention}.\n" + f"Trade submitted for approval.\n\nThe acceptance request has been posted to {channel.mention}.\n" f"All participating teams must click **Accept Trade** to finalize.", ephemeral=True, ) else: - # No trade channel found, post in current channel await interaction.followup.send( - content="📋 **Trade submitted for approval!** All teams must accept to complete the trade.", + content="**Trade submitted for approval.** All teams must accept to complete the trade.", embed=acceptance_embed, view=acceptance_view, ) except Exception as e: await interaction.followup.send( - f"❌ Error submitting trade: {str(e)}", ephemeral=True + f"Error submitting trade: {str(e)}", ephemeral=True ) @@ -375,15 +342,14 @@ class TradeAcceptanceView(discord.ui.View): if not user_team: await interaction.response.send_message( - "❌ You don't own a team in the league.", ephemeral=True + "You don't own a team in the league.", ephemeral=True ) return False - # Check if their team (or organization) is participating participant = self.builder.trade.get_participant_by_organization(user_team) if not participant: await interaction.response.send_message( - "❌ Your team is not part of this trade.", ephemeral=True + "Your team is not part of this trade.", ephemeral=True ) return False @@ -395,9 +361,7 @@ class TradeAcceptanceView(discord.ui.View): if isinstance(item, discord.ui.Button): item.disabled = True - @discord.ui.button( - label="Accept Trade", style=discord.ButtonStyle.success, emoji="✅" - ) + @discord.ui.button(label="Accept Trade", style=discord.ButtonStyle.success) async def accept_button( self, interaction: discord.Interaction, button: discord.ui.Button ): @@ -406,41 +370,33 @@ class TradeAcceptanceView(discord.ui.View): if not user_team: return - # Find the participating team (could be org affiliate) participant = self.builder.trade.get_participant_by_organization(user_team) if not participant: return team_id = participant.team.id - # Check if already accepted if self.builder.has_team_accepted(team_id): await interaction.response.send_message( - f"✅ {participant.team.abbrev} has already accepted this trade.", + f"{participant.team.abbrev} has already accepted this trade.", ephemeral=True, ) return - # Record acceptance all_accepted = self.builder.accept_trade(team_id) if all_accepted: - # All teams accepted - finalize the trade await self._finalize_trade(interaction) else: - # Update embed to show new acceptance status embed = await create_trade_acceptance_embed(self.builder) await interaction.response.edit_message(embed=embed, view=self) - # Send confirmation to channel await interaction.followup.send( - f"✅ **{participant.team.abbrev}** has accepted the trade! " + f"**{participant.team.abbrev}** has accepted the trade. " f"({len(self.builder.accepted_teams)}/{self.builder.team_count} teams)" ) - @discord.ui.button( - label="Reject Trade", style=discord.ButtonStyle.danger, emoji="❌" - ) + @discord.ui.button(label="Reject Trade", style=discord.ButtonStyle.danger) async def reject_button( self, interaction: discord.Interaction, button: discord.ui.Button ): @@ -453,20 +409,16 @@ class TradeAcceptanceView(discord.ui.View): if not participant: return - # Reject the trade self.builder.reject_trade() - # Disable buttons self.accept_button.disabled = True self.reject_button.disabled = True - # Update embed to show rejection embed = await create_trade_rejection_embed(self.builder, participant.team) await interaction.response.edit_message(embed=embed, view=self) - # Notify the channel await interaction.followup.send( - f"❌ **{participant.team.abbrev}** has rejected the trade.\n\n" + f"**{participant.team.abbrev}** has rejected the trade.\n\n" f"The trade has been moved back to **DRAFT** status. " f"Teams can continue negotiating using `/trade` commands." ) @@ -480,11 +432,9 @@ class TradeAcceptanceView(discord.ui.View): config = get_config() - # Get next week for transactions current = await league_service.get_current_state() next_week = current.week + 1 if current else 1 - # Create FA team for reference fa_team = Team( id=config.free_agent_team_id, abbrev="FA", @@ -493,13 +443,10 @@ class TradeAcceptanceView(discord.ui.View): season=self.builder.trade.season, ) # type: ignore - # Create transactions from all moves transactions: List[Transaction] = [] move_id = f"Trade-{self.builder.trade_id}-{int(datetime.now(timezone.utc).timestamp())}" - # Process cross-team moves for move in self.builder.trade.cross_team_moves: - # Get actual team affiliates for from/to based on roster type if move.from_roster == RosterType.MAJOR_LEAGUE: old_team = move.source_team elif move.from_roster == RosterType.MINOR_LEAGUE: @@ -544,11 +491,10 @@ class TradeAcceptanceView(discord.ui.View): oldteam=old_team, newteam=new_team, cancelled=False, - frozen=False, # Trades are NOT frozen - immediately effective + frozen=False, ) transactions.append(transaction) - # Process supplementary moves for move in self.builder.trade.supplementary_moves: if move.from_roster == RosterType.MAJOR_LEAGUE: old_team = move.source_team @@ -598,11 +544,10 @@ class TradeAcceptanceView(discord.ui.View): oldteam=old_team, newteam=new_team, cancelled=False, - frozen=False, # Trades are NOT frozen - immediately effective + frozen=False, ) transactions.append(transaction) - # POST transactions to database if transactions: created_transactions = ( await transaction_service.create_transaction_batch(transactions) @@ -610,7 +555,6 @@ class TradeAcceptanceView(discord.ui.View): else: created_transactions = [] - # Post to #transaction-log channel if created_transactions and interaction.client: await post_trade_to_log( bot=interaction.client, @@ -619,28 +563,23 @@ class TradeAcceptanceView(discord.ui.View): effective_week=next_week, ) - # Update trade status self.builder.trade.status = TradeStatus.ACCEPTED - # Disable buttons self.accept_button.disabled = True self.reject_button.disabled = True - # Update embed to show completion embed = await create_trade_complete_embed( self.builder, len(created_transactions), next_week ) await interaction.edit_original_response(embed=embed, view=self) - # Send completion message await interaction.followup.send( - f"🎉 **Trade Complete!**\n\n" + f"**Trade Complete!**\n\n" f"All {self.builder.team_count} teams have accepted the trade.\n" f"**{len(created_transactions)} transactions** have been created for **Week {next_week}**.\n\n" f"Trade ID: `{self.builder.trade_id}`" ) - # Clear the trade builder for team in self.builder.participating_teams: clear_trade_builder_by_team(team.id) @@ -648,69 +587,64 @@ class TradeAcceptanceView(discord.ui.View): except Exception as e: await interaction.followup.send( - f"❌ Error finalizing trade: {str(e)}", ephemeral=True + f"Error finalizing trade: {str(e)}", ephemeral=True ) async def create_trade_acceptance_embed(builder: TradeBuilder) -> discord.Embed: """Create embed showing trade details and acceptance status.""" embed = EmbedTemplate.create_base_embed( - title=f"📋 Trade Pending Acceptance - {builder.trade.get_trade_summary()}", + title=f"Trade Pending Acceptance - {builder.trade.get_trade_summary()}", description="All participating teams must accept to complete the trade.", color=EmbedColors.WARNING, ) - # Show participating teams team_list = [ - f"• {team.abbrev} - {team.sname}" for team in builder.participating_teams + f"- {team.abbrev} - {team.sname}" for team in builder.participating_teams ] embed.add_field( - name=f"🏟️ Participating Teams ({builder.team_count})", + name=f"Participating Teams ({builder.team_count})", value="\n".join(team_list), inline=False, ) - # Show cross-team moves if builder.trade.cross_team_moves: moves_text = "" for move in builder.trade.cross_team_moves[:10]: - moves_text += f"• {move.description}\n" + moves_text += f"- {move.description}\n" if len(builder.trade.cross_team_moves) > 10: moves_text += f"... and {len(builder.trade.cross_team_moves) - 10} more" embed.add_field( - name=f"🔄 Player Exchanges ({len(builder.trade.cross_team_moves)})", + name=f"Player Exchanges ({len(builder.trade.cross_team_moves)})", value=moves_text, inline=False, ) - # Show supplementary moves if any if builder.trade.supplementary_moves: supp_text = "" for move in builder.trade.supplementary_moves[:5]: - supp_text += f"• {move.description}\n" + supp_text += f"- {move.description}\n" if len(builder.trade.supplementary_moves) > 5: supp_text += f"... and {len(builder.trade.supplementary_moves) - 5} more" embed.add_field( - name=f"⚙️ Supplementary Moves ({len(builder.trade.supplementary_moves)})", + name=f"Supplementary Moves ({len(builder.trade.supplementary_moves)})", value=supp_text, inline=False, ) - # Show acceptance status status_lines = [] for team in builder.participating_teams: if team.id in builder.accepted_teams: - status_lines.append(f"✅ **{team.abbrev}** - Accepted") + status_lines.append(f"**{team.abbrev}** - Accepted") else: - status_lines.append(f"⏳ **{team.abbrev}** - Pending") + status_lines.append(f"**{team.abbrev}** - Pending") embed.add_field( - name="📊 Acceptance Status", value="\n".join(status_lines), inline=False + name="Acceptance Status", value="\n".join(status_lines), inline=False ) - # Add footer embed.set_footer( - text=f"Trade ID: {builder.trade_id} • {len(builder.accepted_teams)}/{builder.team_count} teams accepted" + text=f"Trade ID: {builder.trade_id} | {len(builder.accepted_teams)}/{builder.team_count} teams accepted" ) return embed @@ -721,7 +655,7 @@ async def create_trade_rejection_embed( ) -> discord.Embed: """Create embed showing trade was rejected.""" embed = EmbedTemplate.create_base_embed( - title=f"❌ Trade Rejected - {builder.trade.get_trade_summary()}", + title=f"Trade Rejected - {builder.trade.get_trade_summary()}", description=f"**{rejecting_team.abbrev}** has rejected the trade.\n\n" f"The trade has been moved back to **DRAFT** status.\n" f"Teams can continue negotiating using `/trade` commands.", @@ -738,29 +672,27 @@ async def create_trade_complete_embed( ) -> discord.Embed: """Create embed showing trade was completed.""" embed = EmbedTemplate.create_base_embed( - title=f"🎉 Trade Complete! - {builder.trade.get_trade_summary()}", - description=f"All {builder.team_count} teams have accepted the trade!\n\n" + title=f"Trade Complete - {builder.trade.get_trade_summary()}", + description=f"All {builder.team_count} teams have accepted the trade.\n\n" f"**{transaction_count} transactions** created for **Week {effective_week}**.", color=EmbedColors.SUCCESS, ) - # Show final acceptance status (all green) status_lines = [ - f"✅ **{team.abbrev}** - Accepted" for team in builder.participating_teams + f"**{team.abbrev}** - Accepted" for team in builder.participating_teams ] - embed.add_field(name="📊 Final Status", value="\n".join(status_lines), inline=False) + embed.add_field(name="Final Status", value="\n".join(status_lines), inline=False) - # Show cross-team moves if builder.trade.cross_team_moves: moves_text = "" for move in builder.trade.cross_team_moves[:8]: - moves_text += f"• {move.description}\n" + moves_text += f"- {move.description}\n" if len(builder.trade.cross_team_moves) > 8: moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more" - embed.add_field(name=f"🔄 Player Exchanges", value=moves_text, inline=False) + embed.add_field(name="Player Exchanges", value=moves_text, inline=False) embed.set_footer( - text=f"Trade ID: {builder.trade_id} • Effective: Week {effective_week}" + text=f"Trade ID: {builder.trade_id} | Effective: Week {effective_week}" ) return embed @@ -776,7 +708,6 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed: Returns: Discord embed with current trade state """ - # Determine embed color based on trade status if builder.is_empty: color = EmbedColors.SECONDARY else: @@ -784,22 +715,20 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed: color = EmbedColors.SUCCESS if validation.is_legal else EmbedColors.WARNING embed = EmbedTemplate.create_base_embed( - title=f"📋 Trade Builder - {builder.trade.get_trade_summary()}", - description=f"Build your multi-team trade", + title=f"Trade Builder - {builder.trade.get_trade_summary()}", + description="Build your multi-team trade", color=color, ) - # Add participating teams section team_list = [ - f"• {team.abbrev} - {team.sname}" for team in builder.participating_teams + f"- {team.abbrev} - {team.sname}" for team in builder.participating_teams ] embed.add_field( - name=f"🏟️ Participating Teams ({builder.team_count})", + name=f"Participating Teams ({builder.team_count})", value="\n".join(team_list) if team_list else "*No teams yet*", inline=False, ) - # Add current moves section if builder.is_empty: embed.add_field( name="Current Moves", @@ -807,29 +736,23 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed: inline=False, ) else: - # Show cross-team moves if builder.trade.cross_team_moves: moves_text = "" - for i, move in enumerate( - builder.trade.cross_team_moves[:8], 1 - ): # Limit display + for i, move in enumerate(builder.trade.cross_team_moves[:8], 1): moves_text += f"{i}. {move.description}\n" if len(builder.trade.cross_team_moves) > 8: moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more" embed.add_field( - name=f"🔄 Player Exchanges ({len(builder.trade.cross_team_moves)})", + name=f"Player Exchanges ({len(builder.trade.cross_team_moves)})", value=moves_text, inline=False, ) - # Show supplementary moves if builder.trade.supplementary_moves: supp_text = "" - for i, move in enumerate( - builder.trade.supplementary_moves[:5], 1 - ): # Limit display + for i, move in enumerate(builder.trade.supplementary_moves[:5], 1): supp_text += f"{i}. {move.description}\n" if len(builder.trade.supplementary_moves) > 5: @@ -838,31 +761,33 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed: ) embed.add_field( - name=f"⚙️ Supplementary Moves ({len(builder.trade.supplementary_moves)})", + name=f"Supplementary Moves ({len(builder.trade.supplementary_moves)})", value=supp_text, inline=False, ) - # Add quick validation summary validation = await builder.validate_trade() if validation.is_legal: - status_text = "✅ Trade appears legal" + status_text = "Trade appears legal" else: error_count = len(validation.all_errors) - status_text = f"❌ {error_count} error{'s' if error_count != 1 else ''} found" + status_text = f"{error_count} error{'s' if error_count != 1 else ''} found\n" + status_text += "\n".join(f"- {error}" for error in validation.all_errors) + if validation.all_suggestions: + status_text += "\n" + "\n".join( + f"- {s}" for s in validation.all_suggestions + ) - embed.add_field(name="🔍 Quick Status", value=status_text, inline=False) + embed.add_field(name="Quick Status", value=status_text, inline=False) - # Add instructions for adding more moves embed.add_field( - name="➕ Build Your Trade", - value="• `/trade add-player` - Add player exchanges\n• `/trade supplementary` - Add internal moves\n• `/trade add-team` - Add more teams", + name="Build Your Trade", + value="- `/trade add-player` - Add player exchanges\n- `/trade supplementary` - Add internal moves\n- `/trade add-team` - Add more teams", inline=False, ) - # Add footer with trade ID and timestamp embed.set_footer( - text=f"Trade ID: {builder.trade_id} • Created: {datetime.fromisoformat(builder.trade.created_at).strftime('%H:%M:%S')}" + text=f"Trade ID: {builder.trade_id} | Created: {datetime.fromisoformat(builder.trade.created_at).strftime('%H:%M:%S')}" ) return embed