fix: prevent partial DB writes on scorecard submission failure #79

Closed
cal wants to merge 12 commits from fix/scorecard-submission-resilience into next-release
13 changed files with 1006 additions and 922 deletions

View File

@ -210,17 +210,41 @@ class SubmitScorecardCommands(commands.Cog):
game_id = scheduled_game.id game_id = scheduled_game.id
# Phase 6: Read Scorecard Data # Phase 6: Read ALL Scorecard Data (before any DB writes)
# Reading everything first prevents partial commits if the
# spreadsheet has formula errors (e.g. #N/A in pitching decisions)
await interaction.edit_original_response( await interaction.edit_original_response(
content="📊 Reading play-by-play data..." content="📊 Reading scorecard data..."
) )
plays_data = await self.sheets_service.read_playtable_data(scorecard) plays_data = await self.sheets_service.read_playtable_data(scorecard)
box_score = await self.sheets_service.read_box_score(scorecard)
decisions_data = await self.sheets_service.read_pitching_decisions(
scorecard
)
# Add game_id to each play # Add game_id to each play
for play in plays_data: for play in plays_data:
play["game_id"] = game_id play["game_id"] = game_id
# Add game metadata to each decision
for decision in decisions_data:
decision["game_id"] = game_id
decision["season"] = current.season
decision["week"] = setup_data["week"]
decision["game_num"] = setup_data["game_num"]
# Validate WP and LP exist and fetch Player objects
wp, lp, sv, holders, _blown_saves = (
await decision_service.find_winning_losing_pitchers(decisions_data)
)
if wp is None or lp is None:
await interaction.edit_original_response(
content="❌ Your card is missing either a Winning Pitcher or Losing Pitcher"
)
return
# Phase 7: POST Plays # Phase 7: POST Plays
await interaction.edit_original_response( await interaction.edit_original_response(
content="💾 Submitting plays to database..." content="💾 Submitting plays to database..."
@ -244,10 +268,7 @@ class SubmitScorecardCommands(commands.Cog):
) )
return return
# Phase 8: Read Box Score # Phase 8: PATCH Game
box_score = await self.sheets_service.read_box_score(scorecard)
# Phase 9: PATCH Game
await interaction.edit_original_response( await interaction.edit_original_response(
content="⚾ Updating game result..." content="⚾ Updating game result..."
) )
@ -275,33 +296,7 @@ class SubmitScorecardCommands(commands.Cog):
) )
return return
# Phase 10: Read Pitching Decisions # Phase 9: POST Decisions
decisions_data = await self.sheets_service.read_pitching_decisions(
scorecard
)
# Add game metadata to each decision
for decision in decisions_data:
decision["game_id"] = game_id
decision["season"] = current.season
decision["week"] = setup_data["week"]
decision["game_num"] = setup_data["game_num"]
# Validate WP and LP exist and fetch Player objects
wp, lp, sv, holders, _blown_saves = (
await decision_service.find_winning_losing_pitchers(decisions_data)
)
if wp is None or lp is None:
# Rollback
await game_service.wipe_game_data(game_id)
await play_service.delete_plays_for_game(game_id)
await interaction.edit_original_response(
content="❌ Your card is missing either a Winning Pitcher or Losing Pitcher"
)
return
# Phase 11: POST Decisions
await interaction.edit_original_response( await interaction.edit_original_response(
content="🎯 Submitting pitching decisions..." content="🎯 Submitting pitching decisions..."
) )
@ -361,6 +356,30 @@ class SubmitScorecardCommands(commands.Cog):
# Success! # Success!
await interaction.edit_original_response(content="✅ You are all set!") await interaction.edit_original_response(content="✅ You are all set!")
except SheetsException as e:
# Spreadsheet reading error - show the detailed message to the user
self.logger.error(
f"Spreadsheet error in scorecard submission: {e}", error=e
)
if rollback_state and game_id:
try:
if rollback_state == "GAME_PATCHED":
await game_service.wipe_game_data(game_id)
await play_service.delete_plays_for_game(game_id)
elif rollback_state == "PLAYS_POSTED":
await play_service.delete_plays_for_game(game_id)
except Exception:
pass # Best effort rollback
await interaction.edit_original_response(
content=(
f"❌ There's a problem with your scorecard:\n\n"
f"{str(e)}\n\n"
f"Please fix the issue in your spreadsheet and resubmit."
)
)
except Exception as e: except Exception as e:
# Unexpected error - attempt rollback # Unexpected error - attempt rollback
self.logger.error(f"Unexpected error in scorecard submission: {e}", error=e) self.logger.error(f"Unexpected error in scorecard submission: {e}", error=e)

View File

@ -3,6 +3,7 @@ Base service class for Discord Bot v2.0
Provides common CRUD operations and error handling for all data services. Provides common CRUD operations and error handling for all data services.
""" """
import logging import logging
import hashlib import hashlib
from typing import Optional, Type, TypeVar, Generic, Dict, Any, List, Tuple from typing import Optional, Type, TypeVar, Generic, Dict, Any, List, Tuple
@ -12,15 +13,15 @@ from models.base import SBABaseModel
from exceptions import APIException from exceptions import APIException
from utils.cache import CacheManager from utils.cache import CacheManager
logger = logging.getLogger(f'{__name__}.BaseService') logger = logging.getLogger(f"{__name__}.BaseService")
T = TypeVar('T', bound=SBABaseModel) T = TypeVar("T", bound=SBABaseModel)
class BaseService(Generic[T]): class BaseService(Generic[T]):
""" """
Base service class providing common CRUD operations for SBA models. Base service class providing common CRUD operations for SBA models.
Features: Features:
- Generic type support for any SBABaseModel subclass - Generic type support for any SBABaseModel subclass
- Automatic model validation and conversion - Automatic model validation and conversion
@ -28,15 +29,17 @@ class BaseService(Generic[T]):
- API response format handling (count + list format) - API response format handling (count + list format)
- Connection management via global client - Connection management via global client
""" """
def __init__(self, def __init__(
model_class: Type[T], self,
endpoint: str, model_class: Type[T],
client: Optional[APIClient] = None, endpoint: str,
cache_manager: Optional[CacheManager] = None): client: Optional[APIClient] = None,
cache_manager: Optional[CacheManager] = None,
):
""" """
Initialize base service. Initialize base service.
Args: Args:
model_class: Pydantic model class for this service model_class: Pydantic model class for this service
endpoint: API endpoint path (e.g., 'players', 'teams') endpoint: API endpoint path (e.g., 'players', 'teams')
@ -48,40 +51,44 @@ class BaseService(Generic[T]):
self._client = client self._client = client
self._cached_client: Optional[APIClient] = None self._cached_client: Optional[APIClient] = None
self.cache = cache_manager or CacheManager() self.cache = cache_manager or CacheManager()
logger.debug(f"Initialized {self.__class__.__name__} for {model_class.__name__} at endpoint '{endpoint}'") logger.debug(
f"Initialized {self.__class__.__name__} for {model_class.__name__} at endpoint '{endpoint}'"
def _generate_cache_key(self, method: str, params: Optional[List[Tuple[str, Any]]] = None) -> str: )
def _generate_cache_key(
self, method: str, params: Optional[List[Tuple[str, Any]]] = None
) -> str:
""" """
Generate consistent cache key for API calls. Generate consistent cache key for API calls.
Args: Args:
method: API method name method: API method name
params: Query parameters as list of tuples params: Query parameters as list of tuples
Returns: Returns:
SHA256-hashed cache key SHA256-hashed cache key
""" """
key_parts = [self.endpoint, method] key_parts = [self.endpoint, method]
if params: if params:
# Sort parameters for consistent key generation # Sort parameters for consistent key generation
sorted_params = sorted(params, key=lambda x: str(x[0])) sorted_params = sorted(params, key=lambda x: str(x[0]))
param_str = "&".join([f"{k}={v}" for k, v in sorted_params]) param_str = "&".join([f"{k}={v}" for k, v in sorted_params])
key_parts.append(param_str) key_parts.append(param_str)
key_data = ":".join(key_parts) key_data = ":".join(key_parts)
key_hash = hashlib.sha256(key_data.encode()).hexdigest()[:16] # First 16 chars key_hash = hashlib.sha256(key_data.encode()).hexdigest()[:16] # First 16 chars
return self.cache.cache_key("sba", f"{self.endpoint}_{key_hash}") return self.cache.cache_key("sba", f"{self.endpoint}_{key_hash}")
async def _get_cached_items(self, cache_key: str) -> Optional[List[T]]: async def _get_cached_items(self, cache_key: str) -> Optional[List[T]]:
""" """
Get cached list of model items. Get cached list of model items.
Args: Args:
cache_key: Cache key to lookup cache_key: Cache key to lookup
Returns: Returns:
List of model instances or None if not cached List of model instances or None if not cached
""" """
@ -91,13 +98,15 @@ class BaseService(Generic[T]):
return [self.model_class.from_api_data(item) for item in cached_data] return [self.model_class.from_api_data(item) for item in cached_data]
except Exception as e: except Exception as e:
logger.warning(f"Error deserializing cached data for {cache_key}: {e}") logger.warning(f"Error deserializing cached data for {cache_key}: {e}")
return None return None
async def _cache_items(self, cache_key: str, items: List[T], ttl: Optional[int] = None) -> None: async def _cache_items(
self, cache_key: str, items: List[T], ttl: Optional[int] = None
) -> None:
""" """
Cache list of model items. Cache list of model items.
Args: Args:
cache_key: Cache key to store under cache_key: Cache key to store under
items: List of model instances to cache items: List of model instances to cache
@ -105,40 +114,40 @@ class BaseService(Generic[T]):
""" """
if not items: if not items:
return return
try: try:
# Convert to JSON-serializable format # Convert to JSON-serializable format
cache_data = [item.model_dump() for item in items] cache_data = [item.model_dump() for item in items]
await self.cache.set(cache_key, cache_data, ttl) await self.cache.set(cache_key, cache_data, ttl)
except Exception as e: except Exception as e:
logger.warning(f"Error caching items for {cache_key}: {e}") logger.warning(f"Error caching items for {cache_key}: {e}")
async def get_client(self) -> APIClient: async def get_client(self) -> APIClient:
""" """
Get API client instance with caching to reduce async overhead. Get API client instance with caching to reduce async overhead.
Returns: Returns:
APIClient instance (cached after first access) APIClient instance (cached after first access)
""" """
if self._client: if self._client:
return self._client return self._client
# Cache the global client to avoid repeated async calls # Cache the global client to avoid repeated async calls
if self._cached_client is None: if self._cached_client is None:
self._cached_client = await get_global_client() self._cached_client = await get_global_client()
return self._cached_client return self._cached_client
async def get_by_id(self, object_id: int) -> Optional[T]: async def get_by_id(self, object_id: int) -> Optional[T]:
""" """
Get single object by ID. Get single object by ID.
Args: Args:
object_id: Unique identifier for the object object_id: Unique identifier for the object
Returns: Returns:
Model instance or None if not found Model instance or None if not found
Raises: Raises:
APIException: For API errors APIException: For API errors
ValueError: For invalid data ValueError: For invalid data
@ -146,167 +155,181 @@ class BaseService(Generic[T]):
try: try:
client = await self.get_client() client = await self.get_client()
data = await client.get(self.endpoint, object_id=object_id) data = await client.get(self.endpoint, object_id=object_id)
if not data: if not data:
logger.debug(f"{self.model_class.__name__} {object_id} not found") logger.debug(f"{self.model_class.__name__} {object_id} not found")
return None return None
model = self.model_class.from_api_data(data) model = self.model_class.from_api_data(data)
logger.debug(f"Retrieved {self.model_class.__name__} {object_id}: {model}") logger.debug(f"Retrieved {self.model_class.__name__} {object_id}: {model}")
return model return model
except APIException: except APIException:
logger.error(f"API error retrieving {self.model_class.__name__} {object_id}") logger.error(
f"API error retrieving {self.model_class.__name__} {object_id}"
)
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error retrieving {self.model_class.__name__} {object_id}: {e}") logger.error(
f"Error retrieving {self.model_class.__name__} {object_id}: {e}"
)
raise APIException(f"Failed to retrieve {self.model_class.__name__}: {e}") raise APIException(f"Failed to retrieve {self.model_class.__name__}: {e}")
async def get_all(self, params: Optional[List[tuple]] = None) -> Tuple[List[T], int]: async def get_all(
self, params: Optional[List[tuple]] = None
) -> Tuple[List[T], int]:
""" """
Get all objects with optional query parameters. Get all objects with optional query parameters.
Args: Args:
params: Query parameters as list of (key, value) tuples params: Query parameters as list of (key, value) tuples
Returns: Returns:
Tuple of (list of model instances, total count) Tuple of (list of model instances, total count)
Raises: Raises:
APIException: For API errors APIException: For API errors
""" """
try: try:
client = await self.get_client() client = await self.get_client()
data = await client.get(self.endpoint, params=params) data = await client.get(self.endpoint, params=params)
if not data: if not data:
logger.debug(f"No {self.model_class.__name__} objects found") logger.debug(f"No {self.model_class.__name__} objects found")
return [], 0 return [], 0
# Handle API response format: {'count': int, '<endpoint>': [...]} # Handle API response format: {'count': int, '<endpoint>': [...]}
items, count = self._extract_items_and_count_from_response(data) items, count = self._extract_items_and_count_from_response(data)
models = [self.model_class.from_api_data(item) for item in items] models = [self.model_class.from_api_data(item) for item in items]
logger.debug(f"Retrieved {len(models)} of {count} {self.model_class.__name__} objects") logger.debug(
f"Retrieved {len(models)} of {count} {self.model_class.__name__} objects"
)
return models, count return models, count
except APIException: except APIException:
logger.error(f"API error retrieving {self.model_class.__name__} list") logger.error(f"API error retrieving {self.model_class.__name__} list")
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error retrieving {self.model_class.__name__} list: {e}") logger.error(f"Error retrieving {self.model_class.__name__} list: {e}")
raise APIException(f"Failed to retrieve {self.model_class.__name__} list: {e}") raise APIException(
f"Failed to retrieve {self.model_class.__name__} list: {e}"
)
async def get_all_items(self, params: Optional[List[tuple]] = None) -> List[T]: async def get_all_items(self, params: Optional[List[tuple]] = None) -> List[T]:
""" """
Get all objects (convenience method that only returns the list). Get all objects (convenience method that only returns the list).
Args: Args:
params: Query parameters as list of (key, value) tuples params: Query parameters as list of (key, value) tuples
Returns: Returns:
List of model instances List of model instances
""" """
items, _ = await self.get_all(params=params) items, _ = await self.get_all(params=params)
return items return items
async def create(self, model_data: Dict[str, Any]) -> Optional[T]: async def create(self, model_data: Dict[str, Any]) -> Optional[T]:
""" """
Create new object from data dictionary. Create new object from data dictionary.
Args: Args:
model_data: Dictionary of model fields model_data: Dictionary of model fields
Returns: Returns:
Created model instance or None Created model instance or None
Raises: Raises:
APIException: For API errors APIException: For API errors
ValueError: For invalid data ValueError: For invalid data
""" """
try: try:
client = await self.get_client() client = await self.get_client()
response = await client.post(self.endpoint, model_data) response = await client.post(f"{self.endpoint}/", model_data)
if not response: if not response:
logger.warning(f"No response from {self.model_class.__name__} creation") logger.warning(f"No response from {self.model_class.__name__} creation")
return None return None
model = self.model_class.from_api_data(response) model = self.model_class.from_api_data(response)
logger.debug(f"Created {self.model_class.__name__}: {model}") logger.debug(f"Created {self.model_class.__name__}: {model}")
return model return model
except APIException: except APIException:
logger.error(f"API error creating {self.model_class.__name__}") logger.error(f"API error creating {self.model_class.__name__}")
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error creating {self.model_class.__name__}: {e}") logger.error(f"Error creating {self.model_class.__name__}: {e}")
raise APIException(f"Failed to create {self.model_class.__name__}: {e}") raise APIException(f"Failed to create {self.model_class.__name__}: {e}")
async def create_from_model(self, model: T) -> Optional[T]: async def create_from_model(self, model: T) -> Optional[T]:
""" """
Create new object from model instance. Create new object from model instance.
Args: Args:
model: Model instance to create model: Model instance to create
Returns: Returns:
Created model instance or None Created model instance or None
""" """
return await self.create(model.to_dict(exclude_none=True)) return await self.create(model.to_dict(exclude_none=True))
async def update(self, object_id: int, model_data: Dict[str, Any]) -> Optional[T]: async def update(self, object_id: int, model_data: Dict[str, Any]) -> Optional[T]:
""" """
Update existing object. Update existing object.
Args: Args:
object_id: ID of object to update object_id: ID of object to update
model_data: Dictionary of fields to update model_data: Dictionary of fields to update
Returns: Returns:
Updated model instance or None if not found Updated model instance or None if not found
Raises: Raises:
APIException: For API errors APIException: For API errors
""" """
try: try:
client = await self.get_client() client = await self.get_client()
response = await client.put(self.endpoint, model_data, object_id=object_id) response = await client.put(self.endpoint, model_data, object_id=object_id)
if not response: if not response:
logger.debug(f"{self.model_class.__name__} {object_id} not found for update") logger.debug(
f"{self.model_class.__name__} {object_id} not found for update"
)
return None return None
model = self.model_class.from_api_data(response) model = self.model_class.from_api_data(response)
logger.debug(f"Updated {self.model_class.__name__} {object_id}: {model}") logger.debug(f"Updated {self.model_class.__name__} {object_id}: {model}")
return model return model
except APIException: except APIException:
logger.error(f"API error updating {self.model_class.__name__} {object_id}") logger.error(f"API error updating {self.model_class.__name__} {object_id}")
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}") logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}")
raise APIException(f"Failed to update {self.model_class.__name__}: {e}") raise APIException(f"Failed to update {self.model_class.__name__}: {e}")
async def update_from_model(self, model: T) -> Optional[T]: async def update_from_model(self, model: T) -> Optional[T]:
""" """
Update object from model instance. Update object from model instance.
Args: Args:
model: Model instance to update (must have ID) model: Model instance to update (must have ID)
Returns: Returns:
Updated model instance or None Updated model instance or None
Raises: Raises:
ValueError: If model has no ID ValueError: If model has no ID
""" """
if not model.id: if not model.id:
raise ValueError(f"Cannot update {self.model_class.__name__} without ID") raise ValueError(f"Cannot update {self.model_class.__name__} without ID")
return await self.update(model.id, model.to_dict(exclude_none=True)) return await self.update(model.id, model.to_dict(exclude_none=True))
async def patch(self, object_id: int, model_data: Dict[str, Any], use_query_params: bool = False) -> Optional[T]: async def patch(
self, object_id: int, model_data: Dict[str, Any], use_query_params: bool = False
) -> Optional[T]:
""" """
Update existing object with HTTP PATCH. Update existing object with HTTP PATCH.
@ -323,10 +346,14 @@ class BaseService(Generic[T]):
""" """
try: try:
client = await self.get_client() client = await self.get_client()
response = await client.patch(self.endpoint, model_data, object_id, use_query_params=use_query_params) response = await client.patch(
self.endpoint, model_data, object_id, use_query_params=use_query_params
)
if not response: if not response:
logger.debug(f"{self.model_class.__name__} {object_id} not found for update") logger.debug(
f"{self.model_class.__name__} {object_id} not found for update"
)
return None return None
model = self.model_class.from_api_data(response) model = self.model_class.from_api_data(response)
@ -340,134 +367,142 @@ class BaseService(Generic[T]):
logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}") logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}")
raise APIException(f"Failed to update {self.model_class.__name__}: {e}") raise APIException(f"Failed to update {self.model_class.__name__}: {e}")
async def delete(self, object_id: int) -> bool: async def delete(self, object_id: int) -> bool:
""" """
Delete object by ID. Delete object by ID.
Args: Args:
object_id: ID of object to delete object_id: ID of object to delete
Returns: Returns:
True if deleted, False if not found True if deleted, False if not found
Raises: Raises:
APIException: For API errors APIException: For API errors
""" """
try: try:
client = await self.get_client() client = await self.get_client()
success = await client.delete(self.endpoint, object_id=object_id) success = await client.delete(self.endpoint, object_id=object_id)
if success: if success:
logger.debug(f"Deleted {self.model_class.__name__} {object_id}") logger.debug(f"Deleted {self.model_class.__name__} {object_id}")
else: else:
logger.debug(f"{self.model_class.__name__} {object_id} not found for deletion") logger.debug(
f"{self.model_class.__name__} {object_id} not found for deletion"
)
return success return success
except APIException: except APIException:
logger.error(f"API error deleting {self.model_class.__name__} {object_id}") logger.error(f"API error deleting {self.model_class.__name__} {object_id}")
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error deleting {self.model_class.__name__} {object_id}: {e}") logger.error(f"Error deleting {self.model_class.__name__} {object_id}: {e}")
raise APIException(f"Failed to delete {self.model_class.__name__}: {e}") raise APIException(f"Failed to delete {self.model_class.__name__}: {e}")
async def get_by_field(self, field: str, value: Any) -> List[T]: async def get_by_field(self, field: str, value: Any) -> List[T]:
""" """
Get objects by specific field value. Get objects by specific field value.
Args: Args:
field: Field name to search field: Field name to search
value: Field value to match value: Field value to match
Returns: Returns:
List of matching model instances List of matching model instances
""" """
params = [(field, str(value))] params = [(field, str(value))]
return await self.get_all_items(params=params) return await self.get_all_items(params=params)
async def count(self, params: Optional[List[tuple]] = None) -> int: async def count(self, params: Optional[List[tuple]] = None) -> int:
""" """
Get count of objects matching parameters. Get count of objects matching parameters.
Args: Args:
params: Query parameters params: Query parameters
Returns: Returns:
Number of matching objects (from API count field) Number of matching objects (from API count field)
""" """
_, count = await self.get_all(params=params) _, count = await self.get_all(params=params)
return count return count
def _extract_items_and_count_from_response(self, data: Any) -> Tuple[List[Dict[str, Any]], int]: def _extract_items_and_count_from_response(
self, data: Any
) -> Tuple[List[Dict[str, Any]], int]:
""" """
Extract items list and count from API response with optimized parsing. Extract items list and count from API response with optimized parsing.
Expected format: {'count': int, '<endpoint>': [...]} Expected format: {'count': int, '<endpoint>': [...]}
Single object format: {'id': 1, 'name': '...'} Single object format: {'id': 1, 'name': '...'}
Args: Args:
data: API response data data: API response data
Returns: Returns:
Tuple of (items list, total count) Tuple of (items list, total count)
""" """
if isinstance(data, list): if isinstance(data, list):
return data, len(data) return data, len(data)
if not isinstance(data, dict): if not isinstance(data, dict):
logger.warning(f"Unexpected response format for {self.model_class.__name__}: {type(data)}") logger.warning(
f"Unexpected response format for {self.model_class.__name__}: {type(data)}"
)
return [], 0 return [], 0
# Single pass through the response dict - get count first # Single pass through the response dict - get count first
count = data.get('count', 0) count = data.get("count", 0)
# Priority order for finding items list (most common first) # Priority order for finding items list (most common first)
field_candidates = [self.endpoint, 'items', 'data', 'results'] field_candidates = [self.endpoint, "items", "data", "results"]
for field_name in field_candidates: for field_name in field_candidates:
if field_name in data and isinstance(data[field_name], list): if field_name in data and isinstance(data[field_name], list):
return data[field_name], count or len(data[field_name]) return data[field_name], count or len(data[field_name])
# Single object response (check for common identifying fields) # Single object response (check for common identifying fields)
if any(key in data for key in ['id', 'name', 'abbrev']): if any(key in data for key in ["id", "name", "abbrev"]):
return [data], 1 return [data], 1
return [], count return [], count
async def get_items_with_params(self, params: Optional[List[tuple]] = None) -> List[T]: async def get_items_with_params(
self, params: Optional[List[tuple]] = None
) -> List[T]:
""" """
Get all items with parameters (alias for get_all_items for compatibility). Get all items with parameters (alias for get_all_items for compatibility).
Args: Args:
params: Query parameters as list of (key, value) tuples params: Query parameters as list of (key, value) tuples
Returns: Returns:
List of model instances List of model instances
""" """
return await self.get_all_items(params=params) return await self.get_all_items(params=params)
async def create_item(self, model_data: Dict[str, Any]) -> Optional[T]: async def create_item(self, model_data: Dict[str, Any]) -> Optional[T]:
""" """
Create item (alias for create for compatibility). Create item (alias for create for compatibility).
Args: Args:
model_data: Dictionary of model fields model_data: Dictionary of model fields
Returns: Returns:
Created model instance or None Created model instance or None
""" """
return await self.create(model_data) return await self.create(model_data)
async def update_item_by_field(self, field: str, value: Any, update_data: Dict[str, Any]) -> Optional[T]: async def update_item_by_field(
self, field: str, value: Any, update_data: Dict[str, Any]
) -> Optional[T]:
""" """
Update item by field value. Update item by field value.
Args: Args:
field: Field name to search by field: Field name to search by
value: Field value to match value: Field value to match
update_data: Data to update update_data: Data to update
Returns: Returns:
Updated model instance or None if not found Updated model instance or None if not found
""" """
@ -475,22 +510,22 @@ class BaseService(Generic[T]):
items = await self.get_by_field(field, value) items = await self.get_by_field(field, value)
if not items: if not items:
return None return None
# Update the first matching item # Update the first matching item
item = items[0] item = items[0]
if not item.id: if not item.id:
return None return None
return await self.update(item.id, update_data) return await self.update(item.id, update_data)
async def delete_item_by_field(self, field: str, value: Any) -> bool: async def delete_item_by_field(self, field: str, value: Any) -> bool:
""" """
Delete item by field value. Delete item by field value.
Args: Args:
field: Field name to search by field: Field name to search by
value: Field value to match value: Field value to match
Returns: Returns:
True if deleted, False if not found True if deleted, False if not found
""" """
@ -498,62 +533,41 @@ class BaseService(Generic[T]):
items = await self.get_by_field(field, value) items = await self.get_by_field(field, value)
if not items: if not items:
return False return False
# Delete the first matching item # Delete the first matching item
item = items[0] item = items[0]
if not item.id: if not item.id:
return False return False
return await self.delete(item.id) return await self.delete(item.id)
async def create_item_in_table(self, table_name: str, item_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: async def get_items_from_table_with_params(
""" self, table_name: str, params: List[tuple]
Create item in a specific table (simplified for custom commands service). ) -> List[Dict[str, Any]]:
This is a placeholder - real implementation would need table-specific endpoints.
Args:
table_name: Name of the table
item_data: Data to create
Returns:
Created item data or None
"""
# For now, use the main endpoint - this would need proper implementation
# for different tables like 'custom_command_creators'
try:
client = await self.get_client()
# Use table name as endpoint for now
response = await client.post(table_name, item_data)
return response
except Exception as e:
logger.error(f"Error creating item in table {table_name}: {e}")
return None
async def get_items_from_table_with_params(self, table_name: str, params: List[tuple]) -> List[Dict[str, Any]]:
""" """
Get items from a specific table with parameters. Get items from a specific table with parameters.
Args: Args:
table_name: Name of the table table_name: Name of the table
params: Query parameters params: Query parameters
Returns: Returns:
List of item dictionaries List of item dictionaries
""" """
try: try:
client = await self.get_client() client = await self.get_client()
data = await client.get(table_name, params=params) data = await client.get(table_name, params=params)
if not data: if not data:
return [] return []
# Handle response format # Handle response format
items, _ = self._extract_items_and_count_from_response(data) items, _ = self._extract_items_and_count_from_response(data)
return items return items
except Exception as e: except Exception as e:
logger.error(f"Error getting items from table {table_name}: {e}") logger.error(f"Error getting items from table {table_name}: {e}")
return [] return []
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}(model={self.model_class.__name__}, endpoint='{self.endpoint}')" return f"{self.__class__.__name__}(model={self.model_class.__name__}, endpoint='{self.endpoint}')"

View File

@ -552,9 +552,8 @@ class CustomCommandsService(BaseService[CustomCommand]):
"active_commands": 0, "active_commands": 0,
} }
result = await self.create_item_in_table( client = await self.get_client()
"custom_commands/creators", creator_data result = await client.post("custom_commands/creators", creator_data)
)
if not result: if not result:
raise BotException("Failed to create command creator") raise BotException("Failed to create command creator")

View File

@ -3,6 +3,7 @@ Decision Service
Manages pitching decision operations for game submission. Manages pitching decision operations for game submission.
""" """
from typing import List, Dict, Any, Optional, Tuple from typing import List, Dict, Any, Optional, Tuple
from utils.logging import get_contextual_logger from utils.logging import get_contextual_logger
@ -16,17 +17,14 @@ class DecisionService:
def __init__(self): def __init__(self):
"""Initialize decision service.""" """Initialize decision service."""
self.logger = get_contextual_logger(f'{__name__}.DecisionService') self.logger = get_contextual_logger(f"{__name__}.DecisionService")
self._get_client = get_global_client self._get_client = get_global_client
async def get_client(self): async def get_client(self):
"""Get the API client.""" """Get the API client."""
return await self._get_client() return await self._get_client()
async def create_decisions_batch( async def create_decisions_batch(self, decisions: List[Dict[str, Any]]) -> bool:
self,
decisions: List[Dict[str, Any]]
) -> bool:
""" """
POST batch of decisions to /decisions endpoint. POST batch of decisions to /decisions endpoint.
@ -42,8 +40,10 @@ class DecisionService:
try: try:
client = await self.get_client() client = await self.get_client()
payload = {'decisions': decisions} payload = {"decisions": decisions}
await client.post('decisions', payload) # Trailing slash required: without it, the server returns a 307 redirect
# and aiohttp drops the POST body when following the redirect
await client.post("decisions/", payload)
self.logger.info(f"Created {len(decisions)} decisions") self.logger.info(f"Created {len(decisions)} decisions")
return True return True
@ -70,7 +70,7 @@ class DecisionService:
""" """
try: try:
client = await self.get_client() client = await self.get_client()
await client.delete(f'decisions/game/{game_id}') await client.delete(f"decisions/game/{game_id}")
self.logger.info(f"Deleted decisions for game {game_id}") self.logger.info(f"Deleted decisions for game {game_id}")
return True return True
@ -80,9 +80,10 @@ class DecisionService:
raise APIException(f"Failed to delete decisions: {e}") raise APIException(f"Failed to delete decisions: {e}")
async def find_winning_losing_pitchers( async def find_winning_losing_pitchers(
self, self, decisions_data: List[Dict[str, Any]]
decisions_data: List[Dict[str, Any]] ) -> Tuple[
) -> Tuple[Optional[Player], Optional[Player], Optional[Player], List[Player], List[Player]]: Optional[Player], Optional[Player], Optional[Player], List[Player], List[Player]
]:
""" """
Extract WP, LP, SV, Holds, Blown Saves from decisions list and fetch Player objects. Extract WP, LP, SV, Holds, Blown Saves from decisions list and fetch Player objects.
@ -110,17 +111,17 @@ class DecisionService:
# First pass: Extract IDs # First pass: Extract IDs
for decision in decisions_data: for decision in decisions_data:
pitcher_id = int(decision.get('pitcher_id', 0)) pitcher_id = int(decision.get("pitcher_id", 0))
if int(decision.get('win', 0)) == 1: if int(decision.get("win", 0)) == 1:
wp_id = pitcher_id wp_id = pitcher_id
if int(decision.get('loss', 0)) == 1: if int(decision.get("loss", 0)) == 1:
lp_id = pitcher_id lp_id = pitcher_id
if int(decision.get('is_save', 0)) == 1: if int(decision.get("is_save", 0)) == 1:
sv_id = pitcher_id sv_id = pitcher_id
if int(decision.get('hold', 0)) == 1: if int(decision.get("hold", 0)) == 1:
hold_ids.append(pitcher_id) hold_ids.append(pitcher_id)
if int(decision.get('b_save', 0)) == 1: if int(decision.get("b_save", 0)) == 1:
bsv_ids.append(pitcher_id) bsv_ids.append(pitcher_id)
# Second pass: Fetch Player objects # Second pass: Fetch Player objects
@ -154,9 +155,9 @@ class DecisionService:
""" """
error_str = str(error) error_str = str(error)
if 'Player ID' in error_str and 'not found' in error_str: if "Player ID" in error_str and "not found" in error_str:
return "Invalid pitcher ID in decision data." return "Invalid pitcher ID in decision data."
elif 'Game ID' in error_str and 'not found' in error_str: elif "Game ID" in error_str and "not found" in error_str:
return "Game not found for decisions." return "Game not found for decisions."
else: else:
return f"Error submitting decisions: {error_str}" return f"Error submitting decisions: {error_str}"

View File

@ -3,13 +3,14 @@ Draft list service for Discord Bot v2.0
Handles team draft list (auto-draft queue) operations. NO CACHING - lists change frequently. Handles team draft list (auto-draft queue) operations. NO CACHING - lists change frequently.
""" """
import logging import logging
from typing import Optional, List from typing import Optional, List
from services.base_service import BaseService from services.base_service import BaseService
from models.draft_list import DraftList from models.draft_list import DraftList
logger = logging.getLogger(f'{__name__}.DraftListService') logger = logging.getLogger(f"{__name__}.DraftListService")
class DraftListService(BaseService[DraftList]): class DraftListService(BaseService[DraftList]):
@ -32,7 +33,7 @@ class DraftListService(BaseService[DraftList]):
def __init__(self): def __init__(self):
"""Initialize draft list service.""" """Initialize draft list service."""
super().__init__(DraftList, 'draftlist') super().__init__(DraftList, "draftlist")
logger.debug("DraftListService initialized") logger.debug("DraftListService initialized")
def _extract_items_and_count_from_response(self, data): def _extract_items_and_count_from_response(self, data):
@ -54,20 +55,16 @@ class DraftListService(BaseService[DraftList]):
return [], 0 return [], 0
# Get count # Get count
count = data.get('count', 0) count = data.get("count", 0)
# API returns items under 'picks' key (not 'draftlist') # API returns items under 'picks' key (not 'draftlist')
if 'picks' in data and isinstance(data['picks'], list): if "picks" in data and isinstance(data["picks"], list):
return data['picks'], count or len(data['picks']) return data["picks"], count or len(data["picks"])
# Fallback to standard extraction # Fallback to standard extraction
return super()._extract_items_and_count_from_response(data) return super()._extract_items_and_count_from_response(data)
async def get_team_list( async def get_team_list(self, season: int, team_id: int) -> List[DraftList]:
self,
season: int,
team_id: int
) -> List[DraftList]:
""" """
Get team's draft list ordered by rank. Get team's draft list ordered by rank.
@ -82,8 +79,8 @@ class DraftListService(BaseService[DraftList]):
""" """
try: try:
params = [ params = [
('season', str(season)), ("season", str(season)),
('team_id', str(team_id)) ("team_id", str(team_id)),
# NOTE: API does not support 'sort' param - results must be sorted client-side # NOTE: API does not support 'sort' param - results must be sorted client-side
] ]
@ -100,11 +97,7 @@ class DraftListService(BaseService[DraftList]):
return [] return []
async def add_to_list( async def add_to_list(
self, self, season: int, team_id: int, player_id: int, rank: Optional[int] = None
season: int,
team_id: int,
player_id: int,
rank: Optional[int] = None
) -> Optional[List[DraftList]]: ) -> Optional[List[DraftList]]:
""" """
Add player to team's draft list. Add player to team's draft list.
@ -133,10 +126,10 @@ class DraftListService(BaseService[DraftList]):
# Create new entry data # Create new entry data
new_entry_data = { new_entry_data = {
'season': season, "season": season,
'team_id': team_id, "team_id": team_id,
'player_id': player_id, "player_id": player_id,
'rank': rank "rank": rank,
} }
# Build complete list for bulk replacement # Build complete list for bulk replacement
@ -146,36 +139,42 @@ class DraftListService(BaseService[DraftList]):
for entry in current_list: for entry in current_list:
if entry.rank >= rank: if entry.rank >= rank:
# Shift down entries at or after insertion point # Shift down entries at or after insertion point
draft_list_entries.append({ draft_list_entries.append(
'season': entry.season, {
'team_id': entry.team_id, "season": entry.season,
'player_id': entry.player_id, "team_id": entry.team_id,
'rank': entry.rank + 1 "player_id": entry.player_id,
}) "rank": entry.rank + 1,
}
)
else: else:
# Keep existing rank for entries before insertion point # Keep existing rank for entries before insertion point
draft_list_entries.append({ draft_list_entries.append(
'season': entry.season, {
'team_id': entry.team_id, "season": entry.season,
'player_id': entry.player_id, "team_id": entry.team_id,
'rank': entry.rank "player_id": entry.player_id,
}) "rank": entry.rank,
}
)
# Add new entry # Add new entry
draft_list_entries.append(new_entry_data) draft_list_entries.append(new_entry_data)
# Sort by rank for consistency # Sort by rank for consistency
draft_list_entries.sort(key=lambda x: x['rank']) draft_list_entries.sort(key=lambda x: x["rank"])
# POST entire list (bulk replacement) # POST entire list (bulk replacement)
client = await self.get_client() client = await self.get_client()
payload = { payload = {
'count': len(draft_list_entries), "count": len(draft_list_entries),
'draft_list': draft_list_entries "draft_list": draft_list_entries,
} }
logger.debug(f"Posting draft list for team {team_id}: {len(draft_list_entries)} entries") logger.debug(
response = await client.post(self.endpoint, payload) f"Posting draft list for team {team_id}: {len(draft_list_entries)} entries"
)
response = await client.post(f"{self.endpoint}/", payload)
logger.debug(f"POST response: {response}") logger.debug(f"POST response: {response}")
# Verify by fetching the list back (API returns full objects) # Verify by fetching the list back (API returns full objects)
@ -184,20 +183,21 @@ class DraftListService(BaseService[DraftList]):
# Verify the player was added # Verify the player was added
if not any(entry.player_id == player_id for entry in verification): if not any(entry.player_id == player_id for entry in verification):
logger.error(f"Player {player_id} not found in list after POST - operation may have failed") logger.error(
f"Player {player_id} not found in list after POST - operation may have failed"
)
return None return None
logger.info(f"Added player {player_id} to team {team_id} draft list at rank {rank}") logger.info(
f"Added player {player_id} to team {team_id} draft list at rank {rank}"
)
return verification # Return full updated list return verification # Return full updated list
except Exception as e: except Exception as e:
logger.error(f"Error adding player {player_id} to draft list: {e}") logger.error(f"Error adding player {player_id} to draft list: {e}")
return None return None
async def remove_from_list( async def remove_from_list(self, entry_id: int) -> bool:
self,
entry_id: int
) -> bool:
""" """
Remove entry from draft list by ID. Remove entry from draft list by ID.
@ -209,14 +209,13 @@ class DraftListService(BaseService[DraftList]):
Returns: Returns:
True if deletion succeeded True if deletion succeeded
""" """
logger.warning("remove_from_list() called with entry_id - use remove_player_from_list() instead") logger.warning(
"remove_from_list() called with entry_id - use remove_player_from_list() instead"
)
return False return False
async def remove_player_from_list( async def remove_player_from_list(
self, self, season: int, team_id: int, player_id: int
season: int,
team_id: int,
player_id: int
) -> bool: ) -> bool:
""" """
Remove specific player from team's draft list. Remove specific player from team's draft list.
@ -238,7 +237,9 @@ class DraftListService(BaseService[DraftList]):
# Check if player is in list # Check if player is in list
player_found = any(entry.player_id == player_id for entry in current_list) player_found = any(entry.player_id == player_id for entry in current_list)
if not player_found: if not player_found:
logger.warning(f"Player {player_id} not found in team {team_id} draft list") logger.warning(
f"Player {player_id} not found in team {team_id} draft list"
)
return False return False
# Build new list without the player, adjusting ranks # Build new list without the player, adjusting ranks
@ -246,22 +247,24 @@ class DraftListService(BaseService[DraftList]):
new_rank = 1 new_rank = 1
for entry in current_list: for entry in current_list:
if entry.player_id != player_id: if entry.player_id != player_id:
draft_list_entries.append({ draft_list_entries.append(
'season': entry.season, {
'team_id': entry.team_id, "season": entry.season,
'player_id': entry.player_id, "team_id": entry.team_id,
'rank': new_rank "player_id": entry.player_id,
}) "rank": new_rank,
}
)
new_rank += 1 new_rank += 1
# POST updated list (bulk replacement) # POST updated list (bulk replacement)
client = await self.get_client() client = await self.get_client()
payload = { payload = {
'count': len(draft_list_entries), "count": len(draft_list_entries),
'draft_list': draft_list_entries "draft_list": draft_list_entries,
} }
await client.post(self.endpoint, payload) await client.post(f"{self.endpoint}/", payload)
logger.info(f"Removed player {player_id} from team {team_id} draft list") logger.info(f"Removed player {player_id} from team {team_id} draft list")
return True return True
@ -270,11 +273,7 @@ class DraftListService(BaseService[DraftList]):
logger.error(f"Error removing player {player_id} from draft list: {e}") logger.error(f"Error removing player {player_id} from draft list: {e}")
return False return False
async def clear_list( async def clear_list(self, season: int, team_id: int) -> bool:
self,
season: int,
team_id: int
) -> bool:
""" """
Clear entire draft list for team. Clear entire draft list for team.
@ -309,10 +308,7 @@ class DraftListService(BaseService[DraftList]):
return False return False
async def reorder_list( async def reorder_list(
self, self, season: int, team_id: int, new_order: List[int]
season: int,
team_id: int,
new_order: List[int]
) -> bool: ) -> bool:
""" """
Reorder team's draft list. Reorder team's draft list.
@ -342,21 +338,23 @@ class DraftListService(BaseService[DraftList]):
continue continue
entry = entry_map[player_id] entry = entry_map[player_id]
draft_list_entries.append({ draft_list_entries.append(
'season': entry.season, {
'team_id': entry.team_id, "season": entry.season,
'player_id': entry.player_id, "team_id": entry.team_id,
'rank': new_rank "player_id": entry.player_id,
}) "rank": new_rank,
}
)
# POST reordered list (bulk replacement) # POST reordered list (bulk replacement)
client = await self.get_client() client = await self.get_client()
payload = { payload = {
'count': len(draft_list_entries), "count": len(draft_list_entries),
'draft_list': draft_list_entries "draft_list": draft_list_entries,
} }
await client.post(self.endpoint, payload) await client.post(f"{self.endpoint}/", payload)
logger.info(f"Reordered draft list for team {team_id}") logger.info(f"Reordered draft list for team {team_id}")
return True return True
@ -365,12 +363,7 @@ class DraftListService(BaseService[DraftList]):
logger.error(f"Error reordering draft list for team {team_id}: {e}") logger.error(f"Error reordering draft list for team {team_id}: {e}")
return False return False
async def move_entry_up( async def move_entry_up(self, season: int, team_id: int, player_id: int) -> bool:
self,
season: int,
team_id: int,
player_id: int
) -> bool:
""" """
Move player up one position in draft list (higher priority). Move player up one position in draft list (higher priority).
@ -403,7 +396,9 @@ class DraftListService(BaseService[DraftList]):
return False return False
# Find entry above (rank - 1) # Find entry above (rank - 1)
above_entry = next((e for e in entries if e.rank == current_entry.rank - 1), None) above_entry = next(
(e for e in entries if e.rank == current_entry.rank - 1), None
)
if not above_entry: if not above_entry:
logger.error(f"Could not find entry above rank {current_entry.rank}") logger.error(f"Could not find entry above rank {current_entry.rank}")
return False return False
@ -421,24 +416,26 @@ class DraftListService(BaseService[DraftList]):
# Keep existing rank # Keep existing rank
new_rank = entry.rank new_rank = entry.rank
draft_list_entries.append({ draft_list_entries.append(
'season': entry.season, {
'team_id': entry.team_id, "season": entry.season,
'player_id': entry.player_id, "team_id": entry.team_id,
'rank': new_rank "player_id": entry.player_id,
}) "rank": new_rank,
}
)
# Sort by rank # Sort by rank
draft_list_entries.sort(key=lambda x: x['rank']) draft_list_entries.sort(key=lambda x: x["rank"])
# POST updated list (bulk replacement) # POST updated list (bulk replacement)
client = await self.get_client() client = await self.get_client()
payload = { payload = {
'count': len(draft_list_entries), "count": len(draft_list_entries),
'draft_list': draft_list_entries "draft_list": draft_list_entries,
} }
await client.post(self.endpoint, payload) await client.post(f"{self.endpoint}/", payload)
logger.info(f"Moved player {player_id} up to rank {current_entry.rank - 1}") logger.info(f"Moved player {player_id} up to rank {current_entry.rank - 1}")
return True return True
@ -447,12 +444,7 @@ class DraftListService(BaseService[DraftList]):
logger.error(f"Error moving player {player_id} up in draft list: {e}") logger.error(f"Error moving player {player_id} up in draft list: {e}")
return False return False
async def move_entry_down( async def move_entry_down(self, season: int, team_id: int, player_id: int) -> bool:
self,
season: int,
team_id: int,
player_id: int
) -> bool:
""" """
Move player down one position in draft list (lower priority). Move player down one position in draft list (lower priority).
@ -485,7 +477,9 @@ class DraftListService(BaseService[DraftList]):
return False return False
# Find entry below (rank + 1) # Find entry below (rank + 1)
below_entry = next((e for e in entries if e.rank == current_entry.rank + 1), None) below_entry = next(
(e for e in entries if e.rank == current_entry.rank + 1), None
)
if not below_entry: if not below_entry:
logger.error(f"Could not find entry below rank {current_entry.rank}") logger.error(f"Could not find entry below rank {current_entry.rank}")
return False return False
@ -503,25 +497,29 @@ class DraftListService(BaseService[DraftList]):
# Keep existing rank # Keep existing rank
new_rank = entry.rank new_rank = entry.rank
draft_list_entries.append({ draft_list_entries.append(
'season': entry.season, {
'team_id': entry.team_id, "season": entry.season,
'player_id': entry.player_id, "team_id": entry.team_id,
'rank': new_rank "player_id": entry.player_id,
}) "rank": new_rank,
}
)
# Sort by rank # Sort by rank
draft_list_entries.sort(key=lambda x: x['rank']) draft_list_entries.sort(key=lambda x: x["rank"])
# POST updated list (bulk replacement) # POST updated list (bulk replacement)
client = await self.get_client() client = await self.get_client()
payload = { payload = {
'count': len(draft_list_entries), "count": len(draft_list_entries),
'draft_list': draft_list_entries "draft_list": draft_list_entries,
} }
await client.post(self.endpoint, payload) await client.post(f"{self.endpoint}/", payload)
logger.info(f"Moved player {player_id} down to rank {current_entry.rank + 1}") logger.info(
f"Moved player {player_id} down to rank {current_entry.rank + 1}"
)
return True return True

View File

@ -3,13 +3,14 @@ Injury service for Discord Bot v2.0
Handles injury-related operations including checking, creating, and clearing injuries. Handles injury-related operations including checking, creating, and clearing injuries.
""" """
import logging import logging
from typing import Optional, List from typing import Optional, List
from services.base_service import BaseService from services.base_service import BaseService
from models.injury import Injury from models.injury import Injury
logger = logging.getLogger(f'{__name__}.InjuryService') logger = logging.getLogger(f"{__name__}.InjuryService")
class InjuryService(BaseService[Injury]): class InjuryService(BaseService[Injury]):
@ -25,7 +26,7 @@ class InjuryService(BaseService[Injury]):
def __init__(self): def __init__(self):
"""Initialize injury service.""" """Initialize injury service."""
super().__init__(Injury, 'injuries') super().__init__(Injury, "injuries")
logger.debug("InjuryService initialized") logger.debug("InjuryService initialized")
async def get_active_injury(self, player_id: int, season: int) -> Optional[Injury]: async def get_active_injury(self, player_id: int, season: int) -> Optional[Injury]:
@ -41,25 +42,31 @@ class InjuryService(BaseService[Injury]):
""" """
try: try:
params = [ params = [
('player_id', str(player_id)), ("player_id", str(player_id)),
('season', str(season)), ("season", str(season)),
('is_active', 'true') ("is_active", "true"),
] ]
injuries = await self.get_all_items(params=params) injuries = await self.get_all_items(params=params)
if injuries: if injuries:
logger.debug(f"Found active injury for player {player_id} in season {season}") logger.debug(
f"Found active injury for player {player_id} in season {season}"
)
return injuries[0] return injuries[0]
logger.debug(f"No active injury found for player {player_id} in season {season}") logger.debug(
f"No active injury found for player {player_id} in season {season}"
)
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error getting active injury for player {player_id}: {e}") logger.error(f"Error getting active injury for player {player_id}: {e}")
return None return None
async def get_injuries_by_player(self, player_id: int, season: int, active_only: bool = False) -> List[Injury]: async def get_injuries_by_player(
self, player_id: int, season: int, active_only: bool = False
) -> List[Injury]:
""" """
Get all injuries for a player in a specific season. Get all injuries for a player in a specific season.
@ -72,13 +79,10 @@ class InjuryService(BaseService[Injury]):
List of injuries for the player List of injuries for the player
""" """
try: try:
params = [ params = [("player_id", str(player_id)), ("season", str(season))]
('player_id', str(player_id)),
('season', str(season))
]
if active_only: if active_only:
params.append(('is_active', 'true')) params.append(("is_active", "true"))
injuries = await self.get_all_items(params=params) injuries = await self.get_all_items(params=params)
logger.debug(f"Retrieved {len(injuries)} injuries for player {player_id}") logger.debug(f"Retrieved {len(injuries)} injuries for player {player_id}")
@ -88,7 +92,9 @@ class InjuryService(BaseService[Injury]):
logger.error(f"Error getting injuries for player {player_id}: {e}") logger.error(f"Error getting injuries for player {player_id}: {e}")
return [] return []
async def get_injuries_by_team(self, team_id: int, season: int, active_only: bool = True) -> List[Injury]: async def get_injuries_by_team(
self, team_id: int, season: int, active_only: bool = True
) -> List[Injury]:
""" """
Get all injuries for a team in a specific season. Get all injuries for a team in a specific season.
@ -101,13 +107,10 @@ class InjuryService(BaseService[Injury]):
List of injuries for the team List of injuries for the team
""" """
try: try:
params = [ params = [("team_id", str(team_id)), ("season", str(season))]
('team_id', str(team_id)),
('season', str(season))
]
if active_only: if active_only:
params.append(('is_active', 'true')) params.append(("is_active", "true"))
injuries = await self.get_all_items(params=params) injuries = await self.get_all_items(params=params)
logger.debug(f"Retrieved {len(injuries)} injuries for team {team_id}") logger.debug(f"Retrieved {len(injuries)} injuries for team {team_id}")
@ -125,7 +128,7 @@ class InjuryService(BaseService[Injury]):
start_week: int, start_week: int,
start_game: int, start_game: int,
end_week: int, end_week: int,
end_game: int end_game: int,
) -> Optional[Injury]: ) -> Optional[Injury]:
""" """
Create a new injury record. Create a new injury record.
@ -144,22 +147,24 @@ class InjuryService(BaseService[Injury]):
""" """
try: try:
injury_data = { injury_data = {
'season': season, "season": season,
'player_id': player_id, "player_id": player_id,
'total_games': total_games, "total_games": total_games,
'start_week': start_week, "start_week": start_week,
'start_game': start_game, "start_game": start_game,
'end_week': end_week, "end_week": end_week,
'end_game': end_game, "end_game": end_game,
'is_active': True "is_active": True,
} }
# Call the API to create the injury # Call the API to create the injury
client = await self.get_client() client = await self.get_client()
response = await client.post(self.endpoint, injury_data) response = await client.post(f"{self.endpoint}/", injury_data)
if not response: if not response:
logger.error(f"Failed to create injury for player {player_id}: No response from API") logger.error(
f"Failed to create injury for player {player_id}: No response from API"
)
return None return None
# Merge the request data with the response to ensure all required fields are present # Merge the request data with the response to ensure all required fields are present
@ -187,7 +192,9 @@ class InjuryService(BaseService[Injury]):
""" """
try: try:
# Note: API expects is_active as query parameter, not JSON body # Note: API expects is_active as query parameter, not JSON body
updated_injury = await self.patch(injury_id, {'is_active': False}, use_query_params=True) updated_injury = await self.patch(
injury_id, {"is_active": False}, use_query_params=True
)
if updated_injury: if updated_injury:
logger.info(f"Cleared injury {injury_id}") logger.info(f"Cleared injury {injury_id}")
@ -216,16 +223,18 @@ class InjuryService(BaseService[Injury]):
try: try:
client = await self.get_client() client = await self.get_client()
params = [ params = [
('season', str(season)), ("season", str(season)),
('is_active', 'true'), ("is_active", "true"),
('sort', 'return-asc') ("sort", "return-asc"),
] ]
response = await client.get(self.endpoint, params=params) response = await client.get(self.endpoint, params=params)
if response and 'injuries' in response: if response and "injuries" in response:
logger.debug(f"Retrieved {len(response['injuries'])} active injuries for season {season}") logger.debug(
return response['injuries'] f"Retrieved {len(response['injuries'])} active injuries for season {season}"
)
return response["injuries"]
logger.debug(f"No active injuries found for season {season}") logger.debug(f"No active injuries found for season {season}")
return [] return []

View File

@ -3,6 +3,7 @@ Play Service
Manages play-by-play data operations for game submission. Manages play-by-play data operations for game submission.
""" """
from typing import List, Dict, Any from typing import List, Dict, Any
from utils.logging import get_contextual_logger from utils.logging import get_contextual_logger
@ -16,7 +17,7 @@ class PlayService:
def __init__(self): def __init__(self):
"""Initialize play service.""" """Initialize play service."""
self.logger = get_contextual_logger(f'{__name__}.PlayService') self.logger = get_contextual_logger(f"{__name__}.PlayService")
self._get_client = get_global_client self._get_client = get_global_client
async def get_client(self): async def get_client(self):
@ -39,8 +40,10 @@ class PlayService:
try: try:
client = await self.get_client() client = await self.get_client()
payload = {'plays': plays} payload = {"plays": plays}
response = await client.post('plays', payload) # Trailing slash required: without it, the server returns a 307 redirect
# and aiohttp drops the POST body when following the redirect
response = await client.post("plays/", payload)
self.logger.info(f"Created {len(plays)} plays") self.logger.info(f"Created {len(plays)} plays")
return True return True
@ -68,7 +71,7 @@ class PlayService:
""" """
try: try:
client = await self.get_client() client = await self.get_client()
response = await client.delete(f'plays/game/{game_id}') response = await client.delete(f"plays/game/{game_id}")
self.logger.info(f"Deleted plays for game {game_id}") self.logger.info(f"Deleted plays for game {game_id}")
return True return True
@ -77,11 +80,7 @@ class PlayService:
self.logger.error(f"Failed to delete plays for game {game_id}: {e}") self.logger.error(f"Failed to delete plays for game {game_id}: {e}")
raise APIException(f"Failed to delete plays: {e}") raise APIException(f"Failed to delete plays: {e}")
async def get_top_plays_by_wpa( async def get_top_plays_by_wpa(self, game_id: int, limit: int = 3) -> List[Play]:
self,
game_id: int,
limit: int = 3
) -> List[Play]:
""" """
Get top plays by WPA (absolute value) for key plays display. Get top plays by WPA (absolute value) for key plays display.
@ -95,19 +94,15 @@ class PlayService:
try: try:
client = await self.get_client() client = await self.get_client()
params = [ params = [("game_id", game_id), ("sort", "wpa-desc"), ("limit", limit)]
('game_id', game_id),
('sort', 'wpa-desc'),
('limit', limit)
]
response = await client.get('plays', params=params) response = await client.get("plays", params=params)
if not response or 'plays' not in response: if not response or "plays" not in response:
self.logger.info(f'No plays found for game ID {game_id}') self.logger.info(f"No plays found for game ID {game_id}")
return [] return []
plays = [Play.from_api_data(p) for p in response['plays']] plays = [Play.from_api_data(p) for p in response["plays"]]
self.logger.debug(f"Retrieved {len(plays)} top plays for game {game_id}") self.logger.debug(f"Retrieved {len(plays)} top plays for game {game_id}")
return plays return plays
@ -129,11 +124,11 @@ class PlayService:
error_str = str(error) error_str = str(error)
# Common error patterns # Common error patterns
if 'Player ID' in error_str and 'not found' in error_str: if "Player ID" in error_str and "not found" in error_str:
return "Invalid player ID in scorecard data. Please check player IDs." return "Invalid player ID in scorecard data. Please check player IDs."
elif 'Game ID' in error_str and 'not found' in error_str: elif "Game ID" in error_str and "not found" in error_str:
return "Game not found in database. Please contact an admin." return "Game not found in database. Please contact an admin."
elif 'validation' in error_str.lower(): elif "validation" in error_str.lower():
return f"Data validation error: {error_str}" return f"Data validation error: {error_str}"
else: else:
return f"Error submitting plays: {error_str}" return f"Error submitting plays: {error_str}"

View File

@ -416,6 +416,8 @@ class SheetsService:
self.logger.info(f"Read {len(pit_data)} valid pitching decisions") self.logger.info(f"Read {len(pit_data)} valid pitching decisions")
return pit_data return pit_data
except SheetsException:
raise
except Exception as e: except Exception as e:
self.logger.error(f"Failed to read pitching decisions: {e}") self.logger.error(f"Failed to read pitching decisions: {e}")
raise SheetsException("Unable to read pitching decisions") from e raise SheetsException("Unable to read pitching decisions") from e
@ -458,6 +460,8 @@ class SheetsService:
"home": [int(x) for x in score_table[1]], # [R, H, E] "home": [int(x) for x in score_table[1]], # [R, H, E]
} }
except SheetsException:
raise
except Exception as e: except Exception as e:
self.logger.error(f"Failed to read box score: {e}") self.logger.error(f"Failed to read box score: {e}")
raise SheetsException("Unable to read box score") from e raise SheetsException("Unable to read box score") from e

View File

@ -3,6 +3,7 @@ Trade Builder Service
Extends the TransactionBuilder to support multi-team trades and player exchanges. Extends the TransactionBuilder to support multi-team trades and player exchanges.
""" """
import logging import logging
from typing import Dict, List, Optional, Set from typing import Dict, List, Optional, Set
from datetime import datetime, timezone from datetime import datetime, timezone
@ -12,10 +13,14 @@ from config import get_config
from models.trade import Trade, TradeMove, TradeStatus from models.trade import Trade, TradeMove, TradeStatus
from models.team import Team, RosterType from models.team import Team, RosterType
from models.player import Player from models.player import Player
from services.transaction_builder import TransactionBuilder, RosterValidationResult, TransactionMove from services.transaction_builder import (
TransactionBuilder,
RosterValidationResult,
TransactionMove,
)
from services.team_service import team_service from services.team_service import team_service
logger = logging.getLogger(f'{__name__}.TradeBuilder') logger = logging.getLogger(f"{__name__}.TradeBuilder")
class TradeValidationResult: class TradeValidationResult:
@ -52,7 +57,9 @@ class TradeValidationResult:
suggestions.extend(validation.suggestions) suggestions.extend(validation.suggestions)
return suggestions return suggestions
def get_participant_validation(self, team_id: int) -> Optional[RosterValidationResult]: def get_participant_validation(
self, team_id: int
) -> Optional[RosterValidationResult]:
"""Get validation result for a specific team.""" """Get validation result for a specific team."""
return self.participant_validations.get(team_id) return self.participant_validations.get(team_id)
@ -64,7 +71,12 @@ class TradeBuilder:
Extends the functionality of TransactionBuilder to support trades between teams. Extends the functionality of TransactionBuilder to support trades between teams.
""" """
def __init__(self, initiated_by: int, initiating_team: Team, season: int = get_config().sba_season): def __init__(
self,
initiated_by: int,
initiating_team: Team,
season: int = get_config().sba_season,
):
""" """
Initialize trade builder. Initialize trade builder.
@ -79,7 +91,7 @@ class TradeBuilder:
status=TradeStatus.DRAFT, status=TradeStatus.DRAFT,
initiated_by=initiated_by, initiated_by=initiated_by,
created_at=datetime.now(timezone.utc).isoformat(), created_at=datetime.now(timezone.utc).isoformat(),
season=season season=season,
) )
# Add the initiating team as first participant # Add the initiating team as first participant
@ -91,7 +103,9 @@ class TradeBuilder:
# Track which teams have accepted the trade (team_id -> True) # Track which teams have accepted the trade (team_id -> True)
self.accepted_teams: Set[int] = set() self.accepted_teams: Set[int] = set()
logger.info(f"TradeBuilder initialized: {self.trade.trade_id} by user {initiated_by} for {initiating_team.abbrev}") logger.info(
f"TradeBuilder initialized: {self.trade.trade_id} by user {initiated_by} for {initiating_team.abbrev}"
)
@property @property
def trade_id(self) -> str: def trade_id(self) -> str:
@ -127,7 +141,11 @@ class TradeBuilder:
@property @property
def pending_teams(self) -> List[Team]: def pending_teams(self) -> List[Team]:
"""Get list of teams that haven't accepted yet.""" """Get list of teams that haven't accepted yet."""
return [team for team in self.participating_teams if team.id not in self.accepted_teams] return [
team
for team in self.participating_teams
if team.id not in self.accepted_teams
]
def accept_trade(self, team_id: int) -> bool: def accept_trade(self, team_id: int) -> bool:
""" """
@ -140,7 +158,9 @@ class TradeBuilder:
True if all teams have now accepted, False otherwise True if all teams have now accepted, False otherwise
""" """
self.accepted_teams.add(team_id) self.accepted_teams.add(team_id)
logger.info(f"Team {team_id} accepted trade {self.trade_id}. Accepted: {len(self.accepted_teams)}/{self.team_count}") logger.info(
f"Team {team_id} accepted trade {self.trade_id}. Accepted: {len(self.accepted_teams)}/{self.team_count}"
)
return self.all_teams_accepted return self.all_teams_accepted
def reject_trade(self) -> None: def reject_trade(self) -> None:
@ -160,7 +180,9 @@ class TradeBuilder:
Returns: Returns:
Dict mapping team_id to acceptance status (True/False) Dict mapping team_id to acceptance status (True/False)
""" """
return {team.id: team.id in self.accepted_teams for team in self.participating_teams} return {
team.id: team.id in self.accepted_teams for team in self.participating_teams
}
def has_team_accepted(self, team_id: int) -> bool: def has_team_accepted(self, team_id: int) -> bool:
"""Check if a specific team has accepted.""" """Check if a specific team has accepted."""
@ -184,7 +206,9 @@ class TradeBuilder:
participant = self.trade.add_participant(team) participant = self.trade.add_participant(team)
# Create transaction builder for this team # Create transaction builder for this team
self._team_builders[team.id] = TransactionBuilder(team, self.trade.initiated_by, self.trade.season) self._team_builders[team.id] = TransactionBuilder(
team, self.trade.initiated_by, self.trade.season
)
# Register team in secondary index for multi-GM access # Register team in secondary index for multi-GM access
trade_key = f"{self.trade.initiated_by}:trade" trade_key = f"{self.trade.initiated_by}:trade"
@ -209,7 +233,10 @@ class TradeBuilder:
# Check if team has moves - prevent removal if they do # Check if team has moves - prevent removal if they do
if participant.all_moves: if participant.all_moves:
return False, f"{participant.team.abbrev} has moves in this trade and cannot be removed" return (
False,
f"{participant.team.abbrev} has moves in this trade and cannot be removed",
)
# Remove team # Remove team
removed = self.trade.remove_participant(team_id) removed = self.trade.remove_participant(team_id)
@ -229,7 +256,7 @@ class TradeBuilder:
from_team: Team, from_team: Team,
to_team: Team, to_team: Team,
from_roster: RosterType, from_roster: RosterType,
to_roster: RosterType to_roster: RosterType,
) -> tuple[bool, str]: ) -> tuple[bool, str]:
""" """
Add a player move to the trade. Add a player move to the trade.
@ -246,7 +273,10 @@ class TradeBuilder:
""" """
# Validate player is not from Free Agency # Validate player is not from Free Agency
if player.team_id == get_config().free_agent_team_id: if player.team_id == get_config().free_agent_team_id:
return False, f"Cannot add {player.name} from Free Agency. Players must be traded from teams within the organizations involved in the trade." return (
False,
f"Cannot add {player.name} from Free Agency. Players must be traded from teams within the organizations involved in the trade.",
)
# Validate player has a valid team assignment # Validate player has a valid team assignment
if not player.team_id: if not player.team_id:
@ -259,7 +289,10 @@ class TradeBuilder:
# Check if player's team is in the same organization as from_team # Check if player's team is in the same organization as from_team
if not player_team.is_same_organization(from_team): if not player_team.is_same_organization(from_team):
return False, f"{player.name} is on {player_team.abbrev}, they are not eligible to be added to the trade." return (
False,
f"{player.name} is on {player_team.abbrev}, they are not eligible to be added to the trade.",
)
# Ensure both teams are participating (check by organization for ML authority) # Ensure both teams are participating (check by organization for ML authority)
from_participant = self.trade.get_participant_by_organization(from_team) from_participant = self.trade.get_participant_by_organization(from_team)
@ -274,7 +307,10 @@ class TradeBuilder:
for participant in self.trade.participants: for participant in self.trade.participants:
for existing_move in participant.all_moves: for existing_move in participant.all_moves:
if existing_move.player.id == player.id: if existing_move.player.id == player.id:
return False, f"{player.name} is already involved in a move in this trade" return (
False,
f"{player.name} is already involved in a move in this trade",
)
# Create trade move # Create trade move
trade_move = TradeMove( trade_move = TradeMove(
@ -284,7 +320,7 @@ class TradeBuilder:
from_team=from_team, from_team=from_team,
to_team=to_team, to_team=to_team,
source_team=from_team, source_team=from_team,
destination_team=to_team destination_team=to_team,
) )
# Add to giving team's moves # Add to giving team's moves
@ -303,7 +339,7 @@ class TradeBuilder:
from_roster=from_roster, from_roster=from_roster,
to_roster=RosterType.FREE_AGENCY, # Conceptually leaving the org to_roster=RosterType.FREE_AGENCY, # Conceptually leaving the org
from_team=from_team, from_team=from_team,
to_team=None to_team=None,
) )
# Move for receiving team (player joining) # Move for receiving team (player joining)
@ -312,19 +348,23 @@ class TradeBuilder:
from_roster=RosterType.FREE_AGENCY, # Conceptually joining from outside from_roster=RosterType.FREE_AGENCY, # Conceptually joining from outside
to_roster=to_roster, to_roster=to_roster,
from_team=None, from_team=None,
to_team=to_team to_team=to_team,
) )
# Add moves to respective builders # Add moves to respective builders
# Skip pending transaction check for trades - they have their own validation workflow # Skip pending transaction check for trades - they have their own validation workflow
from_success, from_error = await from_builder.add_move(from_move, check_pending_transactions=False) from_success, from_error = await from_builder.add_move(
from_move, check_pending_transactions=False
)
if not from_success: if not from_success:
# Remove from trade if builder failed # Remove from trade if builder failed
from_participant.moves_giving.remove(trade_move) from_participant.moves_giving.remove(trade_move)
to_participant.moves_receiving.remove(trade_move) to_participant.moves_receiving.remove(trade_move)
return False, f"Error adding move to {from_team.abbrev}: {from_error}" return False, f"Error adding move to {from_team.abbrev}: {from_error}"
to_success, to_error = await to_builder.add_move(to_move, check_pending_transactions=False) to_success, to_error = await to_builder.add_move(
to_move, check_pending_transactions=False
)
if not to_success: if not to_success:
# Rollback both if second failed # Rollback both if second failed
from_builder.remove_move(player.id) from_builder.remove_move(player.id)
@ -332,15 +372,13 @@ class TradeBuilder:
to_participant.moves_receiving.remove(trade_move) to_participant.moves_receiving.remove(trade_move)
return False, f"Error adding move to {to_team.abbrev}: {to_error}" return False, f"Error adding move to {to_team.abbrev}: {to_error}"
logger.info(f"Added player move to trade {self.trade_id}: {trade_move.description}") logger.info(
f"Added player move to trade {self.trade_id}: {trade_move.description}"
)
return True, "" return True, ""
async def add_supplementary_move( async def add_supplementary_move(
self, self, team: Team, player: Player, from_roster: RosterType, to_roster: RosterType
team: Team,
player: Player,
from_roster: RosterType,
to_roster: RosterType
) -> tuple[bool, str]: ) -> tuple[bool, str]:
""" """
Add a supplementary move (internal organizational move) for roster legality. Add a supplementary move (internal organizational move) for roster legality.
@ -366,7 +404,7 @@ class TradeBuilder:
from_team=team, from_team=team,
to_team=team, to_team=team,
source_team=team, source_team=team,
destination_team=team destination_team=team,
) )
# Add to participant's supplementary moves # Add to participant's supplementary moves
@ -379,16 +417,20 @@ class TradeBuilder:
from_roster=from_roster, from_roster=from_roster,
to_roster=to_roster, to_roster=to_roster,
from_team=team, from_team=team,
to_team=team to_team=team,
) )
# Skip pending transaction check for trade supplementary moves # Skip pending transaction check for trade supplementary moves
success, error = await builder.add_move(trans_move, check_pending_transactions=False) success, error = await builder.add_move(
trans_move, check_pending_transactions=False
)
if not success: if not success:
participant.supplementary_moves.remove(supp_move) participant.supplementary_moves.remove(supp_move)
return False, error return False, error
logger.info(f"Added supplementary move for {team.abbrev}: {supp_move.description}") logger.info(
f"Added supplementary move for {team.abbrev}: {supp_move.description}"
)
return True, "" return True, ""
async def remove_move(self, player_id: int) -> tuple[bool, str]: async def remove_move(self, player_id: int) -> tuple[bool, str]:
@ -432,21 +474,41 @@ class TradeBuilder:
for builder in self._team_builders.values(): for builder in self._team_builders.values():
builder.remove_move(player_id) builder.remove_move(player_id)
logger.info(f"Removed move from trade {self.trade_id}: {removed_move.description}") logger.info(
f"Removed move from trade {self.trade_id}: {removed_move.description}"
)
return True, "" return True, ""
async def validate_trade(self, next_week: Optional[int] = None) -> TradeValidationResult: async def validate_trade(
self, next_week: Optional[int] = None
) -> TradeValidationResult:
""" """
Validate the entire trade including all teams' roster legality. Validate the entire trade including all teams' roster legality.
Validates against next week's projected roster (current roster + pending
transactions), matching the behavior of /dropadd validation.
Args: Args:
next_week: Week to validate for (optional) next_week: Week to validate for (auto-fetched if not provided)
Returns: Returns:
TradeValidationResult with comprehensive validation TradeValidationResult with comprehensive validation
""" """
result = TradeValidationResult() result = TradeValidationResult()
# Auto-fetch next week so validation includes pending transactions
if next_week is None:
try:
from services.league_service import league_service
current_state = await league_service.get_current_state()
next_week = (current_state.week + 1) if current_state else 1
except Exception as e:
logger.warning(
f"Could not determine next week for trade validation: {e}"
)
next_week = None
# Validate trade structure # Validate trade structure
is_balanced, balance_errors = self.trade.validate_trade_balance() is_balanced, balance_errors = self.trade.validate_trade_balance()
if not is_balanced: if not is_balanced:
@ -472,13 +534,17 @@ class TradeBuilder:
if self.team_count < 2: if self.team_count < 2:
result.trade_suggestions.append("Add another team to create a trade") result.trade_suggestions.append("Add another team to create a trade")
logger.debug(f"Trade validation for {self.trade_id}: Legal={result.is_legal}, Errors={len(result.all_errors)}") logger.debug(
f"Trade validation for {self.trade_id}: Legal={result.is_legal}, Errors={len(result.all_errors)}"
)
return result return result
def _get_or_create_builder(self, team: Team) -> TransactionBuilder: def _get_or_create_builder(self, team: Team) -> TransactionBuilder:
"""Get or create a transaction builder for a team.""" """Get or create a transaction builder for a team."""
if team.id not in self._team_builders: if team.id not in self._team_builders:
self._team_builders[team.id] = TransactionBuilder(team, self.trade.initiated_by, self.trade.season) self._team_builders[team.id] = TransactionBuilder(
team, self.trade.initiated_by, self.trade.season
)
return self._team_builders[team.id] return self._team_builders[team.id]
def clear_trade(self) -> None: def clear_trade(self) -> None:
@ -592,4 +658,4 @@ def clear_trade_builder_by_team(team_id: int) -> bool:
def get_active_trades() -> Dict[str, TradeBuilder]: def get_active_trades() -> Dict[str, TradeBuilder]:
"""Get all active trade builders (for debugging/admin purposes).""" """Get all active trade builders (for debugging/admin purposes)."""
return _active_trade_builders.copy() return _active_trade_builders.copy()

View File

@ -248,7 +248,7 @@ class TransactionService(BaseService[Transaction]):
# POST batch to API # POST batch to API
client = await self.get_client() client = await self.get_client()
response = await client.post(self.endpoint, data=batch_data) response = await client.post(f"{self.endpoint}/", data=batch_data)
# API returns a string like "2 transactions have been added" # API returns a string like "2 transactions have been added"
# We need to return the original Transaction objects (they won't have IDs assigned by API) # We need to return the original Transaction objects (they won't have IDs assigned by API)

View File

@ -1,18 +1,24 @@
""" """
API client tests using aioresponses for clean HTTP mocking API client tests using aioresponses for clean HTTP mocking
""" """
import pytest import pytest
import asyncio import asyncio
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from aioresponses import aioresponses from aioresponses import aioresponses
from api.client import APIClient, get_api_client, get_global_client, cleanup_global_client from api.client import (
APIClient,
get_api_client,
get_global_client,
cleanup_global_client,
)
from exceptions import APIException from exceptions import APIException
class TestAPIClientWithAioresponses: class TestAPIClientWithAioresponses:
"""Test API client with aioresponses for HTTP mocking.""" """Test API client with aioresponses for HTTP mocking."""
@pytest.fixture @pytest.fixture
def mock_config(self): def mock_config(self):
"""Mock configuration for testing.""" """Mock configuration for testing."""
@ -20,66 +26,57 @@ class TestAPIClientWithAioresponses:
config.db_url = "https://api.example.com" config.db_url = "https://api.example.com"
config.api_token = "test-token" config.api_token = "test-token"
return config return config
@pytest.fixture @pytest.fixture
def api_client(self, mock_config): def api_client(self, mock_config):
"""Create API client with mocked config.""" """Create API client with mocked config."""
with patch('api.client.get_config', return_value=mock_config): with patch("api.client.get_config", return_value=mock_config):
return APIClient() return APIClient()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_request_success(self, api_client): async def test_get_request_success(self, api_client):
"""Test successful GET request.""" """Test successful GET request."""
expected_data = {"id": 1, "name": "Test Player"} expected_data = {"id": 1, "name": "Test Player"}
with aioresponses() as m: with aioresponses() as m:
m.get( m.get(
"https://api.example.com/v3/players/1", "https://api.example.com/v3/players/1",
payload=expected_data, payload=expected_data,
status=200 status=200,
) )
result = await api_client.get("players", object_id=1) result = await api_client.get("players", object_id=1)
assert result == expected_data assert result == expected_data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_request_404(self, api_client): async def test_get_request_404(self, api_client):
"""Test GET request returning 404.""" """Test GET request returning 404."""
with aioresponses() as m: with aioresponses() as m:
m.get( m.get("https://api.example.com/v3/players/999", status=404)
"https://api.example.com/v3/players/999",
status=404
)
result = await api_client.get("players", object_id=999) result = await api_client.get("players", object_id=999)
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_request_401_auth_error(self, api_client): async def test_get_request_401_auth_error(self, api_client):
"""Test GET request with authentication error.""" """Test GET request with authentication error."""
with aioresponses() as m: with aioresponses() as m:
m.get( m.get("https://api.example.com/v3/players", status=401)
"https://api.example.com/v3/players",
status=401
)
with pytest.raises(APIException, match="Authentication failed"): with pytest.raises(APIException, match="Authentication failed"):
await api_client.get("players") await api_client.get("players")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_request_403_forbidden(self, api_client): async def test_get_request_403_forbidden(self, api_client):
"""Test GET request with forbidden error.""" """Test GET request with forbidden error."""
with aioresponses() as m: with aioresponses() as m:
m.get( m.get("https://api.example.com/v3/players", status=403)
"https://api.example.com/v3/players",
status=403
)
with pytest.raises(APIException, match="Access forbidden"): with pytest.raises(APIException, match="Access forbidden"):
await api_client.get("players") await api_client.get("players")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_request_500_server_error(self, api_client): async def test_get_request_500_server_error(self, api_client):
"""Test GET request with server error.""" """Test GET request with server error."""
@ -87,135 +84,127 @@ class TestAPIClientWithAioresponses:
m.get( m.get(
"https://api.example.com/v3/players", "https://api.example.com/v3/players",
status=500, status=500,
body="Internal Server Error" body="Internal Server Error",
) )
with pytest.raises(APIException, match="API request failed with status 500"): with pytest.raises(
APIException, match="API request failed with status 500"
):
await api_client.get("players") await api_client.get("players")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_request_with_params(self, api_client): async def test_get_request_with_params(self, api_client):
"""Test GET request with query parameters.""" """Test GET request with query parameters."""
expected_data = {"count": 2, "players": [{"id": 1}, {"id": 2}]} expected_data = {"count": 2, "players": [{"id": 1}, {"id": 2}]}
with aioresponses() as m: with aioresponses() as m:
m.get( m.get(
"https://api.example.com/v3/players?team_id=5&season=12", "https://api.example.com/v3/players?team_id=5&season=12",
payload=expected_data, payload=expected_data,
status=200 status=200,
) )
result = await api_client.get("players", params=[("team_id", "5"), ("season", "12")]) result = await api_client.get(
"players", params=[("team_id", "5"), ("season", "12")]
)
assert result == expected_data assert result == expected_data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_post_request_success(self, api_client): async def test_post_request_success(self, api_client):
"""Test successful POST request.""" """Test successful POST request."""
input_data = {"name": "New Player", "position": "C"} input_data = {"name": "New Player", "position": "C"}
expected_response = {"id": 1, "name": "New Player", "position": "C"} expected_response = {"id": 1, "name": "New Player", "position": "C"}
with aioresponses() as m: with aioresponses() as m:
m.post( m.post(
"https://api.example.com/v3/players", "https://api.example.com/v3/players",
payload=expected_response, payload=expected_response,
status=201 status=201,
) )
result = await api_client.post("players", input_data) result = await api_client.post("players", input_data)
assert result == expected_response assert result == expected_response
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_post_request_400_error(self, api_client): async def test_post_request_400_error(self, api_client):
"""Test POST request with validation error.""" """Test POST request with validation error."""
input_data = {"invalid": "data"} input_data = {"invalid": "data"}
with aioresponses() as m: with aioresponses() as m:
m.post( m.post(
"https://api.example.com/v3/players", "https://api.example.com/v3/players", status=400, body="Invalid data"
status=400,
body="Invalid data"
) )
with pytest.raises(APIException, match="POST request failed with status 400"): with pytest.raises(
APIException, match="POST request failed with status 400"
):
await api_client.post("players", input_data) await api_client.post("players", input_data)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_put_request_success(self, api_client): async def test_put_request_success(self, api_client):
"""Test successful PUT request.""" """Test successful PUT request."""
update_data = {"name": "Updated Player"} update_data = {"name": "Updated Player"}
expected_response = {"id": 1, "name": "Updated Player"} expected_response = {"id": 1, "name": "Updated Player"}
with aioresponses() as m: with aioresponses() as m:
m.put( m.put(
"https://api.example.com/v3/players/1", "https://api.example.com/v3/players/1",
payload=expected_response, payload=expected_response,
status=200 status=200,
) )
result = await api_client.put("players", update_data, object_id=1) result = await api_client.put("players", update_data, object_id=1)
assert result == expected_response assert result == expected_response
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_put_request_404(self, api_client): async def test_put_request_404(self, api_client):
"""Test PUT request with 404.""" """Test PUT request with 404."""
update_data = {"name": "Updated Player"} update_data = {"name": "Updated Player"}
with aioresponses() as m: with aioresponses() as m:
m.put( m.put("https://api.example.com/v3/players/999", status=404)
"https://api.example.com/v3/players/999",
status=404
)
result = await api_client.put("players", update_data, object_id=999) result = await api_client.put("players", update_data, object_id=999)
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_request_success(self, api_client): async def test_delete_request_success(self, api_client):
"""Test successful DELETE request.""" """Test successful DELETE request."""
with aioresponses() as m: with aioresponses() as m:
m.delete( m.delete("https://api.example.com/v3/players/1", status=204)
"https://api.example.com/v3/players/1",
status=204
)
result = await api_client.delete("players", object_id=1) result = await api_client.delete("players", object_id=1)
assert result is True assert result is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_request_404(self, api_client): async def test_delete_request_404(self, api_client):
"""Test DELETE request with 404.""" """Test DELETE request with 404."""
with aioresponses() as m: with aioresponses() as m:
m.delete( m.delete("https://api.example.com/v3/players/999", status=404)
"https://api.example.com/v3/players/999",
status=404
)
result = await api_client.delete("players", object_id=999) result = await api_client.delete("players", object_id=999)
assert result is False assert result is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_request_200_success(self, api_client): async def test_delete_request_200_success(self, api_client):
"""Test DELETE request with 200 success.""" """Test DELETE request with 200 success."""
with aioresponses() as m: with aioresponses() as m:
m.delete( m.delete("https://api.example.com/v3/players/1", status=200)
"https://api.example.com/v3/players/1",
status=200
)
result = await api_client.delete("players", object_id=1) result = await api_client.delete("players", object_id=1)
assert result is True assert result is True
class TestAPIClientHelpers: class TestAPIClientHelpers:
"""Test API client helper functions.""" """Test API client helper functions."""
@pytest.fixture @pytest.fixture
def mock_config(self): def mock_config(self):
"""Mock configuration for testing.""" """Mock configuration for testing."""
@ -223,49 +212,49 @@ class TestAPIClientHelpers:
config.db_url = "https://api.example.com" config.db_url = "https://api.example.com"
config.api_token = "test-token" config.api_token = "test-token"
return config return config
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_api_client_context_manager(self, mock_config): async def test_get_api_client_context_manager(self, mock_config):
"""Test get_api_client context manager.""" """Test get_api_client context manager."""
with patch('api.client.get_config', return_value=mock_config): with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m: with aioresponses() as m:
m.get( m.get(
"https://api.example.com/v3/test", "https://api.example.com/v3/test",
payload={"success": True}, payload={"success": True},
status=200 status=200,
) )
async with get_api_client() as client: async with get_api_client() as client:
assert isinstance(client, APIClient) assert isinstance(client, APIClient)
result = await client.get("test") result = await client.get("test")
assert result == {"success": True} assert result == {"success": True}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_global_client_management(self, mock_config): async def test_global_client_management(self, mock_config):
"""Test global client getter and cleanup.""" """Test global client getter and cleanup."""
with patch('api.client.get_config', return_value=mock_config): with patch("api.client.get_config", return_value=mock_config):
# Get global client # Get global client
client1 = await get_global_client() client1 = await get_global_client()
client2 = await get_global_client() client2 = await get_global_client()
# Should return same instance # Should return same instance
assert client1 is client2 assert client1 is client2
assert isinstance(client1, APIClient) assert isinstance(client1, APIClient)
# Test cleanup # Test cleanup
await cleanup_global_client() await cleanup_global_client()
# New client should be different instance # New client should be different instance
client3 = await get_global_client() client3 = await get_global_client()
assert client3 is not client1 assert client3 is not client1
# Clean up for other tests # Clean up for other tests
await cleanup_global_client() await cleanup_global_client()
class TestIntegrationScenarios: class TestIntegrationScenarios:
"""Test realistic integration scenarios.""" """Test realistic integration scenarios."""
@pytest.fixture @pytest.fixture
def mock_config(self): def mock_config(self):
"""Mock configuration for testing.""" """Mock configuration for testing."""
@ -273,11 +262,11 @@ class TestIntegrationScenarios:
config.db_url = "https://api.example.com" config.db_url = "https://api.example.com"
config.api_token = "test-token" config.api_token = "test-token"
return config return config
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_player_retrieval_with_team_lookup(self, mock_config): async def test_player_retrieval_with_team_lookup(self, mock_config):
"""Test realistic scenario: get player with team data.""" """Test realistic scenario: get player with team data."""
with patch('api.client.get_config', return_value=mock_config): with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m: with aioresponses() as m:
# Mock player data response # Mock player data response
player_data = { player_data = {
@ -287,43 +276,41 @@ class TestIntegrationScenarios:
"season": 12, "season": 12,
"team_id": 5, "team_id": 5,
"image": "https://example.com/player1.jpg", "image": "https://example.com/player1.jpg",
"pos_1": "C" "pos_1": "C",
} }
m.get( m.get(
"https://api.example.com/v3/players/1", "https://api.example.com/v3/players/1",
payload=player_data, payload=player_data,
status=200 status=200,
) )
# Mock team data response # Mock team data response
team_data = { team_data = {
"id": 5, "id": 5,
"abbrev": "TST", "abbrev": "TST",
"sname": "Test Team", "sname": "Test Team",
"lname": "Test Team Full Name", "lname": "Test Team Full Name",
"season": 12 "season": 12,
} }
m.get( m.get(
"https://api.example.com/v3/teams/5", "https://api.example.com/v3/teams/5", payload=team_data, status=200
payload=team_data,
status=200
) )
client = APIClient() client = APIClient()
# Get player # Get player
player = await client.get("players", object_id=1) player = await client.get("players", object_id=1)
assert player["name"] == "Test Player" assert player["name"] == "Test Player"
assert player["team_id"] == 5 assert player["team_id"] == 5
# Get team for player # Get team for player
team = await client.get("teams", object_id=player["team_id"]) team = await client.get("teams", object_id=player["team_id"])
assert team["sname"] == "Test Team" assert team["sname"] == "Test Team"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_response_format_handling(self, mock_config): async def test_api_response_format_handling(self, mock_config):
"""Test handling of the API's count + list format.""" """Test handling of the API's count + list format."""
with patch('api.client.get_config', return_value=mock_config): with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m: with aioresponses() as m:
# Mock API response with count format # Mock API response with count format
api_response = { api_response = {
@ -336,7 +323,7 @@ class TestIntegrationScenarios:
"season": 12, "season": 12,
"team_id": 5, "team_id": 5,
"image": "https://example.com/player1.jpg", "image": "https://example.com/player1.jpg",
"pos_1": "C" "pos_1": "C",
}, },
{ {
"id": 2, "id": 2,
@ -345,93 +332,93 @@ class TestIntegrationScenarios:
"season": 12, "season": 12,
"team_id": 6, "team_id": 6,
"image": "https://example.com/player2.jpg", "image": "https://example.com/player2.jpg",
"pos_1": "1B" "pos_1": "1B",
} },
] ],
} }
m.get( m.get(
"https://api.example.com/v3/players?team_id=5", "https://api.example.com/v3/players?team_id=5",
payload=api_response, payload=api_response,
status=200 status=200,
) )
client = APIClient() client = APIClient()
result = await client.get("players", params=[("team_id", "5")]) result = await client.get("players", params=[("team_id", "5")])
assert result["count"] == 25 assert result["count"] == 25
assert len(result["players"]) == 2 assert len(result["players"]) == 2
assert result["players"][0]["name"] == "Player 1" assert result["players"][0]["name"] == "Player 1"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_error_recovery_scenarios(self, mock_config): async def test_error_recovery_scenarios(self, mock_config):
"""Test error handling and recovery.""" """Test error handling and recovery."""
with patch('api.client.get_config', return_value=mock_config): with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m: with aioresponses() as m:
# First request fails with 500 # First request fails with 500
m.get( m.get(
"https://api.example.com/v3/players/1", "https://api.example.com/v3/players/1",
status=500, status=500,
body="Internal Server Error" body="Internal Server Error",
) )
# Second request succeeds # Second request succeeds
m.get( m.get(
"https://api.example.com/v3/players/2", "https://api.example.com/v3/players/2",
payload={"id": 2, "name": "Working Player"}, payload={"id": 2, "name": "Working Player"},
status=200 status=200,
) )
client = APIClient() client = APIClient()
# First request should raise exception # First request should raise exception
with pytest.raises(APIException, match="API request failed"): with pytest.raises(APIException, match="API request failed"):
await client.get("players", object_id=1) await client.get("players", object_id=1)
# Second request should work fine # Second request should work fine
result = await client.get("players", object_id=2) result = await client.get("players", object_id=2)
assert result["name"] == "Working Player" assert result["name"] == "Working Player"
# Client should still be functional # Client should still be functional
await client.close() await client.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_concurrent_requests(self, mock_config): async def test_concurrent_requests(self, mock_config):
"""Test multiple concurrent requests.""" """Test multiple concurrent requests."""
import asyncio import asyncio
with patch('api.client.get_config', return_value=mock_config): with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m: with aioresponses() as m:
# Mock multiple endpoints # Mock multiple endpoints
for i in range(1, 4): for i in range(1, 4):
m.get( m.get(
f"https://api.example.com/v3/players/{i}", f"https://api.example.com/v3/players/{i}",
payload={"id": i, "name": f"Player {i}"}, payload={"id": i, "name": f"Player {i}"},
status=200 status=200,
) )
client = APIClient() client = APIClient()
# Make concurrent requests # Make concurrent requests
tasks = [ tasks = [
client.get("players", object_id=1), client.get("players", object_id=1),
client.get("players", object_id=2), client.get("players", object_id=2),
client.get("players", object_id=3) client.get("players", object_id=3),
] ]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
assert len(results) == 3 assert len(results) == 3
assert results[0]["name"] == "Player 1" assert results[0]["name"] == "Player 1"
assert results[1]["name"] == "Player 2" assert results[1]["name"] == "Player 2"
assert results[2]["name"] == "Player 3" assert results[2]["name"] == "Player 3"
await client.close() await client.close()
class TestAPIClientCoverageExtras: class TestAPIClientCoverageExtras:
"""Additional coverage tests for API client edge cases.""" """Additional coverage tests for API client edge cases."""
@pytest.fixture @pytest.fixture
def mock_config(self): def mock_config(self):
"""Mock configuration for testing.""" """Mock configuration for testing."""
@ -439,98 +426,104 @@ class TestAPIClientCoverageExtras:
config.db_url = "https://api.example.com" config.db_url = "https://api.example.com"
config.api_token = "test-token" config.api_token = "test-token"
return config return config
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_global_client_cleanup_when_none(self): async def test_global_client_cleanup_when_none(self):
"""Test cleanup when no global client exists.""" """Test cleanup when no global client exists."""
# Ensure no global client exists # Ensure no global client exists
await cleanup_global_client() await cleanup_global_client()
# Should not raise error # Should not raise error
await cleanup_global_client() await cleanup_global_client()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_url_building_edge_cases(self, mock_config): async def test_url_building_edge_cases(self, mock_config):
"""Test URL building with various edge cases.""" """Test URL building with various edge cases."""
with patch('api.client.get_config', return_value=mock_config): with patch("api.client.get_config", return_value=mock_config):
client = APIClient() client = APIClient()
# Test trailing slash handling # Test trailing slash handling
client.base_url = "https://api.example.com/" client.base_url = "https://api.example.com/"
url = client._build_url("players") url = client._build_url("players")
assert url == "https://api.example.com/v3/players" assert url == "https://api.example.com/v3/players"
assert "//" not in url.replace("https://", "") assert "//" not in url.replace("https://", "")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parameter_handling_edge_cases(self, mock_config): async def test_parameter_handling_edge_cases(self, mock_config):
"""Test parameter handling with various scenarios.""" """Test parameter handling with various scenarios."""
with patch('api.client.get_config', return_value=mock_config): with patch("api.client.get_config", return_value=mock_config):
client = APIClient() client = APIClient()
# Test with existing query string # Test with existing query string
url = client._add_params("https://example.com/api?existing=true", [("new", "param")]) url = client._add_params(
"https://example.com/api?existing=true", [("new", "param")]
)
assert url == "https://example.com/api?existing=true&new=param" assert url == "https://example.com/api?existing=true&new=param"
# Test with no parameters # Test with no parameters
url = client._add_params("https://example.com/api") url = client._add_params("https://example.com/api")
assert url == "https://example.com/api" assert url == "https://example.com/api"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_timeout_error_handling(self, mock_config): async def test_timeout_error_handling(self, mock_config):
"""Test timeout error handling using aioresponses.""" """Test timeout error handling using aioresponses."""
with patch('api.client.get_config', return_value=mock_config): with patch("api.client.get_config", return_value=mock_config):
client = APIClient() client = APIClient()
# Test timeout using aioresponses exception parameter # Test timeout using aioresponses exception parameter
with aioresponses() as m: with aioresponses() as m:
m.get( m.get(
"https://api.example.com/v3/players", "https://api.example.com/v3/players",
exception=asyncio.TimeoutError("Request timed out") exception=asyncio.TimeoutError("Request timed out"),
) )
with pytest.raises(APIException, match="API call failed.*Request timed out"): with pytest.raises(
APIException, match="API call failed.*Request timed out"
):
await client.get("players") await client.get("players")
await client.close() await client.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generic_exception_handling(self, mock_config): async def test_generic_exception_handling(self, mock_config):
"""Test generic exception handling.""" """Test generic exception handling."""
with patch('api.client.get_config', return_value=mock_config): with patch("api.client.get_config", return_value=mock_config):
client = APIClient() client = APIClient()
# Test generic exception # Test generic exception
with aioresponses() as m: with aioresponses() as m:
m.get( m.get(
"https://api.example.com/v3/players", "https://api.example.com/v3/players",
exception=Exception("Generic error") exception=Exception("Generic error"),
) )
with pytest.raises(APIException, match="API call failed.*Generic error"): with pytest.raises(
APIException, match="API call failed.*Generic error"
):
await client.get("players") await client.get("players")
await client.close() await client.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_session_closed_handling(self, mock_config): async def test_session_closed_handling(self, mock_config):
"""Test handling of closed session.""" """Test handling of closed session."""
with patch('api.client.get_config', return_value=mock_config): with patch("api.client.get_config", return_value=mock_config):
# Test that the client recreates session when needed # Test that the client recreates session when needed
with aioresponses() as m: with aioresponses() as m:
m.get( m.get(
"https://api.example.com/v3/players", "https://api.example.com/v3/players",
payload={"success": True}, payload={"success": True},
status=200 status=200,
) )
client = APIClient() client = APIClient()
# Close the session manually # Close the session manually
await client._ensure_session() await client._ensure_session()
await client._session.close() await client._session.close()
# Client should recreate session and work fine # Client should recreate session and work fine
result = await client.get("players") result = await client.get("players")
assert result == {"success": True} assert result == {"success": True}
await client.close() await client.close()

View File

@ -1,6 +1,7 @@
""" """
Tests for BaseService functionality Tests for BaseService functionality
""" """
import pytest import pytest
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
@ -10,6 +11,7 @@ from models.base import SBABaseModel
class MockModel(SBABaseModel): class MockModel(SBABaseModel):
"""Mock model for testing BaseService.""" """Mock model for testing BaseService."""
id: int id: int
name: str name: str
value: int = 100 value: int = 100
@ -17,240 +19,229 @@ class MockModel(SBABaseModel):
class TestBaseService: class TestBaseService:
"""Test BaseService functionality.""" """Test BaseService functionality."""
@pytest.fixture @pytest.fixture
def mock_client(self): def mock_client(self):
"""Mock API client.""" """Mock API client."""
client = AsyncMock() client = AsyncMock()
return client return client
@pytest.fixture @pytest.fixture
def base_service(self, mock_client): def base_service(self, mock_client):
"""Create BaseService instance for testing.""" """Create BaseService instance for testing."""
service = BaseService(MockModel, 'mocks', client=mock_client) service = BaseService(MockModel, "mocks", client=mock_client)
return service return service
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_init(self): async def test_init(self):
"""Test service initialization.""" """Test service initialization."""
service = BaseService(MockModel, 'test_endpoint') service = BaseService(MockModel, "test_endpoint")
assert service.model_class == MockModel assert service.model_class == MockModel
assert service.endpoint == 'test_endpoint' assert service.endpoint == "test_endpoint"
assert service._client is None assert service._client is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_by_id_success(self, base_service, mock_client): async def test_get_by_id_success(self, base_service, mock_client):
"""Test successful get_by_id.""" """Test successful get_by_id."""
mock_data = {'id': 1, 'name': 'Test', 'value': 200} mock_data = {"id": 1, "name": "Test", "value": 200}
mock_client.get.return_value = mock_data mock_client.get.return_value = mock_data
result = await base_service.get_by_id(1) result = await base_service.get_by_id(1)
assert isinstance(result, MockModel) assert isinstance(result, MockModel)
assert result.id == 1 assert result.id == 1
assert result.name == 'Test' assert result.name == "Test"
assert result.value == 200 assert result.value == 200
mock_client.get.assert_called_once_with('mocks', object_id=1) mock_client.get.assert_called_once_with("mocks", object_id=1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_by_id_not_found(self, base_service, mock_client): async def test_get_by_id_not_found(self, base_service, mock_client):
"""Test get_by_id when object not found.""" """Test get_by_id when object not found."""
mock_client.get.return_value = None mock_client.get.return_value = None
result = await base_service.get_by_id(999) result = await base_service.get_by_id(999)
assert result is None assert result is None
mock_client.get.assert_called_once_with('mocks', object_id=999) mock_client.get.assert_called_once_with("mocks", object_id=999)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_all_with_count(self, base_service, mock_client): async def test_get_all_with_count(self, base_service, mock_client):
"""Test get_all with count response format.""" """Test get_all with count response format."""
mock_data = { mock_data = {
'count': 2, "count": 2,
'mocks': [ "mocks": [
{'id': 1, 'name': 'Test1', 'value': 100}, {"id": 1, "name": "Test1", "value": 100},
{'id': 2, 'name': 'Test2', 'value': 200} {"id": 2, "name": "Test2", "value": 200},
] ],
} }
mock_client.get.return_value = mock_data mock_client.get.return_value = mock_data
result, count = await base_service.get_all() result, count = await base_service.get_all()
assert len(result) == 2 assert len(result) == 2
assert count == 2 assert count == 2
assert all(isinstance(item, MockModel) for item in result) assert all(isinstance(item, MockModel) for item in result)
mock_client.get.assert_called_once_with('mocks', params=None) mock_client.get.assert_called_once_with("mocks", params=None)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_all_items_convenience(self, base_service, mock_client): async def test_get_all_items_convenience(self, base_service, mock_client):
"""Test get_all_items convenience method.""" """Test get_all_items convenience method."""
mock_data = { mock_data = {"count": 1, "mocks": [{"id": 1, "name": "Test", "value": 100}]}
'count': 1,
'mocks': [{'id': 1, 'name': 'Test', 'value': 100}]
}
mock_client.get.return_value = mock_data mock_client.get.return_value = mock_data
result = await base_service.get_all_items() result = await base_service.get_all_items()
assert len(result) == 1 assert len(result) == 1
assert isinstance(result[0], MockModel) assert isinstance(result[0], MockModel)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_success(self, base_service, mock_client): async def test_create_success(self, base_service, mock_client):
"""Test successful object creation.""" """Test successful object creation."""
input_data = {'name': 'New Item', 'value': 300} input_data = {"name": "New Item", "value": 300}
response_data = {'id': 3, 'name': 'New Item', 'value': 300} response_data = {"id": 3, "name": "New Item", "value": 300}
mock_client.post.return_value = response_data mock_client.post.return_value = response_data
result = await base_service.create(input_data) result = await base_service.create(input_data)
assert isinstance(result, MockModel) assert isinstance(result, MockModel)
assert result.id == 3 assert result.id == 3
assert result.name == 'New Item' assert result.name == "New Item"
mock_client.post.assert_called_once_with('mocks', input_data) mock_client.post.assert_called_once_with("mocks/", input_data)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_success(self, base_service, mock_client): async def test_update_success(self, base_service, mock_client):
"""Test successful object update.""" """Test successful object update."""
update_data = {'name': 'Updated'} update_data = {"name": "Updated"}
response_data = {'id': 1, 'name': 'Updated', 'value': 100} response_data = {"id": 1, "name": "Updated", "value": 100}
mock_client.put.return_value = response_data mock_client.put.return_value = response_data
result = await base_service.update(1, update_data) result = await base_service.update(1, update_data)
assert isinstance(result, MockModel) assert isinstance(result, MockModel)
assert result.name == 'Updated' assert result.name == "Updated"
mock_client.put.assert_called_once_with('mocks', update_data, object_id=1) mock_client.put.assert_called_once_with("mocks", update_data, object_id=1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_success(self, base_service, mock_client): async def test_delete_success(self, base_service, mock_client):
"""Test successful object deletion.""" """Test successful object deletion."""
mock_client.delete.return_value = True mock_client.delete.return_value = True
result = await base_service.delete(1) result = await base_service.delete(1)
assert result is True assert result is True
mock_client.delete.assert_called_once_with('mocks', object_id=1) mock_client.delete.assert_called_once_with("mocks", object_id=1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_by_field(self, base_service, mock_client): async def test_get_by_field(self, base_service, mock_client):
"""Test get_by_field functionality.""" """Test get_by_field functionality."""
mock_data = { mock_data = {"count": 1, "mocks": [{"id": 1, "name": "Test", "value": 100}]}
'count': 1,
'mocks': [{'id': 1, 'name': 'Test', 'value': 100}]
}
mock_client.get.return_value = mock_data mock_client.get.return_value = mock_data
result = await base_service.get_by_field('name', 'Test') result = await base_service.get_by_field("name", "Test")
assert len(result) == 1 assert len(result) == 1
mock_client.get.assert_called_once_with('mocks', params=[('name', 'Test')]) mock_client.get.assert_called_once_with("mocks", params=[("name", "Test")])
def test_extract_items_and_count_standard_format(self, base_service): def test_extract_items_and_count_standard_format(self, base_service):
"""Test response parsing for standard format.""" """Test response parsing for standard format."""
data = { data = {
'count': 3, "count": 3,
'mocks': [ "mocks": [
{'id': 1, 'name': 'Test1'}, {"id": 1, "name": "Test1"},
{'id': 2, 'name': 'Test2'}, {"id": 2, "name": "Test2"},
{'id': 3, 'name': 'Test3'} {"id": 3, "name": "Test3"},
] ],
} }
items, count = base_service._extract_items_and_count_from_response(data) items, count = base_service._extract_items_and_count_from_response(data)
assert len(items) == 3 assert len(items) == 3
assert count == 3 assert count == 3
assert items[0]['name'] == 'Test1' assert items[0]["name"] == "Test1"
def test_extract_items_and_count_single_object(self, base_service): def test_extract_items_and_count_single_object(self, base_service):
"""Test response parsing for single object.""" """Test response parsing for single object."""
data = {'id': 1, 'name': 'Single'} data = {"id": 1, "name": "Single"}
items, count = base_service._extract_items_and_count_from_response(data) items, count = base_service._extract_items_and_count_from_response(data)
assert len(items) == 1 assert len(items) == 1
assert count == 1 assert count == 1
assert items[0] == data assert items[0] == data
def test_extract_items_and_count_direct_list(self, base_service): def test_extract_items_and_count_direct_list(self, base_service):
"""Test response parsing for direct list.""" """Test response parsing for direct list."""
data = [ data = [{"id": 1, "name": "Test1"}, {"id": 2, "name": "Test2"}]
{'id': 1, 'name': 'Test1'},
{'id': 2, 'name': 'Test2'}
]
items, count = base_service._extract_items_and_count_from_response(data) items, count = base_service._extract_items_and_count_from_response(data)
assert len(items) == 2 assert len(items) == 2
assert count == 2 assert count == 2
class TestBaseServiceExtras: class TestBaseServiceExtras:
"""Additional coverage tests for BaseService edge cases.""" """Additional coverage tests for BaseService edge cases."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_base_service_additional_methods(self): async def test_base_service_additional_methods(self):
"""Test additional BaseService methods for coverage.""" """Test additional BaseService methods for coverage."""
from services.base_service import BaseService from services.base_service import BaseService
from models.base import SBABaseModel from models.base import SBABaseModel
class TestModel(SBABaseModel): class TestModel(SBABaseModel):
name: str name: str
value: int = 100 value: int = 100
mock_client = AsyncMock() mock_client = AsyncMock()
service = BaseService(TestModel, 'test', client=mock_client) service = BaseService(TestModel, "test", client=mock_client)
# Test count method # Test count method
mock_client.reset_mock() mock_client.reset_mock()
mock_client.get.return_value = {'count': 42, 'test': []} mock_client.get.return_value = {"count": 42, "test": []}
count = await service.count(params=[('active', 'true')]) count = await service.count(params=[("active", "true")])
assert count == 42 assert count == 42
# Test update_from_model with ID # Test update_from_model with ID
mock_client.reset_mock() mock_client.reset_mock()
model = TestModel(id=1, name="Updated", value=300) model = TestModel(id=1, name="Updated", value=300)
mock_client.put.return_value = {"id": 1, "name": "Updated", "value": 300} mock_client.put.return_value = {"id": 1, "name": "Updated", "value": 300}
result = await service.update_from_model(model) result = await service.update_from_model(model)
assert result.name == "Updated" assert result.name == "Updated"
# Test update_from_model without ID # Test update_from_model without ID
model_no_id = TestModel(name="Test") model_no_id = TestModel(name="Test")
with pytest.raises(ValueError, match="Cannot update TestModel without ID"): with pytest.raises(ValueError, match="Cannot update TestModel without ID"):
await service.update_from_model(model_no_id) await service.update_from_model(model_no_id)
def test_base_service_response_parsing_edge_cases(self): def test_base_service_response_parsing_edge_cases(self):
"""Test edge cases in response parsing.""" """Test edge cases in response parsing."""
from services.base_service import BaseService from services.base_service import BaseService
from models.base import SBABaseModel from models.base import SBABaseModel
class TestModel(SBABaseModel): class TestModel(SBABaseModel):
name: str name: str
service = BaseService(TestModel, 'test') service = BaseService(TestModel, "test")
# Test with 'items' field # Test with 'items' field
data = {'count': 2, 'items': [{'name': 'Item1'}, {'name': 'Item2'}]} data = {"count": 2, "items": [{"name": "Item1"}, {"name": "Item2"}]}
items, count = service._extract_items_and_count_from_response(data) items, count = service._extract_items_and_count_from_response(data)
assert len(items) == 2 assert len(items) == 2
assert count == 2 assert count == 2
# Test with 'data' field # Test with 'data' field
data = {'count': 1, 'data': [{'name': 'DataItem'}]} data = {"count": 1, "data": [{"name": "DataItem"}]}
items, count = service._extract_items_and_count_from_response(data) items, count = service._extract_items_and_count_from_response(data)
assert len(items) == 1 assert len(items) == 1
assert count == 1 assert count == 1
# Test with count but no recognizable list field # Test with count but no recognizable list field
data = {'count': 5, 'unknown_field': [{'name': 'Item'}]} data = {"count": 5, "unknown_field": [{"name": "Item"}]}
items, count = service._extract_items_and_count_from_response(data) items, count = service._extract_items_and_count_from_response(data)
assert len(items) == 0 assert len(items) == 0
assert count == 5 assert count == 5
# Test with unexpected data type # Test with unexpected data type
items, count = service._extract_items_and_count_from_response("unexpected") items, count = service._extract_items_and_count_from_response("unexpected")
assert len(items) == 0 assert len(items) == 0
assert count == 0 assert count == 0

View File

@ -3,6 +3,7 @@ Interactive Trade Embed Views
Handles the Discord embed and button interfaces for the multi-team trade builder. Handles the Discord embed and button interfaces for the multi-team trade builder.
""" """
import discord import discord
from typing import Optional, List from typing import Optional, List
from datetime import datetime, timezone from datetime import datetime, timezone
@ -31,60 +32,56 @@ class TradeEmbedView(discord.ui.View):
"""Check if user has permission to interact with this view.""" """Check if user has permission to interact with this view."""
if interaction.user.id != self.user_id: if interaction.user.id != self.user_id:
await interaction.response.send_message( await interaction.response.send_message(
"You don't have permission to use this trade builder.", "You don't have permission to use this trade builder.",
ephemeral=True ephemeral=True,
) )
return False return False
return True return True
async def on_timeout(self) -> None: async def on_timeout(self) -> None:
"""Handle view timeout.""" """Handle view timeout."""
# Disable all buttons when timeout occurs
for item in self.children: for item in self.children:
if isinstance(item, discord.ui.Button): if isinstance(item, discord.ui.Button):
item.disabled = True item.disabled = True
@discord.ui.button(label="Remove Move", style=discord.ButtonStyle.red, emoji="") @discord.ui.button(label="Remove Move", style=discord.ButtonStyle.red)
async def remove_move_button(self, interaction: discord.Interaction, button: discord.ui.Button): async def remove_move_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle remove move button click.""" """Handle remove move button click."""
if self.builder.is_empty: if self.builder.is_empty:
await interaction.response.send_message( await interaction.response.send_message(
"❌ No moves to remove. Add some moves first!", "No moves to remove. Add some moves first!", ephemeral=True
ephemeral=True
) )
return return
# Create select menu for move removal
select_view = RemoveTradeMovesView(self.builder, self.user_id) select_view = RemoveTradeMovesView(self.builder, self.user_id)
embed = await create_trade_embed(self.builder) embed = await create_trade_embed(self.builder)
await interaction.response.edit_message(embed=embed, view=select_view) await interaction.response.edit_message(embed=embed, view=select_view)
@discord.ui.button(label="Validate Trade", style=discord.ButtonStyle.secondary, emoji="🔍") @discord.ui.button(label="Validate Trade", style=discord.ButtonStyle.secondary)
async def validate_button(self, interaction: discord.Interaction, button: discord.ui.Button): async def validate_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle validate trade button click.""" """Handle validate trade button click."""
await interaction.response.defer(ephemeral=True) await interaction.response.defer(ephemeral=True)
# Perform detailed validation
validation = await self.builder.validate_trade() validation = await self.builder.validate_trade()
# Create validation report
if validation.is_legal: if validation.is_legal:
status_emoji = ""
status_text = "**Trade is LEGAL**" status_text = "**Trade is LEGAL**"
color = EmbedColors.SUCCESS color = EmbedColors.SUCCESS
else: else:
status_emoji = ""
status_text = "**Trade has ERRORS**" status_text = "**Trade has ERRORS**"
color = EmbedColors.ERROR color = EmbedColors.ERROR
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title=f"{status_emoji} Trade Validation Report", title="Trade Validation Report",
description=status_text, description=status_text,
color=color color=color,
) )
# Add team-by-team validation
for participant in self.builder.trade.participants: for participant in self.builder.trade.participants:
team_validation = validation.get_participant_validation(participant.team.id) team_validation = validation.get_participant_validation(participant.team.id)
if team_validation: if team_validation:
@ -98,72 +95,65 @@ class TradeEmbedView(discord.ui.View):
team_status.append(team_validation.pre_existing_transactions_note) team_status.append(team_validation.pre_existing_transactions_note)
embed.add_field( embed.add_field(
name=f"🏟️ {participant.team.abbrev} - {participant.team.sname}", name=f"{participant.team.abbrev} - {participant.team.sname}",
value="\n".join(team_status), value="\n".join(team_status),
inline=False inline=False,
) )
# Add overall errors and suggestions
if validation.all_errors: if validation.all_errors:
error_text = "\n".join([f"{error}" for error in validation.all_errors]) error_text = "\n".join([f"- {error}" for error in validation.all_errors])
embed.add_field( embed.add_field(name="Errors", value=error_text, inline=False)
name="❌ Errors",
value=error_text,
inline=False
)
if validation.all_suggestions: if validation.all_suggestions:
suggestion_text = "\n".join([f"💡 {suggestion}" for suggestion in validation.all_suggestions]) suggestion_text = "\n".join(
embed.add_field( [f"- {suggestion}" for suggestion in validation.all_suggestions]
name="💡 Suggestions",
value=suggestion_text,
inline=False
) )
embed.add_field(name="Suggestions", value=suggestion_text, inline=False)
await interaction.followup.send(embed=embed, ephemeral=True) await interaction.followup.send(embed=embed, ephemeral=True)
@discord.ui.button(label="Submit Trade", style=discord.ButtonStyle.primary, emoji="📤") @discord.ui.button(label="Submit Trade", style=discord.ButtonStyle.primary)
async def submit_button(self, interaction: discord.Interaction, button: discord.ui.Button): async def submit_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle submit trade button click.""" """Handle submit trade button click."""
if self.builder.is_empty: if self.builder.is_empty:
await interaction.response.send_message( await interaction.response.send_message(
"❌ Cannot submit empty trade. Add some moves first!", "Cannot submit empty trade. Add some moves first!", ephemeral=True
ephemeral=True
) )
return return
# Validate before submission
validation = await self.builder.validate_trade() validation = await self.builder.validate_trade()
if not validation.is_legal: if not validation.is_legal:
error_msg = "**Cannot submit illegal trade:**\n" error_msg = "**Cannot submit illegal trade:**\n"
error_msg += "\n".join([f" {error}" for error in validation.all_errors]) error_msg += "\n".join([f"- {error}" for error in validation.all_errors])
if validation.all_suggestions: if validation.all_suggestions:
error_msg += "\n\n**Suggestions:**\n" error_msg += "\n\n**Suggestions:**\n"
error_msg += "\n".join([f"💡 {suggestion}" for suggestion in validation.all_suggestions]) error_msg += "\n".join(
[f"- {suggestion}" for suggestion in validation.all_suggestions]
)
await interaction.response.send_message(error_msg, ephemeral=True) await interaction.response.send_message(error_msg, ephemeral=True)
return return
# Show confirmation modal
modal = SubmitTradeConfirmationModal(self.builder) modal = SubmitTradeConfirmationModal(self.builder)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@discord.ui.button(label="Cancel Trade", style=discord.ButtonStyle.secondary, emoji="") @discord.ui.button(label="Cancel Trade", style=discord.ButtonStyle.secondary)
async def cancel_button(self, interaction: discord.Interaction, button: discord.ui.Button): async def cancel_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle cancel trade button click.""" """Handle cancel trade button click."""
self.builder.clear_trade() self.builder.clear_trade()
embed = await create_trade_embed(self.builder) embed = await create_trade_embed(self.builder)
# Disable all buttons after cancellation
for item in self.children: for item in self.children:
if isinstance(item, discord.ui.Button): if isinstance(item, discord.ui.Button):
item.disabled = True item.disabled = True
await interaction.response.edit_message( await interaction.response.edit_message(
content="❌ **Trade cancelled and cleared.**", content="**Trade cancelled and cleared.**", embed=embed, view=self
embed=embed,
view=self
) )
self.stop() self.stop()
@ -176,12 +166,12 @@ class RemoveTradeMovesView(discord.ui.View):
self.builder = builder self.builder = builder
self.user_id = user_id self.user_id = user_id
# Create select menu with current moves
if not builder.is_empty: if not builder.is_empty:
self.add_item(RemoveTradeMovesSelect(builder)) self.add_item(RemoveTradeMovesSelect(builder))
# Add back button back_button = discord.ui.Button(
back_button = discord.ui.Button(label="Back", style=discord.ButtonStyle.secondary, emoji="⬅️") label="Back", style=discord.ButtonStyle.secondary
)
back_button.callback = self.back_callback back_button.callback = self.back_callback
self.add_item(back_button) self.add_item(back_button)
@ -202,35 +192,36 @@ class RemoveTradeMovesSelect(discord.ui.Select):
def __init__(self, builder: TradeBuilder): def __init__(self, builder: TradeBuilder):
self.builder = builder self.builder = builder
# Create options from all moves (cross-team and supplementary)
options = [] options = []
move_count = 0 move_count = 0
# Add cross-team moves for move in builder.trade.cross_team_moves[
for move in builder.trade.cross_team_moves[:20]: # Limit to avoid Discord's 25 option limit :20
options.append(discord.SelectOption( ]: # Limit to avoid Discord's 25 option limit
label=f"{move.player.name}", options.append(
description=move.description[:100], # Discord description limit discord.SelectOption(
value=str(move.player.id), label=f"{move.player.name}",
emoji="🔄" description=move.description[:100],
)) value=str(move.player.id),
)
)
move_count += 1 move_count += 1
# Add supplementary moves if there's room
remaining_slots = 25 - move_count remaining_slots = 25 - move_count
for move in builder.trade.supplementary_moves[:remaining_slots]: for move in builder.trade.supplementary_moves[:remaining_slots]:
options.append(discord.SelectOption( options.append(
label=f"{move.player.name}", discord.SelectOption(
description=move.description[:100], label=f"{move.player.name}",
value=str(move.player.id), description=move.description[:100],
emoji="⚙️" value=str(move.player.id),
)) )
)
super().__init__( super().__init__(
placeholder="Select a move to remove...", placeholder="Select a move to remove...",
min_values=1, min_values=1,
max_values=1, max_values=1,
options=options options=options,
) )
async def callback(self, interaction: discord.Interaction): async def callback(self, interaction: discord.Interaction):
@ -241,27 +232,25 @@ class RemoveTradeMovesSelect(discord.ui.Select):
if success: if success:
await interaction.response.send_message( await interaction.response.send_message(
f"✅ Removed move for player ID {player_id}", f"Removed move for player ID {player_id}", ephemeral=True
ephemeral=True
) )
# Update the embed
main_view = TradeEmbedView(self.builder, interaction.user.id) main_view = TradeEmbedView(self.builder, interaction.user.id)
embed = await create_trade_embed(self.builder) embed = await create_trade_embed(self.builder)
# Edit the original message
await interaction.edit_original_response(embed=embed, view=main_view) await interaction.edit_original_response(embed=embed, view=main_view)
else: else:
await interaction.response.send_message( await interaction.response.send_message(
f"❌ Could not remove move: {error_msg}", f"Could not remove move: {error_msg}", ephemeral=True
ephemeral=True
) )
class SubmitTradeConfirmationModal(discord.ui.Modal): class SubmitTradeConfirmationModal(discord.ui.Modal):
"""Modal for confirming trade submission - posts acceptance request to trade channel.""" """Modal for confirming trade submission - posts acceptance request to trade channel."""
def __init__(self, builder: TradeBuilder, trade_channel: Optional[discord.TextChannel] = None): def __init__(
self, builder: TradeBuilder, trade_channel: Optional[discord.TextChannel] = None
):
super().__init__(title="Confirm Trade Submission") super().__init__(title="Confirm Trade Submission")
self.builder = builder self.builder = builder
self.trade_channel = trade_channel self.trade_channel = trade_channel
@ -270,7 +259,7 @@ class SubmitTradeConfirmationModal(discord.ui.Modal):
label="Type 'CONFIRM' to submit for approval", label="Type 'CONFIRM' to submit for approval",
placeholder="CONFIRM", placeholder="CONFIRM",
required=True, required=True,
max_length=7 max_length=7,
) )
self.add_item(self.confirmation) self.add_item(self.confirmation)
@ -279,56 +268,52 @@ class SubmitTradeConfirmationModal(discord.ui.Modal):
"""Handle confirmation submission - posts acceptance view to trade channel.""" """Handle confirmation submission - posts acceptance view to trade channel."""
if self.confirmation.value.upper() != "CONFIRM": if self.confirmation.value.upper() != "CONFIRM":
await interaction.response.send_message( await interaction.response.send_message(
"Trade not submitted. You must type 'CONFIRM' exactly.", "Trade not submitted. You must type 'CONFIRM' exactly.",
ephemeral=True ephemeral=True,
) )
return return
await interaction.response.defer(ephemeral=True) await interaction.response.defer(ephemeral=True)
try: try:
# Update trade status to PROPOSED
from models.trade import TradeStatus from models.trade import TradeStatus
self.builder.trade.status = TradeStatus.PROPOSED self.builder.trade.status = TradeStatus.PROPOSED
# Create acceptance embed and view
acceptance_embed = await create_trade_acceptance_embed(self.builder) acceptance_embed = await create_trade_acceptance_embed(self.builder)
acceptance_view = TradeAcceptanceView(self.builder) acceptance_view = TradeAcceptanceView(self.builder)
# Find the trade channel to post to
channel = self.trade_channel channel = self.trade_channel
if not channel: if not channel:
# Try to find trade channel by name pattern
trade_channel_name = f"trade-{'-'.join(t.abbrev.lower() for t in self.builder.participating_teams)}"
for ch in interaction.guild.text_channels: # type: ignore for ch in interaction.guild.text_channels: # type: ignore
if ch.name.startswith("trade-") and self.builder.trade_id[:4] in ch.name: if (
ch.name.startswith("trade-")
and self.builder.trade_id[:4] in ch.name
):
channel = ch channel = ch
break break
if channel: if channel:
# Post acceptance request to trade channel
await channel.send( await channel.send(
content="📋 **Trade submitted for approval!** All teams must accept to complete the trade.", content="**Trade submitted for approval.** All teams must accept to complete the trade.",
embed=acceptance_embed, embed=acceptance_embed,
view=acceptance_view view=acceptance_view,
) )
await interaction.followup.send( await interaction.followup.send(
f"✅ Trade submitted for approval!\n\nThe acceptance request has been posted to {channel.mention}.\n" f"Trade submitted for approval.\n\nThe acceptance request has been posted to {channel.mention}.\n"
f"All participating teams must click **Accept Trade** to finalize.", f"All participating teams must click **Accept Trade** to finalize.",
ephemeral=True ephemeral=True,
) )
else: else:
# No trade channel found, post in current channel
await interaction.followup.send( await interaction.followup.send(
content="📋 **Trade submitted for approval!** All teams must accept to complete the trade.", content="**Trade submitted for approval.** All teams must accept to complete the trade.",
embed=acceptance_embed, embed=acceptance_embed,
view=acceptance_view view=acceptance_view,
) )
except Exception as e: except Exception as e:
await interaction.followup.send( await interaction.followup.send(
f"❌ Error submitting trade: {str(e)}", f"Error submitting trade: {str(e)}", ephemeral=True
ephemeral=True
) )
@ -343,8 +328,11 @@ class TradeAcceptanceView(discord.ui.View):
"""Get the team owned by the interacting user.""" """Get the team owned by the interacting user."""
from services.team_service import team_service from services.team_service import team_service
from config import get_config from config import get_config
config = get_config() config = get_config()
return await team_service.get_team_by_owner(interaction.user.id, config.sba_season) return await team_service.get_team_by_owner(
interaction.user.id, config.sba_season
)
async def interaction_check(self, interaction: discord.Interaction) -> bool: async def interaction_check(self, interaction: discord.Interaction) -> bool:
"""Check if user is a GM of a participating team.""" """Check if user is a GM of a participating team."""
@ -352,17 +340,14 @@ class TradeAcceptanceView(discord.ui.View):
if not user_team: if not user_team:
await interaction.response.send_message( await interaction.response.send_message(
"❌ You don't own a team in the league.", "You don't own a team in the league.", ephemeral=True
ephemeral=True
) )
return False return False
# Check if their team (or organization) is participating
participant = self.builder.trade.get_participant_by_organization(user_team) participant = self.builder.trade.get_participant_by_organization(user_team)
if not participant: if not participant:
await interaction.response.send_message( await interaction.response.send_message(
"❌ Your team is not part of this trade.", "Your team is not part of this trade.", ephemeral=True
ephemeral=True
) )
return False return False
@ -374,47 +359,45 @@ class TradeAcceptanceView(discord.ui.View):
if isinstance(item, discord.ui.Button): if isinstance(item, discord.ui.Button):
item.disabled = True item.disabled = True
@discord.ui.button(label="Accept Trade", style=discord.ButtonStyle.success, emoji="") @discord.ui.button(label="Accept Trade", style=discord.ButtonStyle.success)
async def accept_button(self, interaction: discord.Interaction, button: discord.ui.Button): async def accept_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle accept button click.""" """Handle accept button click."""
user_team = await self._get_user_team(interaction) user_team = await self._get_user_team(interaction)
if not user_team: if not user_team:
return return
# Find the participating team (could be org affiliate)
participant = self.builder.trade.get_participant_by_organization(user_team) participant = self.builder.trade.get_participant_by_organization(user_team)
if not participant: if not participant:
return return
team_id = participant.team.id team_id = participant.team.id
# Check if already accepted
if self.builder.has_team_accepted(team_id): if self.builder.has_team_accepted(team_id):
await interaction.response.send_message( await interaction.response.send_message(
f"{participant.team.abbrev} has already accepted this trade.", f"{participant.team.abbrev} has already accepted this trade.",
ephemeral=True ephemeral=True,
) )
return return
# Record acceptance
all_accepted = self.builder.accept_trade(team_id) all_accepted = self.builder.accept_trade(team_id)
if all_accepted: if all_accepted:
# All teams accepted - finalize the trade
await self._finalize_trade(interaction) await self._finalize_trade(interaction)
else: else:
# Update embed to show new acceptance status
embed = await create_trade_acceptance_embed(self.builder) embed = await create_trade_acceptance_embed(self.builder)
await interaction.response.edit_message(embed=embed, view=self) await interaction.response.edit_message(embed=embed, view=self)
# Send confirmation to channel
await interaction.followup.send( await interaction.followup.send(
f"**{participant.team.abbrev}** has accepted the trade! " f"**{participant.team.abbrev}** has accepted the trade. "
f"({len(self.builder.accepted_teams)}/{self.builder.team_count} teams)" f"({len(self.builder.accepted_teams)}/{self.builder.team_count} teams)"
) )
@discord.ui.button(label="Reject Trade", style=discord.ButtonStyle.danger, emoji="") @discord.ui.button(label="Reject Trade", style=discord.ButtonStyle.danger)
async def reject_button(self, interaction: discord.Interaction, button: discord.ui.Button): async def reject_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle reject button click - moves trade back to DRAFT.""" """Handle reject button click - moves trade back to DRAFT."""
user_team = await self._get_user_team(interaction) user_team = await self._get_user_team(interaction)
if not user_team: if not user_team:
@ -424,20 +407,16 @@ class TradeAcceptanceView(discord.ui.View):
if not participant: if not participant:
return return
# Reject the trade
self.builder.reject_trade() self.builder.reject_trade()
# Disable buttons
self.accept_button.disabled = True self.accept_button.disabled = True
self.reject_button.disabled = True self.reject_button.disabled = True
# Update embed to show rejection
embed = await create_trade_rejection_embed(self.builder, participant.team) embed = await create_trade_rejection_embed(self.builder, participant.team)
await interaction.response.edit_message(embed=embed, view=self) await interaction.response.edit_message(embed=embed, view=self)
# Notify the channel
await interaction.followup.send( await interaction.followup.send(
f"**{participant.team.abbrev}** has rejected the trade.\n\n" f"**{participant.team.abbrev}** has rejected the trade.\n\n"
f"The trade has been moved back to **DRAFT** status. " f"The trade has been moved back to **DRAFT** status. "
f"Teams can continue negotiating using `/trade` commands." f"Teams can continue negotiating using `/trade` commands."
) )
@ -459,41 +438,52 @@ class TradeAcceptanceView(discord.ui.View):
config = get_config() config = get_config()
# Get next week for transactions
current = await league_service.get_current_state() current = await league_service.get_current_state()
next_week = current.week + 1 if current else 1 next_week = current.week + 1 if current else 1
# Create FA team for reference
fa_team = Team( fa_team = Team(
id=config.free_agent_team_id, id=config.free_agent_team_id,
abbrev="FA", abbrev="FA",
sname="Free Agents", sname="Free Agents",
lname="Free Agency", lname="Free Agency",
season=self.builder.trade.season season=self.builder.trade.season,
) # type: ignore ) # type: ignore
# Create transactions from all moves
transactions: List[Transaction] = [] transactions: List[Transaction] = []
move_id = f"Trade-{self.builder.trade_id}-{int(datetime.now(timezone.utc).timestamp())}" move_id = f"Trade-{self.builder.trade_id}-{int(datetime.now(timezone.utc).timestamp())}"
# Process cross-team moves
for move in self.builder.trade.cross_team_moves: for move in self.builder.trade.cross_team_moves:
# Get actual team affiliates for from/to based on roster type
if move.from_roster == RosterType.MAJOR_LEAGUE: if move.from_roster == RosterType.MAJOR_LEAGUE:
old_team = move.source_team old_team = move.source_team
elif move.from_roster == RosterType.MINOR_LEAGUE: elif move.from_roster == RosterType.MINOR_LEAGUE:
old_team = await move.source_team.minor_league_affiliate() if move.source_team else None old_team = (
await move.source_team.minor_league_affiliate()
if move.source_team
else None
)
elif move.from_roster == RosterType.INJURED_LIST: elif move.from_roster == RosterType.INJURED_LIST:
old_team = await move.source_team.injured_list_affiliate() if move.source_team else None old_team = (
await move.source_team.injured_list_affiliate()
if move.source_team
else None
)
else: else:
old_team = move.source_team old_team = move.source_team
if move.to_roster == RosterType.MAJOR_LEAGUE: if move.to_roster == RosterType.MAJOR_LEAGUE:
new_team = move.destination_team new_team = move.destination_team
elif move.to_roster == RosterType.MINOR_LEAGUE: elif move.to_roster == RosterType.MINOR_LEAGUE:
new_team = await move.destination_team.minor_league_affiliate() if move.destination_team else None new_team = (
await move.destination_team.minor_league_affiliate()
if move.destination_team
else None
)
elif move.to_roster == RosterType.INJURED_LIST: elif move.to_roster == RosterType.INJURED_LIST:
new_team = await move.destination_team.injured_list_affiliate() if move.destination_team else None new_team = (
await move.destination_team.injured_list_affiliate()
if move.destination_team
else None
)
else: else:
new_team = move.destination_team new_team = move.destination_team
@ -507,18 +497,25 @@ class TradeAcceptanceView(discord.ui.View):
oldteam=old_team, oldteam=old_team,
newteam=new_team, newteam=new_team,
cancelled=False, cancelled=False,
frozen=False # Trades are NOT frozen - immediately effective frozen=False,
) )
transactions.append(transaction) transactions.append(transaction)
# Process supplementary moves
for move in self.builder.trade.supplementary_moves: for move in self.builder.trade.supplementary_moves:
if move.from_roster == RosterType.MAJOR_LEAGUE: if move.from_roster == RosterType.MAJOR_LEAGUE:
old_team = move.source_team old_team = move.source_team
elif move.from_roster == RosterType.MINOR_LEAGUE: elif move.from_roster == RosterType.MINOR_LEAGUE:
old_team = await move.source_team.minor_league_affiliate() if move.source_team else None old_team = (
await move.source_team.minor_league_affiliate()
if move.source_team
else None
)
elif move.from_roster == RosterType.INJURED_LIST: elif move.from_roster == RosterType.INJURED_LIST:
old_team = await move.source_team.injured_list_affiliate() if move.source_team else None old_team = (
await move.source_team.injured_list_affiliate()
if move.source_team
else None
)
elif move.from_roster == RosterType.FREE_AGENCY: elif move.from_roster == RosterType.FREE_AGENCY:
old_team = fa_team old_team = fa_team
else: else:
@ -527,9 +524,17 @@ class TradeAcceptanceView(discord.ui.View):
if move.to_roster == RosterType.MAJOR_LEAGUE: if move.to_roster == RosterType.MAJOR_LEAGUE:
new_team = move.destination_team new_team = move.destination_team
elif move.to_roster == RosterType.MINOR_LEAGUE: elif move.to_roster == RosterType.MINOR_LEAGUE:
new_team = await move.destination_team.minor_league_affiliate() if move.destination_team else None new_team = (
await move.destination_team.minor_league_affiliate()
if move.destination_team
else None
)
elif move.to_roster == RosterType.INJURED_LIST: elif move.to_roster == RosterType.INJURED_LIST:
new_team = await move.destination_team.injured_list_affiliate() if move.destination_team else None new_team = (
await move.destination_team.injured_list_affiliate()
if move.destination_team
else None
)
elif move.to_roster == RosterType.FREE_AGENCY: elif move.to_roster == RosterType.FREE_AGENCY:
new_team = fa_team new_team = fa_team
else: else:
@ -545,45 +550,42 @@ class TradeAcceptanceView(discord.ui.View):
oldteam=old_team, oldteam=old_team,
newteam=new_team, newteam=new_team,
cancelled=False, cancelled=False,
frozen=False # Trades are NOT frozen - immediately effective frozen=False,
) )
transactions.append(transaction) transactions.append(transaction)
# POST transactions to database
if transactions: if transactions:
created_transactions = await transaction_service.create_transaction_batch(transactions) created_transactions = (
await transaction_service.create_transaction_batch(transactions)
)
else: else:
created_transactions = [] created_transactions = []
# Post to #transaction-log channel
if created_transactions and interaction.client: if created_transactions and interaction.client:
await post_trade_to_log( await post_trade_to_log(
bot=interaction.client, bot=interaction.client,
builder=self.builder, builder=self.builder,
transactions=created_transactions, transactions=created_transactions,
effective_week=next_week effective_week=next_week,
) )
# Update trade status
self.builder.trade.status = TradeStatus.ACCEPTED self.builder.trade.status = TradeStatus.ACCEPTED
# Disable buttons
self.accept_button.disabled = True self.accept_button.disabled = True
self.reject_button.disabled = True self.reject_button.disabled = True
# Update embed to show completion embed = await create_trade_complete_embed(
embed = await create_trade_complete_embed(self.builder, len(created_transactions), next_week) self.builder, len(created_transactions), next_week
)
await interaction.edit_original_response(embed=embed, view=self) await interaction.edit_original_response(embed=embed, view=self)
# Send completion message
await interaction.followup.send( await interaction.followup.send(
f"🎉 **Trade Complete!**\n\n" f"**Trade Complete!**\n\n"
f"All {self.builder.team_count} teams have accepted the trade.\n" f"All {self.builder.team_count} teams have accepted the trade.\n"
f"**{len(created_transactions)} transactions** have been created for **Week {next_week}**.\n\n" f"**{len(created_transactions)} transactions** have been created for **Week {next_week}**.\n\n"
f"Trade ID: `{self.builder.trade_id}`" f"Trade ID: `{self.builder.trade_id}`"
) )
# Clear the trade builder
for team in self.builder.participating_teams: for team in self.builder.participating_teams:
clear_trade_builder_by_team(team.id) clear_trade_builder_by_team(team.id)
@ -591,81 +593,79 @@ class TradeAcceptanceView(discord.ui.View):
except Exception as e: except Exception as e:
await interaction.followup.send( await interaction.followup.send(
f"❌ Error finalizing trade: {str(e)}", f"Error finalizing trade: {str(e)}", ephemeral=True
ephemeral=True
) )
async def create_trade_acceptance_embed(builder: TradeBuilder) -> discord.Embed: async def create_trade_acceptance_embed(builder: TradeBuilder) -> discord.Embed:
"""Create embed showing trade details and acceptance status.""" """Create embed showing trade details and acceptance status."""
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title=f"📋 Trade Pending Acceptance - {builder.trade.get_trade_summary()}", title=f"Trade Pending Acceptance - {builder.trade.get_trade_summary()}",
description="All participating teams must accept to complete the trade.", description="All participating teams must accept to complete the trade.",
color=EmbedColors.WARNING color=EmbedColors.WARNING,
) )
# Show participating teams team_list = [
team_list = [f"{team.abbrev} - {team.sname}" for team in builder.participating_teams] f"- {team.abbrev} - {team.sname}" for team in builder.participating_teams
]
embed.add_field( embed.add_field(
name=f"🏟️ Participating Teams ({builder.team_count})", name=f"Participating Teams ({builder.team_count})",
value="\n".join(team_list), value="\n".join(team_list),
inline=False inline=False,
) )
# Show cross-team moves
if builder.trade.cross_team_moves: if builder.trade.cross_team_moves:
moves_text = "" moves_text = ""
for move in builder.trade.cross_team_moves[:10]: for move in builder.trade.cross_team_moves[:10]:
moves_text += f" {move.description}\n" moves_text += f"- {move.description}\n"
if len(builder.trade.cross_team_moves) > 10: if len(builder.trade.cross_team_moves) > 10:
moves_text += f"... and {len(builder.trade.cross_team_moves) - 10} more" moves_text += f"... and {len(builder.trade.cross_team_moves) - 10} more"
embed.add_field( embed.add_field(
name=f"🔄 Player Exchanges ({len(builder.trade.cross_team_moves)})", name=f"Player Exchanges ({len(builder.trade.cross_team_moves)})",
value=moves_text, value=moves_text,
inline=False inline=False,
) )
# Show supplementary moves if any
if builder.trade.supplementary_moves: if builder.trade.supplementary_moves:
supp_text = "" supp_text = ""
for move in builder.trade.supplementary_moves[:5]: for move in builder.trade.supplementary_moves[:5]:
supp_text += f" {move.description}\n" supp_text += f"- {move.description}\n"
if len(builder.trade.supplementary_moves) > 5: if len(builder.trade.supplementary_moves) > 5:
supp_text += f"... and {len(builder.trade.supplementary_moves) - 5} more" supp_text += f"... and {len(builder.trade.supplementary_moves) - 5} more"
embed.add_field( embed.add_field(
name=f"⚙️ Supplementary Moves ({len(builder.trade.supplementary_moves)})", name=f"Supplementary Moves ({len(builder.trade.supplementary_moves)})",
value=supp_text, value=supp_text,
inline=False inline=False,
) )
# Show acceptance status
status_lines = [] status_lines = []
for team in builder.participating_teams: for team in builder.participating_teams:
if team.id in builder.accepted_teams: if team.id in builder.accepted_teams:
status_lines.append(f"**{team.abbrev}** - Accepted") status_lines.append(f"**{team.abbrev}** - Accepted")
else: else:
status_lines.append(f"**{team.abbrev}** - Pending") status_lines.append(f"**{team.abbrev}** - Pending")
embed.add_field( embed.add_field(
name="📊 Acceptance Status", name="Acceptance Status", value="\n".join(status_lines), inline=False
value="\n".join(status_lines),
inline=False
) )
# Add footer embed.set_footer(
embed.set_footer(text=f"Trade ID: {builder.trade_id}{len(builder.accepted_teams)}/{builder.team_count} teams accepted") text=f"Trade ID: {builder.trade_id} | {len(builder.accepted_teams)}/{builder.team_count} teams accepted"
)
return embed return embed
async def create_trade_rejection_embed(builder: TradeBuilder, rejecting_team: Team) -> discord.Embed: async def create_trade_rejection_embed(
builder: TradeBuilder, rejecting_team: Team
) -> discord.Embed:
"""Create embed showing trade was rejected.""" """Create embed showing trade was rejected."""
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title=f"Trade Rejected - {builder.trade.get_trade_summary()}", title=f"Trade Rejected - {builder.trade.get_trade_summary()}",
description=f"**{rejecting_team.abbrev}** has rejected the trade.\n\n" description=f"**{rejecting_team.abbrev}** has rejected the trade.\n\n"
f"The trade has been moved back to **DRAFT** status.\n" f"The trade has been moved back to **DRAFT** status.\n"
f"Teams can continue negotiating using `/trade` commands.", f"Teams can continue negotiating using `/trade` commands.",
color=EmbedColors.ERROR color=EmbedColors.ERROR,
) )
embed.set_footer(text=f"Trade ID: {builder.trade_id}") embed.set_footer(text=f"Trade ID: {builder.trade_id}")
@ -673,37 +673,33 @@ async def create_trade_rejection_embed(builder: TradeBuilder, rejecting_team: Te
return embed return embed
async def create_trade_complete_embed(builder: TradeBuilder, transaction_count: int, effective_week: int) -> discord.Embed: async def create_trade_complete_embed(
builder: TradeBuilder, transaction_count: int, effective_week: int
) -> discord.Embed:
"""Create embed showing trade was completed.""" """Create embed showing trade was completed."""
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title=f"🎉 Trade Complete! - {builder.trade.get_trade_summary()}", title=f"Trade Complete - {builder.trade.get_trade_summary()}",
description=f"All {builder.team_count} teams have accepted the trade!\n\n" description=f"All {builder.team_count} teams have accepted the trade.\n\n"
f"**{transaction_count} transactions** created for **Week {effective_week}**.", f"**{transaction_count} transactions** created for **Week {effective_week}**.",
color=EmbedColors.SUCCESS color=EmbedColors.SUCCESS,
) )
# Show final acceptance status (all green) status_lines = [
status_lines = [f"✅ **{team.abbrev}** - Accepted" for team in builder.participating_teams] f"**{team.abbrev}** - Accepted" for team in builder.participating_teams
embed.add_field( ]
name="📊 Final Status", embed.add_field(name="Final Status", value="\n".join(status_lines), inline=False)
value="\n".join(status_lines),
inline=False
)
# Show cross-team moves
if builder.trade.cross_team_moves: if builder.trade.cross_team_moves:
moves_text = "" moves_text = ""
for move in builder.trade.cross_team_moves[:8]: for move in builder.trade.cross_team_moves[:8]:
moves_text += f" {move.description}\n" moves_text += f"- {move.description}\n"
if len(builder.trade.cross_team_moves) > 8: if len(builder.trade.cross_team_moves) > 8:
moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more" moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more"
embed.add_field( embed.add_field(name="Player Exchanges", value=moves_text, inline=False)
name=f"🔄 Player Exchanges",
value=moves_text,
inline=False
)
embed.set_footer(text=f"Trade ID: {builder.trade_id} • Effective: Week {effective_week}") embed.set_footer(
text=f"Trade ID: {builder.trade_id} | Effective: Week {effective_week}"
)
return embed return embed
@ -718,7 +714,6 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed:
Returns: Returns:
Discord embed with current trade state Discord embed with current trade state
""" """
# Determine embed color based on trade status
if builder.is_empty: if builder.is_empty:
color = EmbedColors.SECONDARY color = EmbedColors.SECONDARY
else: else:
@ -726,79 +721,79 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed:
color = EmbedColors.SUCCESS if validation.is_legal else EmbedColors.WARNING color = EmbedColors.SUCCESS if validation.is_legal else EmbedColors.WARNING
embed = EmbedTemplate.create_base_embed( embed = EmbedTemplate.create_base_embed(
title=f"📋 Trade Builder - {builder.trade.get_trade_summary()}", title=f"Trade Builder - {builder.trade.get_trade_summary()}",
description=f"Build your multi-team trade", description="Build your multi-team trade",
color=color color=color,
) )
# Add participating teams section team_list = [
team_list = [f"{team.abbrev} - {team.sname}" for team in builder.participating_teams] f"- {team.abbrev} - {team.sname}" for team in builder.participating_teams
]
embed.add_field( embed.add_field(
name=f"🏟️ Participating Teams ({builder.team_count})", name=f"Participating Teams ({builder.team_count})",
value="\n".join(team_list) if team_list else "*No teams yet*", value="\n".join(team_list) if team_list else "*No teams yet*",
inline=False inline=False,
) )
# Add current moves section
if builder.is_empty: if builder.is_empty:
embed.add_field( embed.add_field(
name="Current Moves", name="Current Moves",
value="*No moves yet. Use the `/trade` commands to build your trade.*", value="*No moves yet. Use the `/trade` commands to build your trade.*",
inline=False inline=False,
) )
else: else:
# Show cross-team moves
if builder.trade.cross_team_moves: if builder.trade.cross_team_moves:
moves_text = "" moves_text = ""
for i, move in enumerate(builder.trade.cross_team_moves[:8], 1): # Limit display for i, move in enumerate(builder.trade.cross_team_moves[:8], 1):
moves_text += f"{i}. {move.description}\n" moves_text += f"{i}. {move.description}\n"
if len(builder.trade.cross_team_moves) > 8: if len(builder.trade.cross_team_moves) > 8:
moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more" moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more"
embed.add_field( embed.add_field(
name=f"🔄 Player Exchanges ({len(builder.trade.cross_team_moves)})", name=f"Player Exchanges ({len(builder.trade.cross_team_moves)})",
value=moves_text, value=moves_text,
inline=False inline=False,
) )
# Show supplementary moves
if builder.trade.supplementary_moves: if builder.trade.supplementary_moves:
supp_text = "" supp_text = ""
for i, move in enumerate(builder.trade.supplementary_moves[:5], 1): # Limit display for i, move in enumerate(builder.trade.supplementary_moves[:5], 1):
supp_text += f"{i}. {move.description}\n" supp_text += f"{i}. {move.description}\n"
if len(builder.trade.supplementary_moves) > 5: if len(builder.trade.supplementary_moves) > 5:
supp_text += f"... and {len(builder.trade.supplementary_moves) - 5} more" supp_text += (
f"... and {len(builder.trade.supplementary_moves) - 5} more"
)
embed.add_field( embed.add_field(
name=f"⚙️ Supplementary Moves ({len(builder.trade.supplementary_moves)})", name=f"Supplementary Moves ({len(builder.trade.supplementary_moves)})",
value=supp_text, value=supp_text,
inline=False inline=False,
) )
# Add quick validation summary
validation = await builder.validate_trade() validation = await builder.validate_trade()
if validation.is_legal: if validation.is_legal:
status_text = "Trade appears legal" status_text = "Trade appears legal"
else: else:
error_count = len(validation.all_errors) error_count = len(validation.all_errors)
status_text = f"{error_count} error{'s' if error_count != 1 else ''} found" status_text = f"{error_count} error{'s' if error_count != 1 else ''} found\n"
status_text += "\n".join(f"- {error}" for error in validation.all_errors)
if validation.all_suggestions:
status_text += "\n" + "\n".join(
f"- {s}" for s in validation.all_suggestions
)
embed.add_field(name="Quick Status", value=status_text, inline=False)
embed.add_field( embed.add_field(
name="🔍 Quick Status", name="Build Your Trade",
value=status_text, value="- `/trade add-player` - Add player exchanges\n- `/trade supplementary` - Add internal moves\n- `/trade add-team` - Add more teams",
inline=False inline=False,
) )
# Add instructions for adding more moves embed.set_footer(
embed.add_field( text=f"Trade ID: {builder.trade_id} | Created: {datetime.fromisoformat(builder.trade.created_at).strftime('%H:%M:%S')}"
name=" Build Your Trade",
value="• `/trade add-player` - Add player exchanges\n• `/trade supplementary` - Add internal moves\n• `/trade add-team` - Add more teams",
inline=False
) )
# Add footer with trade ID and timestamp return embed
embed.set_footer(text=f"Trade ID: {builder.trade_id} • Created: {datetime.fromisoformat(builder.trade.created_at).strftime('%H:%M:%S')}")
return embed