fix: prevent partial DB writes on scorecard submission failure #79

Closed
cal wants to merge 12 commits from fix/scorecard-submission-resilience into next-release
13 changed files with 1006 additions and 922 deletions

View File

@ -210,17 +210,41 @@ class SubmitScorecardCommands(commands.Cog):
game_id = scheduled_game.id
# Phase 6: Read Scorecard Data
# Phase 6: Read ALL Scorecard Data (before any DB writes)
# Reading everything first prevents partial commits if the
# spreadsheet has formula errors (e.g. #N/A in pitching decisions)
await interaction.edit_original_response(
content="📊 Reading play-by-play data..."
content="📊 Reading scorecard data..."
)
plays_data = await self.sheets_service.read_playtable_data(scorecard)
box_score = await self.sheets_service.read_box_score(scorecard)
decisions_data = await self.sheets_service.read_pitching_decisions(
scorecard
)
# Add game_id to each play
for play in plays_data:
play["game_id"] = game_id
# Add game metadata to each decision
for decision in decisions_data:
decision["game_id"] = game_id
decision["season"] = current.season
decision["week"] = setup_data["week"]
decision["game_num"] = setup_data["game_num"]
# Validate WP and LP exist and fetch Player objects
wp, lp, sv, holders, _blown_saves = (
await decision_service.find_winning_losing_pitchers(decisions_data)
)
if wp is None or lp is None:
await interaction.edit_original_response(
content="❌ Your card is missing either a Winning Pitcher or Losing Pitcher"
)
return
# Phase 7: POST Plays
await interaction.edit_original_response(
content="💾 Submitting plays to database..."
@ -244,10 +268,7 @@ class SubmitScorecardCommands(commands.Cog):
)
return
# Phase 8: Read Box Score
box_score = await self.sheets_service.read_box_score(scorecard)
# Phase 9: PATCH Game
# Phase 8: PATCH Game
await interaction.edit_original_response(
content="⚾ Updating game result..."
)
@ -275,33 +296,7 @@ class SubmitScorecardCommands(commands.Cog):
)
return
# Phase 10: Read Pitching Decisions
decisions_data = await self.sheets_service.read_pitching_decisions(
scorecard
)
# Add game metadata to each decision
for decision in decisions_data:
decision["game_id"] = game_id
decision["season"] = current.season
decision["week"] = setup_data["week"]
decision["game_num"] = setup_data["game_num"]
# Validate WP and LP exist and fetch Player objects
wp, lp, sv, holders, _blown_saves = (
await decision_service.find_winning_losing_pitchers(decisions_data)
)
if wp is None or lp is None:
# Rollback
await game_service.wipe_game_data(game_id)
await play_service.delete_plays_for_game(game_id)
await interaction.edit_original_response(
content="❌ Your card is missing either a Winning Pitcher or Losing Pitcher"
)
return
# Phase 11: POST Decisions
# Phase 9: POST Decisions
await interaction.edit_original_response(
content="🎯 Submitting pitching decisions..."
)
@ -361,6 +356,30 @@ class SubmitScorecardCommands(commands.Cog):
# Success!
await interaction.edit_original_response(content="✅ You are all set!")
except SheetsException as e:
# Spreadsheet reading error - show the detailed message to the user
self.logger.error(
f"Spreadsheet error in scorecard submission: {e}", error=e
)
if rollback_state and game_id:
try:
if rollback_state == "GAME_PATCHED":
await game_service.wipe_game_data(game_id)
await play_service.delete_plays_for_game(game_id)
elif rollback_state == "PLAYS_POSTED":
await play_service.delete_plays_for_game(game_id)
except Exception:
pass # Best effort rollback
await interaction.edit_original_response(
content=(
f"❌ There's a problem with your scorecard:\n\n"
f"{str(e)}\n\n"
f"Please fix the issue in your spreadsheet and resubmit."
)
)
except Exception as e:
# Unexpected error - attempt rollback
self.logger.error(f"Unexpected error in scorecard submission: {e}", error=e)

View File

@ -3,6 +3,7 @@ Base service class for Discord Bot v2.0
Provides common CRUD operations and error handling for all data services.
"""
import logging
import hashlib
from typing import Optional, Type, TypeVar, Generic, Dict, Any, List, Tuple
@ -12,15 +13,15 @@ from models.base import SBABaseModel
from exceptions import APIException
from utils.cache import CacheManager
logger = logging.getLogger(f'{__name__}.BaseService')
logger = logging.getLogger(f"{__name__}.BaseService")
T = TypeVar('T', bound=SBABaseModel)
T = TypeVar("T", bound=SBABaseModel)
class BaseService(Generic[T]):
"""
Base service class providing common CRUD operations for SBA models.
Features:
- Generic type support for any SBABaseModel subclass
- Automatic model validation and conversion
@ -28,15 +29,17 @@ class BaseService(Generic[T]):
- API response format handling (count + list format)
- Connection management via global client
"""
def __init__(self,
model_class: Type[T],
endpoint: str,
client: Optional[APIClient] = None,
cache_manager: Optional[CacheManager] = None):
def __init__(
self,
model_class: Type[T],
endpoint: str,
client: Optional[APIClient] = None,
cache_manager: Optional[CacheManager] = None,
):
"""
Initialize base service.
Args:
model_class: Pydantic model class for this service
endpoint: API endpoint path (e.g., 'players', 'teams')
@ -48,40 +51,44 @@ class BaseService(Generic[T]):
self._client = client
self._cached_client: Optional[APIClient] = None
self.cache = cache_manager or CacheManager()
logger.debug(f"Initialized {self.__class__.__name__} for {model_class.__name__} at endpoint '{endpoint}'")
def _generate_cache_key(self, method: str, params: Optional[List[Tuple[str, Any]]] = None) -> str:
logger.debug(
f"Initialized {self.__class__.__name__} for {model_class.__name__} at endpoint '{endpoint}'"
)
def _generate_cache_key(
self, method: str, params: Optional[List[Tuple[str, Any]]] = None
) -> str:
"""
Generate consistent cache key for API calls.
Args:
method: API method name
params: Query parameters as list of tuples
Returns:
SHA256-hashed cache key
"""
key_parts = [self.endpoint, method]
if params:
# Sort parameters for consistent key generation
sorted_params = sorted(params, key=lambda x: str(x[0]))
param_str = "&".join([f"{k}={v}" for k, v in sorted_params])
key_parts.append(param_str)
key_data = ":".join(key_parts)
key_hash = hashlib.sha256(key_data.encode()).hexdigest()[:16] # First 16 chars
return self.cache.cache_key("sba", f"{self.endpoint}_{key_hash}")
async def _get_cached_items(self, cache_key: str) -> Optional[List[T]]:
"""
Get cached list of model items.
Args:
cache_key: Cache key to lookup
Returns:
List of model instances or None if not cached
"""
@ -91,13 +98,15 @@ class BaseService(Generic[T]):
return [self.model_class.from_api_data(item) for item in cached_data]
except Exception as e:
logger.warning(f"Error deserializing cached data for {cache_key}: {e}")
return None
async def _cache_items(self, cache_key: str, items: List[T], ttl: Optional[int] = None) -> None:
async def _cache_items(
self, cache_key: str, items: List[T], ttl: Optional[int] = None
) -> None:
"""
Cache list of model items.
Args:
cache_key: Cache key to store under
items: List of model instances to cache
@ -105,40 +114,40 @@ class BaseService(Generic[T]):
"""
if not items:
return
try:
# Convert to JSON-serializable format
cache_data = [item.model_dump() for item in items]
await self.cache.set(cache_key, cache_data, ttl)
except Exception as e:
logger.warning(f"Error caching items for {cache_key}: {e}")
async def get_client(self) -> APIClient:
"""
Get API client instance with caching to reduce async overhead.
Returns:
APIClient instance (cached after first access)
"""
if self._client:
return self._client
# Cache the global client to avoid repeated async calls
if self._cached_client is None:
self._cached_client = await get_global_client()
return self._cached_client
async def get_by_id(self, object_id: int) -> Optional[T]:
"""
Get single object by ID.
Args:
object_id: Unique identifier for the object
Returns:
Model instance or None if not found
Raises:
APIException: For API errors
ValueError: For invalid data
@ -146,167 +155,181 @@ class BaseService(Generic[T]):
try:
client = await self.get_client()
data = await client.get(self.endpoint, object_id=object_id)
if not data:
logger.debug(f"{self.model_class.__name__} {object_id} not found")
return None
model = self.model_class.from_api_data(data)
logger.debug(f"Retrieved {self.model_class.__name__} {object_id}: {model}")
return model
except APIException:
logger.error(f"API error retrieving {self.model_class.__name__} {object_id}")
logger.error(
f"API error retrieving {self.model_class.__name__} {object_id}"
)
raise
except Exception as e:
logger.error(f"Error retrieving {self.model_class.__name__} {object_id}: {e}")
logger.error(
f"Error retrieving {self.model_class.__name__} {object_id}: {e}"
)
raise APIException(f"Failed to retrieve {self.model_class.__name__}: {e}")
async def get_all(self, params: Optional[List[tuple]] = None) -> Tuple[List[T], int]:
async def get_all(
self, params: Optional[List[tuple]] = None
) -> Tuple[List[T], int]:
"""
Get all objects with optional query parameters.
Args:
params: Query parameters as list of (key, value) tuples
Returns:
Tuple of (list of model instances, total count)
Raises:
APIException: For API errors
"""
try:
client = await self.get_client()
data = await client.get(self.endpoint, params=params)
if not data:
logger.debug(f"No {self.model_class.__name__} objects found")
return [], 0
# Handle API response format: {'count': int, '<endpoint>': [...]}
items, count = self._extract_items_and_count_from_response(data)
models = [self.model_class.from_api_data(item) for item in items]
logger.debug(f"Retrieved {len(models)} of {count} {self.model_class.__name__} objects")
logger.debug(
f"Retrieved {len(models)} of {count} {self.model_class.__name__} objects"
)
return models, count
except APIException:
logger.error(f"API error retrieving {self.model_class.__name__} list")
raise
except Exception as e:
logger.error(f"Error retrieving {self.model_class.__name__} list: {e}")
raise APIException(f"Failed to retrieve {self.model_class.__name__} list: {e}")
raise APIException(
f"Failed to retrieve {self.model_class.__name__} list: {e}"
)
async def get_all_items(self, params: Optional[List[tuple]] = None) -> List[T]:
"""
Get all objects (convenience method that only returns the list).
Args:
params: Query parameters as list of (key, value) tuples
Returns:
List of model instances
"""
items, _ = await self.get_all(params=params)
return items
async def create(self, model_data: Dict[str, Any]) -> Optional[T]:
"""
Create new object from data dictionary.
Args:
model_data: Dictionary of model fields
Returns:
Created model instance or None
Raises:
APIException: For API errors
ValueError: For invalid data
"""
try:
client = await self.get_client()
response = await client.post(self.endpoint, model_data)
response = await client.post(f"{self.endpoint}/", model_data)
if not response:
logger.warning(f"No response from {self.model_class.__name__} creation")
return None
model = self.model_class.from_api_data(response)
logger.debug(f"Created {self.model_class.__name__}: {model}")
return model
except APIException:
logger.error(f"API error creating {self.model_class.__name__}")
raise
except Exception as e:
logger.error(f"Error creating {self.model_class.__name__}: {e}")
raise APIException(f"Failed to create {self.model_class.__name__}: {e}")
async def create_from_model(self, model: T) -> Optional[T]:
"""
Create new object from model instance.
Args:
model: Model instance to create
Returns:
Created model instance or None
"""
return await self.create(model.to_dict(exclude_none=True))
async def update(self, object_id: int, model_data: Dict[str, Any]) -> Optional[T]:
"""
Update existing object.
Args:
object_id: ID of object to update
model_data: Dictionary of fields to update
Returns:
Updated model instance or None if not found
Raises:
APIException: For API errors
"""
try:
client = await self.get_client()
response = await client.put(self.endpoint, model_data, object_id=object_id)
if not response:
logger.debug(f"{self.model_class.__name__} {object_id} not found for update")
logger.debug(
f"{self.model_class.__name__} {object_id} not found for update"
)
return None
model = self.model_class.from_api_data(response)
logger.debug(f"Updated {self.model_class.__name__} {object_id}: {model}")
return model
except APIException:
logger.error(f"API error updating {self.model_class.__name__} {object_id}")
raise
except Exception as e:
logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}")
raise APIException(f"Failed to update {self.model_class.__name__}: {e}")
async def update_from_model(self, model: T) -> Optional[T]:
"""
Update object from model instance.
Args:
model: Model instance to update (must have ID)
Returns:
Updated model instance or None
Raises:
ValueError: If model has no ID
"""
if not model.id:
raise ValueError(f"Cannot update {self.model_class.__name__} without ID")
return await self.update(model.id, model.to_dict(exclude_none=True))
async def patch(self, object_id: int, model_data: Dict[str, Any], use_query_params: bool = False) -> Optional[T]:
async def patch(
self, object_id: int, model_data: Dict[str, Any], use_query_params: bool = False
) -> Optional[T]:
"""
Update existing object with HTTP PATCH.
@ -323,10 +346,14 @@ class BaseService(Generic[T]):
"""
try:
client = await self.get_client()
response = await client.patch(self.endpoint, model_data, object_id, use_query_params=use_query_params)
response = await client.patch(
self.endpoint, model_data, object_id, use_query_params=use_query_params
)
if not response:
logger.debug(f"{self.model_class.__name__} {object_id} not found for update")
logger.debug(
f"{self.model_class.__name__} {object_id} not found for update"
)
return None
model = self.model_class.from_api_data(response)
@ -340,134 +367,142 @@ class BaseService(Generic[T]):
logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}")
raise APIException(f"Failed to update {self.model_class.__name__}: {e}")
async def delete(self, object_id: int) -> bool:
"""
Delete object by ID.
Args:
object_id: ID of object to delete
Returns:
True if deleted, False if not found
Raises:
APIException: For API errors
"""
try:
client = await self.get_client()
success = await client.delete(self.endpoint, object_id=object_id)
if success:
logger.debug(f"Deleted {self.model_class.__name__} {object_id}")
else:
logger.debug(f"{self.model_class.__name__} {object_id} not found for deletion")
logger.debug(
f"{self.model_class.__name__} {object_id} not found for deletion"
)
return success
except APIException:
logger.error(f"API error deleting {self.model_class.__name__} {object_id}")
raise
except Exception as e:
logger.error(f"Error deleting {self.model_class.__name__} {object_id}: {e}")
raise APIException(f"Failed to delete {self.model_class.__name__}: {e}")
async def get_by_field(self, field: str, value: Any) -> List[T]:
"""
Get objects by specific field value.
Args:
field: Field name to search
value: Field value to match
Returns:
List of matching model instances
"""
params = [(field, str(value))]
return await self.get_all_items(params=params)
async def count(self, params: Optional[List[tuple]] = None) -> int:
"""
Get count of objects matching parameters.
Args:
params: Query parameters
Returns:
Number of matching objects (from API count field)
"""
_, count = await self.get_all(params=params)
return count
def _extract_items_and_count_from_response(self, data: Any) -> Tuple[List[Dict[str, Any]], int]:
def _extract_items_and_count_from_response(
self, data: Any
) -> Tuple[List[Dict[str, Any]], int]:
"""
Extract items list and count from API response with optimized parsing.
Expected format: {'count': int, '<endpoint>': [...]}
Single object format: {'id': 1, 'name': '...'}
Args:
data: API response data
Returns:
Tuple of (items list, total count)
"""
if isinstance(data, list):
return data, len(data)
if not isinstance(data, dict):
logger.warning(f"Unexpected response format for {self.model_class.__name__}: {type(data)}")
logger.warning(
f"Unexpected response format for {self.model_class.__name__}: {type(data)}"
)
return [], 0
# Single pass through the response dict - get count first
count = data.get('count', 0)
count = data.get("count", 0)
# Priority order for finding items list (most common first)
field_candidates = [self.endpoint, 'items', 'data', 'results']
field_candidates = [self.endpoint, "items", "data", "results"]
for field_name in field_candidates:
if field_name in data and isinstance(data[field_name], list):
return data[field_name], count or len(data[field_name])
# Single object response (check for common identifying fields)
if any(key in data for key in ['id', 'name', 'abbrev']):
if any(key in data for key in ["id", "name", "abbrev"]):
return [data], 1
return [], count
async def get_items_with_params(self, params: Optional[List[tuple]] = None) -> List[T]:
async def get_items_with_params(
self, params: Optional[List[tuple]] = None
) -> List[T]:
"""
Get all items with parameters (alias for get_all_items for compatibility).
Args:
params: Query parameters as list of (key, value) tuples
Returns:
List of model instances
"""
return await self.get_all_items(params=params)
async def create_item(self, model_data: Dict[str, Any]) -> Optional[T]:
"""
Create item (alias for create for compatibility).
Args:
model_data: Dictionary of model fields
Returns:
Created model instance or None
"""
return await self.create(model_data)
async def update_item_by_field(self, field: str, value: Any, update_data: Dict[str, Any]) -> Optional[T]:
async def update_item_by_field(
self, field: str, value: Any, update_data: Dict[str, Any]
) -> Optional[T]:
"""
Update item by field value.
Args:
field: Field name to search by
value: Field value to match
update_data: Data to update
Returns:
Updated model instance or None if not found
"""
@ -475,22 +510,22 @@ class BaseService(Generic[T]):
items = await self.get_by_field(field, value)
if not items:
return None
# Update the first matching item
item = items[0]
if not item.id:
return None
return await self.update(item.id, update_data)
async def delete_item_by_field(self, field: str, value: Any) -> bool:
"""
Delete item by field value.
Args:
field: Field name to search by
value: Field value to match
Returns:
True if deleted, False if not found
"""
@ -498,62 +533,41 @@ class BaseService(Generic[T]):
items = await self.get_by_field(field, value)
if not items:
return False
# Delete the first matching item
item = items[0]
if not item.id:
return False
return await self.delete(item.id)
async def create_item_in_table(self, table_name: str, item_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Create item in a specific table (simplified for custom commands service).
This is a placeholder - real implementation would need table-specific endpoints.
Args:
table_name: Name of the table
item_data: Data to create
Returns:
Created item data or None
"""
# For now, use the main endpoint - this would need proper implementation
# for different tables like 'custom_command_creators'
try:
client = await self.get_client()
# Use table name as endpoint for now
response = await client.post(table_name, item_data)
return response
except Exception as e:
logger.error(f"Error creating item in table {table_name}: {e}")
return None
async def get_items_from_table_with_params(self, table_name: str, params: List[tuple]) -> List[Dict[str, Any]]:
async def get_items_from_table_with_params(
self, table_name: str, params: List[tuple]
) -> List[Dict[str, Any]]:
"""
Get items from a specific table with parameters.
Args:
table_name: Name of the table
params: Query parameters
Returns:
List of item dictionaries
"""
try:
client = await self.get_client()
data = await client.get(table_name, params=params)
if not data:
return []
# Handle response format
items, _ = self._extract_items_and_count_from_response(data)
return items
except Exception as e:
logger.error(f"Error getting items from table {table_name}: {e}")
return []
def __repr__(self) -> str:
return f"{self.__class__.__name__}(model={self.model_class.__name__}, endpoint='{self.endpoint}')"
return f"{self.__class__.__name__}(model={self.model_class.__name__}, endpoint='{self.endpoint}')"

View File

@ -552,9 +552,8 @@ class CustomCommandsService(BaseService[CustomCommand]):
"active_commands": 0,
}
result = await self.create_item_in_table(
"custom_commands/creators", creator_data
)
client = await self.get_client()
result = await client.post("custom_commands/creators", creator_data)
if not result:
raise BotException("Failed to create command creator")

View File

@ -3,6 +3,7 @@ Decision Service
Manages pitching decision operations for game submission.
"""
from typing import List, Dict, Any, Optional, Tuple
from utils.logging import get_contextual_logger
@ -16,17 +17,14 @@ class DecisionService:
def __init__(self):
"""Initialize decision service."""
self.logger = get_contextual_logger(f'{__name__}.DecisionService')
self.logger = get_contextual_logger(f"{__name__}.DecisionService")
self._get_client = get_global_client
async def get_client(self):
"""Get the API client."""
return await self._get_client()
async def create_decisions_batch(
self,
decisions: List[Dict[str, Any]]
) -> bool:
async def create_decisions_batch(self, decisions: List[Dict[str, Any]]) -> bool:
"""
POST batch of decisions to /decisions endpoint.
@ -42,8 +40,10 @@ class DecisionService:
try:
client = await self.get_client()
payload = {'decisions': decisions}
await client.post('decisions', payload)
payload = {"decisions": decisions}
# Trailing slash required: without it, the server returns a 307 redirect
# and aiohttp drops the POST body when following the redirect
await client.post("decisions/", payload)
self.logger.info(f"Created {len(decisions)} decisions")
return True
@ -70,7 +70,7 @@ class DecisionService:
"""
try:
client = await self.get_client()
await client.delete(f'decisions/game/{game_id}')
await client.delete(f"decisions/game/{game_id}")
self.logger.info(f"Deleted decisions for game {game_id}")
return True
@ -80,9 +80,10 @@ class DecisionService:
raise APIException(f"Failed to delete decisions: {e}")
async def find_winning_losing_pitchers(
self,
decisions_data: List[Dict[str, Any]]
) -> Tuple[Optional[Player], Optional[Player], Optional[Player], List[Player], List[Player]]:
self, decisions_data: List[Dict[str, Any]]
) -> Tuple[
Optional[Player], Optional[Player], Optional[Player], List[Player], List[Player]
]:
"""
Extract WP, LP, SV, Holds, Blown Saves from decisions list and fetch Player objects.
@ -110,17 +111,17 @@ class DecisionService:
# First pass: Extract IDs
for decision in decisions_data:
pitcher_id = int(decision.get('pitcher_id', 0))
pitcher_id = int(decision.get("pitcher_id", 0))
if int(decision.get('win', 0)) == 1:
if int(decision.get("win", 0)) == 1:
wp_id = pitcher_id
if int(decision.get('loss', 0)) == 1:
if int(decision.get("loss", 0)) == 1:
lp_id = pitcher_id
if int(decision.get('is_save', 0)) == 1:
if int(decision.get("is_save", 0)) == 1:
sv_id = pitcher_id
if int(decision.get('hold', 0)) == 1:
if int(decision.get("hold", 0)) == 1:
hold_ids.append(pitcher_id)
if int(decision.get('b_save', 0)) == 1:
if int(decision.get("b_save", 0)) == 1:
bsv_ids.append(pitcher_id)
# Second pass: Fetch Player objects
@ -154,9 +155,9 @@ class DecisionService:
"""
error_str = str(error)
if 'Player ID' in error_str and 'not found' in error_str:
if "Player ID" in error_str and "not found" in error_str:
return "Invalid pitcher ID in decision data."
elif 'Game ID' in error_str and 'not found' in error_str:
elif "Game ID" in error_str and "not found" in error_str:
return "Game not found for decisions."
else:
return f"Error submitting decisions: {error_str}"

View File

@ -3,13 +3,14 @@ Draft list service for Discord Bot v2.0
Handles team draft list (auto-draft queue) operations. NO CACHING - lists change frequently.
"""
import logging
from typing import Optional, List
from services.base_service import BaseService
from models.draft_list import DraftList
logger = logging.getLogger(f'{__name__}.DraftListService')
logger = logging.getLogger(f"{__name__}.DraftListService")
class DraftListService(BaseService[DraftList]):
@ -32,7 +33,7 @@ class DraftListService(BaseService[DraftList]):
def __init__(self):
"""Initialize draft list service."""
super().__init__(DraftList, 'draftlist')
super().__init__(DraftList, "draftlist")
logger.debug("DraftListService initialized")
def _extract_items_and_count_from_response(self, data):
@ -54,20 +55,16 @@ class DraftListService(BaseService[DraftList]):
return [], 0
# Get count
count = data.get('count', 0)
count = data.get("count", 0)
# API returns items under 'picks' key (not 'draftlist')
if 'picks' in data and isinstance(data['picks'], list):
return data['picks'], count or len(data['picks'])
if "picks" in data and isinstance(data["picks"], list):
return data["picks"], count or len(data["picks"])
# Fallback to standard extraction
return super()._extract_items_and_count_from_response(data)
async def get_team_list(
self,
season: int,
team_id: int
) -> List[DraftList]:
async def get_team_list(self, season: int, team_id: int) -> List[DraftList]:
"""
Get team's draft list ordered by rank.
@ -82,8 +79,8 @@ class DraftListService(BaseService[DraftList]):
"""
try:
params = [
('season', str(season)),
('team_id', str(team_id))
("season", str(season)),
("team_id", str(team_id)),
# NOTE: API does not support 'sort' param - results must be sorted client-side
]
@ -100,11 +97,7 @@ class DraftListService(BaseService[DraftList]):
return []
async def add_to_list(
self,
season: int,
team_id: int,
player_id: int,
rank: Optional[int] = None
self, season: int, team_id: int, player_id: int, rank: Optional[int] = None
) -> Optional[List[DraftList]]:
"""
Add player to team's draft list.
@ -133,10 +126,10 @@ class DraftListService(BaseService[DraftList]):
# Create new entry data
new_entry_data = {
'season': season,
'team_id': team_id,
'player_id': player_id,
'rank': rank
"season": season,
"team_id": team_id,
"player_id": player_id,
"rank": rank,
}
# Build complete list for bulk replacement
@ -146,36 +139,42 @@ class DraftListService(BaseService[DraftList]):
for entry in current_list:
if entry.rank >= rank:
# Shift down entries at or after insertion point
draft_list_entries.append({
'season': entry.season,
'team_id': entry.team_id,
'player_id': entry.player_id,
'rank': entry.rank + 1
})
draft_list_entries.append(
{
"season": entry.season,
"team_id": entry.team_id,
"player_id": entry.player_id,
"rank": entry.rank + 1,
}
)
else:
# Keep existing rank for entries before insertion point
draft_list_entries.append({
'season': entry.season,
'team_id': entry.team_id,
'player_id': entry.player_id,
'rank': entry.rank
})
draft_list_entries.append(
{
"season": entry.season,
"team_id": entry.team_id,
"player_id": entry.player_id,
"rank": entry.rank,
}
)
# Add new entry
draft_list_entries.append(new_entry_data)
# Sort by rank for consistency
draft_list_entries.sort(key=lambda x: x['rank'])
draft_list_entries.sort(key=lambda x: x["rank"])
# POST entire list (bulk replacement)
client = await self.get_client()
payload = {
'count': len(draft_list_entries),
'draft_list': draft_list_entries
"count": len(draft_list_entries),
"draft_list": draft_list_entries,
}
logger.debug(f"Posting draft list for team {team_id}: {len(draft_list_entries)} entries")
response = await client.post(self.endpoint, payload)
logger.debug(
f"Posting draft list for team {team_id}: {len(draft_list_entries)} entries"
)
response = await client.post(f"{self.endpoint}/", payload)
logger.debug(f"POST response: {response}")
# Verify by fetching the list back (API returns full objects)
@ -184,20 +183,21 @@ class DraftListService(BaseService[DraftList]):
# Verify the player was added
if not any(entry.player_id == player_id for entry in verification):
logger.error(f"Player {player_id} not found in list after POST - operation may have failed")
logger.error(
f"Player {player_id} not found in list after POST - operation may have failed"
)
return None
logger.info(f"Added player {player_id} to team {team_id} draft list at rank {rank}")
logger.info(
f"Added player {player_id} to team {team_id} draft list at rank {rank}"
)
return verification # Return full updated list
except Exception as e:
logger.error(f"Error adding player {player_id} to draft list: {e}")
return None
async def remove_from_list(
self,
entry_id: int
) -> bool:
async def remove_from_list(self, entry_id: int) -> bool:
"""
Remove entry from draft list by ID.
@ -209,14 +209,13 @@ class DraftListService(BaseService[DraftList]):
Returns:
True if deletion succeeded
"""
logger.warning("remove_from_list() called with entry_id - use remove_player_from_list() instead")
logger.warning(
"remove_from_list() called with entry_id - use remove_player_from_list() instead"
)
return False
async def remove_player_from_list(
self,
season: int,
team_id: int,
player_id: int
self, season: int, team_id: int, player_id: int
) -> bool:
"""
Remove specific player from team's draft list.
@ -238,7 +237,9 @@ class DraftListService(BaseService[DraftList]):
# Check if player is in list
player_found = any(entry.player_id == player_id for entry in current_list)
if not player_found:
logger.warning(f"Player {player_id} not found in team {team_id} draft list")
logger.warning(
f"Player {player_id} not found in team {team_id} draft list"
)
return False
# Build new list without the player, adjusting ranks
@ -246,22 +247,24 @@ class DraftListService(BaseService[DraftList]):
new_rank = 1
for entry in current_list:
if entry.player_id != player_id:
draft_list_entries.append({
'season': entry.season,
'team_id': entry.team_id,
'player_id': entry.player_id,
'rank': new_rank
})
draft_list_entries.append(
{
"season": entry.season,
"team_id": entry.team_id,
"player_id": entry.player_id,
"rank": new_rank,
}
)
new_rank += 1
# POST updated list (bulk replacement)
client = await self.get_client()
payload = {
'count': len(draft_list_entries),
'draft_list': draft_list_entries
"count": len(draft_list_entries),
"draft_list": draft_list_entries,
}
await client.post(self.endpoint, payload)
await client.post(f"{self.endpoint}/", payload)
logger.info(f"Removed player {player_id} from team {team_id} draft list")
return True
@ -270,11 +273,7 @@ class DraftListService(BaseService[DraftList]):
logger.error(f"Error removing player {player_id} from draft list: {e}")
return False
async def clear_list(
self,
season: int,
team_id: int
) -> bool:
async def clear_list(self, season: int, team_id: int) -> bool:
"""
Clear entire draft list for team.
@ -309,10 +308,7 @@ class DraftListService(BaseService[DraftList]):
return False
async def reorder_list(
self,
season: int,
team_id: int,
new_order: List[int]
self, season: int, team_id: int, new_order: List[int]
) -> bool:
"""
Reorder team's draft list.
@ -342,21 +338,23 @@ class DraftListService(BaseService[DraftList]):
continue
entry = entry_map[player_id]
draft_list_entries.append({
'season': entry.season,
'team_id': entry.team_id,
'player_id': entry.player_id,
'rank': new_rank
})
draft_list_entries.append(
{
"season": entry.season,
"team_id": entry.team_id,
"player_id": entry.player_id,
"rank": new_rank,
}
)
# POST reordered list (bulk replacement)
client = await self.get_client()
payload = {
'count': len(draft_list_entries),
'draft_list': draft_list_entries
"count": len(draft_list_entries),
"draft_list": draft_list_entries,
}
await client.post(self.endpoint, payload)
await client.post(f"{self.endpoint}/", payload)
logger.info(f"Reordered draft list for team {team_id}")
return True
@ -365,12 +363,7 @@ class DraftListService(BaseService[DraftList]):
logger.error(f"Error reordering draft list for team {team_id}: {e}")
return False
async def move_entry_up(
self,
season: int,
team_id: int,
player_id: int
) -> bool:
async def move_entry_up(self, season: int, team_id: int, player_id: int) -> bool:
"""
Move player up one position in draft list (higher priority).
@ -403,7 +396,9 @@ class DraftListService(BaseService[DraftList]):
return False
# Find entry above (rank - 1)
above_entry = next((e for e in entries if e.rank == current_entry.rank - 1), None)
above_entry = next(
(e for e in entries if e.rank == current_entry.rank - 1), None
)
if not above_entry:
logger.error(f"Could not find entry above rank {current_entry.rank}")
return False
@ -421,24 +416,26 @@ class DraftListService(BaseService[DraftList]):
# Keep existing rank
new_rank = entry.rank
draft_list_entries.append({
'season': entry.season,
'team_id': entry.team_id,
'player_id': entry.player_id,
'rank': new_rank
})
draft_list_entries.append(
{
"season": entry.season,
"team_id": entry.team_id,
"player_id": entry.player_id,
"rank": new_rank,
}
)
# Sort by rank
draft_list_entries.sort(key=lambda x: x['rank'])
draft_list_entries.sort(key=lambda x: x["rank"])
# POST updated list (bulk replacement)
client = await self.get_client()
payload = {
'count': len(draft_list_entries),
'draft_list': draft_list_entries
"count": len(draft_list_entries),
"draft_list": draft_list_entries,
}
await client.post(self.endpoint, payload)
await client.post(f"{self.endpoint}/", payload)
logger.info(f"Moved player {player_id} up to rank {current_entry.rank - 1}")
return True
@ -447,12 +444,7 @@ class DraftListService(BaseService[DraftList]):
logger.error(f"Error moving player {player_id} up in draft list: {e}")
return False
async def move_entry_down(
self,
season: int,
team_id: int,
player_id: int
) -> bool:
async def move_entry_down(self, season: int, team_id: int, player_id: int) -> bool:
"""
Move player down one position in draft list (lower priority).
@ -485,7 +477,9 @@ class DraftListService(BaseService[DraftList]):
return False
# Find entry below (rank + 1)
below_entry = next((e for e in entries if e.rank == current_entry.rank + 1), None)
below_entry = next(
(e for e in entries if e.rank == current_entry.rank + 1), None
)
if not below_entry:
logger.error(f"Could not find entry below rank {current_entry.rank}")
return False
@ -503,25 +497,29 @@ class DraftListService(BaseService[DraftList]):
# Keep existing rank
new_rank = entry.rank
draft_list_entries.append({
'season': entry.season,
'team_id': entry.team_id,
'player_id': entry.player_id,
'rank': new_rank
})
draft_list_entries.append(
{
"season": entry.season,
"team_id": entry.team_id,
"player_id": entry.player_id,
"rank": new_rank,
}
)
# Sort by rank
draft_list_entries.sort(key=lambda x: x['rank'])
draft_list_entries.sort(key=lambda x: x["rank"])
# POST updated list (bulk replacement)
client = await self.get_client()
payload = {
'count': len(draft_list_entries),
'draft_list': draft_list_entries
"count": len(draft_list_entries),
"draft_list": draft_list_entries,
}
await client.post(self.endpoint, payload)
logger.info(f"Moved player {player_id} down to rank {current_entry.rank + 1}")
await client.post(f"{self.endpoint}/", payload)
logger.info(
f"Moved player {player_id} down to rank {current_entry.rank + 1}"
)
return True

View File

@ -3,13 +3,14 @@ Injury service for Discord Bot v2.0
Handles injury-related operations including checking, creating, and clearing injuries.
"""
import logging
from typing import Optional, List
from services.base_service import BaseService
from models.injury import Injury
logger = logging.getLogger(f'{__name__}.InjuryService')
logger = logging.getLogger(f"{__name__}.InjuryService")
class InjuryService(BaseService[Injury]):
@ -25,7 +26,7 @@ class InjuryService(BaseService[Injury]):
def __init__(self):
"""Initialize injury service."""
super().__init__(Injury, 'injuries')
super().__init__(Injury, "injuries")
logger.debug("InjuryService initialized")
async def get_active_injury(self, player_id: int, season: int) -> Optional[Injury]:
@ -41,25 +42,31 @@ class InjuryService(BaseService[Injury]):
"""
try:
params = [
('player_id', str(player_id)),
('season', str(season)),
('is_active', 'true')
("player_id", str(player_id)),
("season", str(season)),
("is_active", "true"),
]
injuries = await self.get_all_items(params=params)
if injuries:
logger.debug(f"Found active injury for player {player_id} in season {season}")
logger.debug(
f"Found active injury for player {player_id} in season {season}"
)
return injuries[0]
logger.debug(f"No active injury found for player {player_id} in season {season}")
logger.debug(
f"No active injury found for player {player_id} in season {season}"
)
return None
except Exception as e:
logger.error(f"Error getting active injury for player {player_id}: {e}")
return None
async def get_injuries_by_player(self, player_id: int, season: int, active_only: bool = False) -> List[Injury]:
async def get_injuries_by_player(
self, player_id: int, season: int, active_only: bool = False
) -> List[Injury]:
"""
Get all injuries for a player in a specific season.
@ -72,13 +79,10 @@ class InjuryService(BaseService[Injury]):
List of injuries for the player
"""
try:
params = [
('player_id', str(player_id)),
('season', str(season))
]
params = [("player_id", str(player_id)), ("season", str(season))]
if active_only:
params.append(('is_active', 'true'))
params.append(("is_active", "true"))
injuries = await self.get_all_items(params=params)
logger.debug(f"Retrieved {len(injuries)} injuries for player {player_id}")
@ -88,7 +92,9 @@ class InjuryService(BaseService[Injury]):
logger.error(f"Error getting injuries for player {player_id}: {e}")
return []
async def get_injuries_by_team(self, team_id: int, season: int, active_only: bool = True) -> List[Injury]:
async def get_injuries_by_team(
self, team_id: int, season: int, active_only: bool = True
) -> List[Injury]:
"""
Get all injuries for a team in a specific season.
@ -101,13 +107,10 @@ class InjuryService(BaseService[Injury]):
List of injuries for the team
"""
try:
params = [
('team_id', str(team_id)),
('season', str(season))
]
params = [("team_id", str(team_id)), ("season", str(season))]
if active_only:
params.append(('is_active', 'true'))
params.append(("is_active", "true"))
injuries = await self.get_all_items(params=params)
logger.debug(f"Retrieved {len(injuries)} injuries for team {team_id}")
@ -125,7 +128,7 @@ class InjuryService(BaseService[Injury]):
start_week: int,
start_game: int,
end_week: int,
end_game: int
end_game: int,
) -> Optional[Injury]:
"""
Create a new injury record.
@ -144,22 +147,24 @@ class InjuryService(BaseService[Injury]):
"""
try:
injury_data = {
'season': season,
'player_id': player_id,
'total_games': total_games,
'start_week': start_week,
'start_game': start_game,
'end_week': end_week,
'end_game': end_game,
'is_active': True
"season": season,
"player_id": player_id,
"total_games": total_games,
"start_week": start_week,
"start_game": start_game,
"end_week": end_week,
"end_game": end_game,
"is_active": True,
}
# Call the API to create the injury
client = await self.get_client()
response = await client.post(self.endpoint, injury_data)
response = await client.post(f"{self.endpoint}/", injury_data)
if not response:
logger.error(f"Failed to create injury for player {player_id}: No response from API")
logger.error(
f"Failed to create injury for player {player_id}: No response from API"
)
return None
# Merge the request data with the response to ensure all required fields are present
@ -187,7 +192,9 @@ class InjuryService(BaseService[Injury]):
"""
try:
# Note: API expects is_active as query parameter, not JSON body
updated_injury = await self.patch(injury_id, {'is_active': False}, use_query_params=True)
updated_injury = await self.patch(
injury_id, {"is_active": False}, use_query_params=True
)
if updated_injury:
logger.info(f"Cleared injury {injury_id}")
@ -216,16 +223,18 @@ class InjuryService(BaseService[Injury]):
try:
client = await self.get_client()
params = [
('season', str(season)),
('is_active', 'true'),
('sort', 'return-asc')
("season", str(season)),
("is_active", "true"),
("sort", "return-asc"),
]
response = await client.get(self.endpoint, params=params)
if response and 'injuries' in response:
logger.debug(f"Retrieved {len(response['injuries'])} active injuries for season {season}")
return response['injuries']
if response and "injuries" in response:
logger.debug(
f"Retrieved {len(response['injuries'])} active injuries for season {season}"
)
return response["injuries"]
logger.debug(f"No active injuries found for season {season}")
return []

View File

@ -3,6 +3,7 @@ Play Service
Manages play-by-play data operations for game submission.
"""
from typing import List, Dict, Any
from utils.logging import get_contextual_logger
@ -16,7 +17,7 @@ class PlayService:
def __init__(self):
"""Initialize play service."""
self.logger = get_contextual_logger(f'{__name__}.PlayService')
self.logger = get_contextual_logger(f"{__name__}.PlayService")
self._get_client = get_global_client
async def get_client(self):
@ -39,8 +40,10 @@ class PlayService:
try:
client = await self.get_client()
payload = {'plays': plays}
response = await client.post('plays', payload)
payload = {"plays": plays}
# Trailing slash required: without it, the server returns a 307 redirect
# and aiohttp drops the POST body when following the redirect
response = await client.post("plays/", payload)
self.logger.info(f"Created {len(plays)} plays")
return True
@ -68,7 +71,7 @@ class PlayService:
"""
try:
client = await self.get_client()
response = await client.delete(f'plays/game/{game_id}')
response = await client.delete(f"plays/game/{game_id}")
self.logger.info(f"Deleted plays for game {game_id}")
return True
@ -77,11 +80,7 @@ class PlayService:
self.logger.error(f"Failed to delete plays for game {game_id}: {e}")
raise APIException(f"Failed to delete plays: {e}")
async def get_top_plays_by_wpa(
self,
game_id: int,
limit: int = 3
) -> List[Play]:
async def get_top_plays_by_wpa(self, game_id: int, limit: int = 3) -> List[Play]:
"""
Get top plays by WPA (absolute value) for key plays display.
@ -95,19 +94,15 @@ class PlayService:
try:
client = await self.get_client()
params = [
('game_id', game_id),
('sort', 'wpa-desc'),
('limit', limit)
]
params = [("game_id", game_id), ("sort", "wpa-desc"), ("limit", limit)]
response = await client.get('plays', params=params)
response = await client.get("plays", params=params)
if not response or 'plays' not in response:
self.logger.info(f'No plays found for game ID {game_id}')
if not response or "plays" not in response:
self.logger.info(f"No plays found for game ID {game_id}")
return []
plays = [Play.from_api_data(p) for p in response['plays']]
plays = [Play.from_api_data(p) for p in response["plays"]]
self.logger.debug(f"Retrieved {len(plays)} top plays for game {game_id}")
return plays
@ -129,11 +124,11 @@ class PlayService:
error_str = str(error)
# Common error patterns
if 'Player ID' in error_str and 'not found' in error_str:
if "Player ID" in error_str and "not found" in error_str:
return "Invalid player ID in scorecard data. Please check player IDs."
elif 'Game ID' in error_str and 'not found' in error_str:
elif "Game ID" in error_str and "not found" in error_str:
return "Game not found in database. Please contact an admin."
elif 'validation' in error_str.lower():
elif "validation" in error_str.lower():
return f"Data validation error: {error_str}"
else:
return f"Error submitting plays: {error_str}"

View File

@ -416,6 +416,8 @@ class SheetsService:
self.logger.info(f"Read {len(pit_data)} valid pitching decisions")
return pit_data
except SheetsException:
raise
except Exception as e:
self.logger.error(f"Failed to read pitching decisions: {e}")
raise SheetsException("Unable to read pitching decisions") from e
@ -458,6 +460,8 @@ class SheetsService:
"home": [int(x) for x in score_table[1]], # [R, H, E]
}
except SheetsException:
raise
except Exception as e:
self.logger.error(f"Failed to read box score: {e}")
raise SheetsException("Unable to read box score") from e

View File

@ -3,6 +3,7 @@ Trade Builder Service
Extends the TransactionBuilder to support multi-team trades and player exchanges.
"""
import logging
from typing import Dict, List, Optional, Set
from datetime import datetime, timezone
@ -12,10 +13,14 @@ from config import get_config
from models.trade import Trade, TradeMove, TradeStatus
from models.team import Team, RosterType
from models.player import Player
from services.transaction_builder import TransactionBuilder, RosterValidationResult, TransactionMove
from services.transaction_builder import (
TransactionBuilder,
RosterValidationResult,
TransactionMove,
)
from services.team_service import team_service
logger = logging.getLogger(f'{__name__}.TradeBuilder')
logger = logging.getLogger(f"{__name__}.TradeBuilder")
class TradeValidationResult:
@ -52,7 +57,9 @@ class TradeValidationResult:
suggestions.extend(validation.suggestions)
return suggestions
def get_participant_validation(self, team_id: int) -> Optional[RosterValidationResult]:
def get_participant_validation(
self, team_id: int
) -> Optional[RosterValidationResult]:
"""Get validation result for a specific team."""
return self.participant_validations.get(team_id)
@ -64,7 +71,12 @@ class TradeBuilder:
Extends the functionality of TransactionBuilder to support trades between teams.
"""
def __init__(self, initiated_by: int, initiating_team: Team, season: int = get_config().sba_season):
def __init__(
self,
initiated_by: int,
initiating_team: Team,
season: int = get_config().sba_season,
):
"""
Initialize trade builder.
@ -79,7 +91,7 @@ class TradeBuilder:
status=TradeStatus.DRAFT,
initiated_by=initiated_by,
created_at=datetime.now(timezone.utc).isoformat(),
season=season
season=season,
)
# Add the initiating team as first participant
@ -91,7 +103,9 @@ class TradeBuilder:
# Track which teams have accepted the trade (team_id -> True)
self.accepted_teams: Set[int] = set()
logger.info(f"TradeBuilder initialized: {self.trade.trade_id} by user {initiated_by} for {initiating_team.abbrev}")
logger.info(
f"TradeBuilder initialized: {self.trade.trade_id} by user {initiated_by} for {initiating_team.abbrev}"
)
@property
def trade_id(self) -> str:
@ -127,7 +141,11 @@ class TradeBuilder:
@property
def pending_teams(self) -> List[Team]:
"""Get list of teams that haven't accepted yet."""
return [team for team in self.participating_teams if team.id not in self.accepted_teams]
return [
team
for team in self.participating_teams
if team.id not in self.accepted_teams
]
def accept_trade(self, team_id: int) -> bool:
"""
@ -140,7 +158,9 @@ class TradeBuilder:
True if all teams have now accepted, False otherwise
"""
self.accepted_teams.add(team_id)
logger.info(f"Team {team_id} accepted trade {self.trade_id}. Accepted: {len(self.accepted_teams)}/{self.team_count}")
logger.info(
f"Team {team_id} accepted trade {self.trade_id}. Accepted: {len(self.accepted_teams)}/{self.team_count}"
)
return self.all_teams_accepted
def reject_trade(self) -> None:
@ -160,7 +180,9 @@ class TradeBuilder:
Returns:
Dict mapping team_id to acceptance status (True/False)
"""
return {team.id: team.id in self.accepted_teams for team in self.participating_teams}
return {
team.id: team.id in self.accepted_teams for team in self.participating_teams
}
def has_team_accepted(self, team_id: int) -> bool:
"""Check if a specific team has accepted."""
@ -184,7 +206,9 @@ class TradeBuilder:
participant = self.trade.add_participant(team)
# Create transaction builder for this team
self._team_builders[team.id] = TransactionBuilder(team, self.trade.initiated_by, self.trade.season)
self._team_builders[team.id] = TransactionBuilder(
team, self.trade.initiated_by, self.trade.season
)
# Register team in secondary index for multi-GM access
trade_key = f"{self.trade.initiated_by}:trade"
@ -209,7 +233,10 @@ class TradeBuilder:
# Check if team has moves - prevent removal if they do
if participant.all_moves:
return False, f"{participant.team.abbrev} has moves in this trade and cannot be removed"
return (
False,
f"{participant.team.abbrev} has moves in this trade and cannot be removed",
)
# Remove team
removed = self.trade.remove_participant(team_id)
@ -229,7 +256,7 @@ class TradeBuilder:
from_team: Team,
to_team: Team,
from_roster: RosterType,
to_roster: RosterType
to_roster: RosterType,
) -> tuple[bool, str]:
"""
Add a player move to the trade.
@ -246,7 +273,10 @@ class TradeBuilder:
"""
# Validate player is not from Free Agency
if player.team_id == get_config().free_agent_team_id:
return False, f"Cannot add {player.name} from Free Agency. Players must be traded from teams within the organizations involved in the trade."
return (
False,
f"Cannot add {player.name} from Free Agency. Players must be traded from teams within the organizations involved in the trade.",
)
# Validate player has a valid team assignment
if not player.team_id:
@ -259,7 +289,10 @@ class TradeBuilder:
# Check if player's team is in the same organization as from_team
if not player_team.is_same_organization(from_team):
return False, f"{player.name} is on {player_team.abbrev}, they are not eligible to be added to the trade."
return (
False,
f"{player.name} is on {player_team.abbrev}, they are not eligible to be added to the trade.",
)
# Ensure both teams are participating (check by organization for ML authority)
from_participant = self.trade.get_participant_by_organization(from_team)
@ -274,7 +307,10 @@ class TradeBuilder:
for participant in self.trade.participants:
for existing_move in participant.all_moves:
if existing_move.player.id == player.id:
return False, f"{player.name} is already involved in a move in this trade"
return (
False,
f"{player.name} is already involved in a move in this trade",
)
# Create trade move
trade_move = TradeMove(
@ -284,7 +320,7 @@ class TradeBuilder:
from_team=from_team,
to_team=to_team,
source_team=from_team,
destination_team=to_team
destination_team=to_team,
)
# Add to giving team's moves
@ -303,7 +339,7 @@ class TradeBuilder:
from_roster=from_roster,
to_roster=RosterType.FREE_AGENCY, # Conceptually leaving the org
from_team=from_team,
to_team=None
to_team=None,
)
# Move for receiving team (player joining)
@ -312,19 +348,23 @@ class TradeBuilder:
from_roster=RosterType.FREE_AGENCY, # Conceptually joining from outside
to_roster=to_roster,
from_team=None,
to_team=to_team
to_team=to_team,
)
# Add moves to respective builders
# Skip pending transaction check for trades - they have their own validation workflow
from_success, from_error = await from_builder.add_move(from_move, check_pending_transactions=False)
from_success, from_error = await from_builder.add_move(
from_move, check_pending_transactions=False
)
if not from_success:
# Remove from trade if builder failed
from_participant.moves_giving.remove(trade_move)
to_participant.moves_receiving.remove(trade_move)
return False, f"Error adding move to {from_team.abbrev}: {from_error}"
to_success, to_error = await to_builder.add_move(to_move, check_pending_transactions=False)
to_success, to_error = await to_builder.add_move(
to_move, check_pending_transactions=False
)
if not to_success:
# Rollback both if second failed
from_builder.remove_move(player.id)
@ -332,15 +372,13 @@ class TradeBuilder:
to_participant.moves_receiving.remove(trade_move)
return False, f"Error adding move to {to_team.abbrev}: {to_error}"
logger.info(f"Added player move to trade {self.trade_id}: {trade_move.description}")
logger.info(
f"Added player move to trade {self.trade_id}: {trade_move.description}"
)
return True, ""
async def add_supplementary_move(
self,
team: Team,
player: Player,
from_roster: RosterType,
to_roster: RosterType
self, team: Team, player: Player, from_roster: RosterType, to_roster: RosterType
) -> tuple[bool, str]:
"""
Add a supplementary move (internal organizational move) for roster legality.
@ -366,7 +404,7 @@ class TradeBuilder:
from_team=team,
to_team=team,
source_team=team,
destination_team=team
destination_team=team,
)
# Add to participant's supplementary moves
@ -379,16 +417,20 @@ class TradeBuilder:
from_roster=from_roster,
to_roster=to_roster,
from_team=team,
to_team=team
to_team=team,
)
# Skip pending transaction check for trade supplementary moves
success, error = await builder.add_move(trans_move, check_pending_transactions=False)
success, error = await builder.add_move(
trans_move, check_pending_transactions=False
)
if not success:
participant.supplementary_moves.remove(supp_move)
return False, error
logger.info(f"Added supplementary move for {team.abbrev}: {supp_move.description}")
logger.info(
f"Added supplementary move for {team.abbrev}: {supp_move.description}"
)
return True, ""
async def remove_move(self, player_id: int) -> tuple[bool, str]:
@ -432,21 +474,41 @@ class TradeBuilder:
for builder in self._team_builders.values():
builder.remove_move(player_id)
logger.info(f"Removed move from trade {self.trade_id}: {removed_move.description}")
logger.info(
f"Removed move from trade {self.trade_id}: {removed_move.description}"
)
return True, ""
async def validate_trade(self, next_week: Optional[int] = None) -> TradeValidationResult:
async def validate_trade(
self, next_week: Optional[int] = None
) -> TradeValidationResult:
"""
Validate the entire trade including all teams' roster legality.
Validates against next week's projected roster (current roster + pending
transactions), matching the behavior of /dropadd validation.
Args:
next_week: Week to validate for (optional)
next_week: Week to validate for (auto-fetched if not provided)
Returns:
TradeValidationResult with comprehensive validation
"""
result = TradeValidationResult()
# Auto-fetch next week so validation includes pending transactions
if next_week is None:
try:
from services.league_service import league_service
current_state = await league_service.get_current_state()
next_week = (current_state.week + 1) if current_state else 1
except Exception as e:
logger.warning(
f"Could not determine next week for trade validation: {e}"
)
next_week = None
# Validate trade structure
is_balanced, balance_errors = self.trade.validate_trade_balance()
if not is_balanced:
@ -472,13 +534,17 @@ class TradeBuilder:
if self.team_count < 2:
result.trade_suggestions.append("Add another team to create a trade")
logger.debug(f"Trade validation for {self.trade_id}: Legal={result.is_legal}, Errors={len(result.all_errors)}")
logger.debug(
f"Trade validation for {self.trade_id}: Legal={result.is_legal}, Errors={len(result.all_errors)}"
)
return result
def _get_or_create_builder(self, team: Team) -> TransactionBuilder:
"""Get or create a transaction builder for a team."""
if team.id not in self._team_builders:
self._team_builders[team.id] = TransactionBuilder(team, self.trade.initiated_by, self.trade.season)
self._team_builders[team.id] = TransactionBuilder(
team, self.trade.initiated_by, self.trade.season
)
return self._team_builders[team.id]
def clear_trade(self) -> None:
@ -592,4 +658,4 @@ def clear_trade_builder_by_team(team_id: int) -> bool:
def get_active_trades() -> Dict[str, TradeBuilder]:
"""Get all active trade builders (for debugging/admin purposes)."""
return _active_trade_builders.copy()
return _active_trade_builders.copy()

View File

@ -248,7 +248,7 @@ class TransactionService(BaseService[Transaction]):
# POST batch to API
client = await self.get_client()
response = await client.post(self.endpoint, data=batch_data)
response = await client.post(f"{self.endpoint}/", data=batch_data)
# API returns a string like "2 transactions have been added"
# We need to return the original Transaction objects (they won't have IDs assigned by API)

View File

@ -1,18 +1,24 @@
"""
API client tests using aioresponses for clean HTTP mocking
"""
import pytest
import asyncio
from unittest.mock import MagicMock, patch
from aioresponses import aioresponses
from api.client import APIClient, get_api_client, get_global_client, cleanup_global_client
from api.client import (
APIClient,
get_api_client,
get_global_client,
cleanup_global_client,
)
from exceptions import APIException
class TestAPIClientWithAioresponses:
"""Test API client with aioresponses for HTTP mocking."""
@pytest.fixture
def mock_config(self):
"""Mock configuration for testing."""
@ -20,66 +26,57 @@ class TestAPIClientWithAioresponses:
config.db_url = "https://api.example.com"
config.api_token = "test-token"
return config
@pytest.fixture
def api_client(self, mock_config):
"""Create API client with mocked config."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
return APIClient()
@pytest.mark.asyncio
async def test_get_request_success(self, api_client):
"""Test successful GET request."""
expected_data = {"id": 1, "name": "Test Player"}
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players/1",
payload=expected_data,
status=200
status=200,
)
result = await api_client.get("players", object_id=1)
assert result == expected_data
@pytest.mark.asyncio
async def test_get_request_404(self, api_client):
"""Test GET request returning 404."""
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players/999",
status=404
)
m.get("https://api.example.com/v3/players/999", status=404)
result = await api_client.get("players", object_id=999)
assert result is None
@pytest.mark.asyncio
async def test_get_request_401_auth_error(self, api_client):
"""Test GET request with authentication error."""
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players",
status=401
)
m.get("https://api.example.com/v3/players", status=401)
with pytest.raises(APIException, match="Authentication failed"):
await api_client.get("players")
@pytest.mark.asyncio
async def test_get_request_403_forbidden(self, api_client):
"""Test GET request with forbidden error."""
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players",
status=403
)
m.get("https://api.example.com/v3/players", status=403)
with pytest.raises(APIException, match="Access forbidden"):
await api_client.get("players")
@pytest.mark.asyncio
async def test_get_request_500_server_error(self, api_client):
"""Test GET request with server error."""
@ -87,135 +84,127 @@ class TestAPIClientWithAioresponses:
m.get(
"https://api.example.com/v3/players",
status=500,
body="Internal Server Error"
body="Internal Server Error",
)
with pytest.raises(APIException, match="API request failed with status 500"):
with pytest.raises(
APIException, match="API request failed with status 500"
):
await api_client.get("players")
@pytest.mark.asyncio
async def test_get_request_with_params(self, api_client):
"""Test GET request with query parameters."""
expected_data = {"count": 2, "players": [{"id": 1}, {"id": 2}]}
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players?team_id=5&season=12",
payload=expected_data,
status=200
status=200,
)
result = await api_client.get("players", params=[("team_id", "5"), ("season", "12")])
result = await api_client.get(
"players", params=[("team_id", "5"), ("season", "12")]
)
assert result == expected_data
@pytest.mark.asyncio
async def test_post_request_success(self, api_client):
"""Test successful POST request."""
input_data = {"name": "New Player", "position": "C"}
expected_response = {"id": 1, "name": "New Player", "position": "C"}
with aioresponses() as m:
m.post(
"https://api.example.com/v3/players",
payload=expected_response,
status=201
status=201,
)
result = await api_client.post("players", input_data)
assert result == expected_response
@pytest.mark.asyncio
async def test_post_request_400_error(self, api_client):
"""Test POST request with validation error."""
input_data = {"invalid": "data"}
with aioresponses() as m:
m.post(
"https://api.example.com/v3/players",
status=400,
body="Invalid data"
"https://api.example.com/v3/players", status=400, body="Invalid data"
)
with pytest.raises(APIException, match="POST request failed with status 400"):
with pytest.raises(
APIException, match="POST request failed with status 400"
):
await api_client.post("players", input_data)
@pytest.mark.asyncio
async def test_put_request_success(self, api_client):
"""Test successful PUT request."""
update_data = {"name": "Updated Player"}
expected_response = {"id": 1, "name": "Updated Player"}
with aioresponses() as m:
m.put(
"https://api.example.com/v3/players/1",
payload=expected_response,
status=200
status=200,
)
result = await api_client.put("players", update_data, object_id=1)
assert result == expected_response
@pytest.mark.asyncio
async def test_put_request_404(self, api_client):
"""Test PUT request with 404."""
update_data = {"name": "Updated Player"}
with aioresponses() as m:
m.put(
"https://api.example.com/v3/players/999",
status=404
)
m.put("https://api.example.com/v3/players/999", status=404)
result = await api_client.put("players", update_data, object_id=999)
assert result is None
@pytest.mark.asyncio
async def test_delete_request_success(self, api_client):
"""Test successful DELETE request."""
with aioresponses() as m:
m.delete(
"https://api.example.com/v3/players/1",
status=204
)
m.delete("https://api.example.com/v3/players/1", status=204)
result = await api_client.delete("players", object_id=1)
assert result is True
@pytest.mark.asyncio
async def test_delete_request_404(self, api_client):
"""Test DELETE request with 404."""
with aioresponses() as m:
m.delete(
"https://api.example.com/v3/players/999",
status=404
)
m.delete("https://api.example.com/v3/players/999", status=404)
result = await api_client.delete("players", object_id=999)
assert result is False
@pytest.mark.asyncio
async def test_delete_request_200_success(self, api_client):
"""Test DELETE request with 200 success."""
with aioresponses() as m:
m.delete(
"https://api.example.com/v3/players/1",
status=200
)
m.delete("https://api.example.com/v3/players/1", status=200)
result = await api_client.delete("players", object_id=1)
assert result is True
class TestAPIClientHelpers:
"""Test API client helper functions."""
@pytest.fixture
def mock_config(self):
"""Mock configuration for testing."""
@ -223,49 +212,49 @@ class TestAPIClientHelpers:
config.db_url = "https://api.example.com"
config.api_token = "test-token"
return config
@pytest.mark.asyncio
async def test_get_api_client_context_manager(self, mock_config):
"""Test get_api_client context manager."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m:
m.get(
"https://api.example.com/v3/test",
payload={"success": True},
status=200
status=200,
)
async with get_api_client() as client:
assert isinstance(client, APIClient)
result = await client.get("test")
assert result == {"success": True}
@pytest.mark.asyncio
async def test_global_client_management(self, mock_config):
"""Test global client getter and cleanup."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
# Get global client
client1 = await get_global_client()
client2 = await get_global_client()
# Should return same instance
assert client1 is client2
assert isinstance(client1, APIClient)
# Test cleanup
await cleanup_global_client()
# New client should be different instance
client3 = await get_global_client()
assert client3 is not client1
# Clean up for other tests
await cleanup_global_client()
class TestIntegrationScenarios:
"""Test realistic integration scenarios."""
@pytest.fixture
def mock_config(self):
"""Mock configuration for testing."""
@ -273,11 +262,11 @@ class TestIntegrationScenarios:
config.db_url = "https://api.example.com"
config.api_token = "test-token"
return config
@pytest.mark.asyncio
async def test_player_retrieval_with_team_lookup(self, mock_config):
"""Test realistic scenario: get player with team data."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m:
# Mock player data response
player_data = {
@ -287,43 +276,41 @@ class TestIntegrationScenarios:
"season": 12,
"team_id": 5,
"image": "https://example.com/player1.jpg",
"pos_1": "C"
"pos_1": "C",
}
m.get(
"https://api.example.com/v3/players/1",
payload=player_data,
status=200
status=200,
)
# Mock team data response
team_data = {
"id": 5,
"abbrev": "TST",
"sname": "Test Team",
"lname": "Test Team Full Name",
"season": 12
"season": 12,
}
m.get(
"https://api.example.com/v3/teams/5",
payload=team_data,
status=200
"https://api.example.com/v3/teams/5", payload=team_data, status=200
)
client = APIClient()
# Get player
player = await client.get("players", object_id=1)
assert player["name"] == "Test Player"
assert player["team_id"] == 5
# Get team for player
team = await client.get("teams", object_id=player["team_id"])
assert team["sname"] == "Test Team"
@pytest.mark.asyncio
async def test_api_response_format_handling(self, mock_config):
"""Test handling of the API's count + list format."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m:
# Mock API response with count format
api_response = {
@ -336,7 +323,7 @@ class TestIntegrationScenarios:
"season": 12,
"team_id": 5,
"image": "https://example.com/player1.jpg",
"pos_1": "C"
"pos_1": "C",
},
{
"id": 2,
@ -345,93 +332,93 @@ class TestIntegrationScenarios:
"season": 12,
"team_id": 6,
"image": "https://example.com/player2.jpg",
"pos_1": "1B"
}
]
"pos_1": "1B",
},
],
}
m.get(
"https://api.example.com/v3/players?team_id=5",
payload=api_response,
status=200
status=200,
)
client = APIClient()
result = await client.get("players", params=[("team_id", "5")])
assert result["count"] == 25
assert len(result["players"]) == 2
assert result["players"][0]["name"] == "Player 1"
@pytest.mark.asyncio
async def test_error_recovery_scenarios(self, mock_config):
"""Test error handling and recovery."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m:
# First request fails with 500
m.get(
"https://api.example.com/v3/players/1",
status=500,
body="Internal Server Error"
body="Internal Server Error",
)
# Second request succeeds
m.get(
"https://api.example.com/v3/players/2",
payload={"id": 2, "name": "Working Player"},
status=200
status=200,
)
client = APIClient()
# First request should raise exception
with pytest.raises(APIException, match="API request failed"):
await client.get("players", object_id=1)
# Second request should work fine
result = await client.get("players", object_id=2)
assert result["name"] == "Working Player"
# Client should still be functional
await client.close()
@pytest.mark.asyncio
async def test_concurrent_requests(self, mock_config):
"""Test multiple concurrent requests."""
import asyncio
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m:
# Mock multiple endpoints
for i in range(1, 4):
m.get(
f"https://api.example.com/v3/players/{i}",
payload={"id": i, "name": f"Player {i}"},
status=200
status=200,
)
client = APIClient()
# Make concurrent requests
tasks = [
client.get("players", object_id=1),
client.get("players", object_id=2),
client.get("players", object_id=3)
client.get("players", object_id=3),
]
results = await asyncio.gather(*tasks)
assert len(results) == 3
assert results[0]["name"] == "Player 1"
assert results[1]["name"] == "Player 2"
assert results[2]["name"] == "Player 3"
await client.close()
class TestAPIClientCoverageExtras:
"""Additional coverage tests for API client edge cases."""
@pytest.fixture
def mock_config(self):
"""Mock configuration for testing."""
@ -439,98 +426,104 @@ class TestAPIClientCoverageExtras:
config.db_url = "https://api.example.com"
config.api_token = "test-token"
return config
@pytest.mark.asyncio
async def test_global_client_cleanup_when_none(self):
"""Test cleanup when no global client exists."""
# Ensure no global client exists
await cleanup_global_client()
# Should not raise error
await cleanup_global_client()
@pytest.mark.asyncio
async def test_url_building_edge_cases(self, mock_config):
"""Test URL building with various edge cases."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
client = APIClient()
# Test trailing slash handling
client.base_url = "https://api.example.com/"
url = client._build_url("players")
assert url == "https://api.example.com/v3/players"
assert "//" not in url.replace("https://", "")
@pytest.mark.asyncio
async def test_parameter_handling_edge_cases(self, mock_config):
"""Test parameter handling with various scenarios."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
client = APIClient()
# Test with existing query string
url = client._add_params("https://example.com/api?existing=true", [("new", "param")])
url = client._add_params(
"https://example.com/api?existing=true", [("new", "param")]
)
assert url == "https://example.com/api?existing=true&new=param"
# Test with no parameters
url = client._add_params("https://example.com/api")
assert url == "https://example.com/api"
@pytest.mark.asyncio
async def test_timeout_error_handling(self, mock_config):
"""Test timeout error handling using aioresponses."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
client = APIClient()
# Test timeout using aioresponses exception parameter
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players",
exception=asyncio.TimeoutError("Request timed out")
exception=asyncio.TimeoutError("Request timed out"),
)
with pytest.raises(APIException, match="API call failed.*Request timed out"):
with pytest.raises(
APIException, match="API call failed.*Request timed out"
):
await client.get("players")
await client.close()
@pytest.mark.asyncio
async def test_generic_exception_handling(self, mock_config):
"""Test generic exception handling."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
client = APIClient()
# Test generic exception
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players",
exception=Exception("Generic error")
exception=Exception("Generic error"),
)
with pytest.raises(APIException, match="API call failed.*Generic error"):
with pytest.raises(
APIException, match="API call failed.*Generic error"
):
await client.get("players")
await client.close()
@pytest.mark.asyncio
async def test_session_closed_handling(self, mock_config):
"""Test handling of closed session."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
# Test that the client recreates session when needed
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players",
payload={"success": True},
status=200
status=200,
)
client = APIClient()
# Close the session manually
await client._ensure_session()
await client._session.close()
# Client should recreate session and work fine
result = await client.get("players")
assert result == {"success": True}
await client.close()
await client.close()

View File

@ -1,6 +1,7 @@
"""
Tests for BaseService functionality
"""
import pytest
from unittest.mock import AsyncMock
@ -10,6 +11,7 @@ from models.base import SBABaseModel
class MockModel(SBABaseModel):
"""Mock model for testing BaseService."""
id: int
name: str
value: int = 100
@ -17,240 +19,229 @@ class MockModel(SBABaseModel):
class TestBaseService:
"""Test BaseService functionality."""
@pytest.fixture
def mock_client(self):
"""Mock API client."""
client = AsyncMock()
return client
@pytest.fixture
def base_service(self, mock_client):
"""Create BaseService instance for testing."""
service = BaseService(MockModel, 'mocks', client=mock_client)
service = BaseService(MockModel, "mocks", client=mock_client)
return service
@pytest.mark.asyncio
async def test_init(self):
"""Test service initialization."""
service = BaseService(MockModel, 'test_endpoint')
service = BaseService(MockModel, "test_endpoint")
assert service.model_class == MockModel
assert service.endpoint == 'test_endpoint'
assert service.endpoint == "test_endpoint"
assert service._client is None
@pytest.mark.asyncio
async def test_get_by_id_success(self, base_service, mock_client):
"""Test successful get_by_id."""
mock_data = {'id': 1, 'name': 'Test', 'value': 200}
mock_data = {"id": 1, "name": "Test", "value": 200}
mock_client.get.return_value = mock_data
result = await base_service.get_by_id(1)
assert isinstance(result, MockModel)
assert result.id == 1
assert result.name == 'Test'
assert result.name == "Test"
assert result.value == 200
mock_client.get.assert_called_once_with('mocks', object_id=1)
mock_client.get.assert_called_once_with("mocks", object_id=1)
@pytest.mark.asyncio
async def test_get_by_id_not_found(self, base_service, mock_client):
"""Test get_by_id when object not found."""
mock_client.get.return_value = None
result = await base_service.get_by_id(999)
assert result is None
mock_client.get.assert_called_once_with('mocks', object_id=999)
mock_client.get.assert_called_once_with("mocks", object_id=999)
@pytest.mark.asyncio
async def test_get_all_with_count(self, base_service, mock_client):
"""Test get_all with count response format."""
mock_data = {
'count': 2,
'mocks': [
{'id': 1, 'name': 'Test1', 'value': 100},
{'id': 2, 'name': 'Test2', 'value': 200}
]
"count": 2,
"mocks": [
{"id": 1, "name": "Test1", "value": 100},
{"id": 2, "name": "Test2", "value": 200},
],
}
mock_client.get.return_value = mock_data
result, count = await base_service.get_all()
assert len(result) == 2
assert count == 2
assert all(isinstance(item, MockModel) for item in result)
mock_client.get.assert_called_once_with('mocks', params=None)
mock_client.get.assert_called_once_with("mocks", params=None)
@pytest.mark.asyncio
async def test_get_all_items_convenience(self, base_service, mock_client):
"""Test get_all_items convenience method."""
mock_data = {
'count': 1,
'mocks': [{'id': 1, 'name': 'Test', 'value': 100}]
}
mock_data = {"count": 1, "mocks": [{"id": 1, "name": "Test", "value": 100}]}
mock_client.get.return_value = mock_data
result = await base_service.get_all_items()
assert len(result) == 1
assert isinstance(result[0], MockModel)
@pytest.mark.asyncio
async def test_create_success(self, base_service, mock_client):
"""Test successful object creation."""
input_data = {'name': 'New Item', 'value': 300}
response_data = {'id': 3, 'name': 'New Item', 'value': 300}
input_data = {"name": "New Item", "value": 300}
response_data = {"id": 3, "name": "New Item", "value": 300}
mock_client.post.return_value = response_data
result = await base_service.create(input_data)
assert isinstance(result, MockModel)
assert result.id == 3
assert result.name == 'New Item'
mock_client.post.assert_called_once_with('mocks', input_data)
assert result.name == "New Item"
mock_client.post.assert_called_once_with("mocks/", input_data)
@pytest.mark.asyncio
async def test_update_success(self, base_service, mock_client):
"""Test successful object update."""
update_data = {'name': 'Updated'}
response_data = {'id': 1, 'name': 'Updated', 'value': 100}
update_data = {"name": "Updated"}
response_data = {"id": 1, "name": "Updated", "value": 100}
mock_client.put.return_value = response_data
result = await base_service.update(1, update_data)
assert isinstance(result, MockModel)
assert result.name == 'Updated'
mock_client.put.assert_called_once_with('mocks', update_data, object_id=1)
assert result.name == "Updated"
mock_client.put.assert_called_once_with("mocks", update_data, object_id=1)
@pytest.mark.asyncio
async def test_delete_success(self, base_service, mock_client):
"""Test successful object deletion."""
mock_client.delete.return_value = True
result = await base_service.delete(1)
assert result is True
mock_client.delete.assert_called_once_with('mocks', object_id=1)
mock_client.delete.assert_called_once_with("mocks", object_id=1)
@pytest.mark.asyncio
async def test_get_by_field(self, base_service, mock_client):
"""Test get_by_field functionality."""
mock_data = {
'count': 1,
'mocks': [{'id': 1, 'name': 'Test', 'value': 100}]
}
mock_data = {"count": 1, "mocks": [{"id": 1, "name": "Test", "value": 100}]}
mock_client.get.return_value = mock_data
result = await base_service.get_by_field('name', 'Test')
result = await base_service.get_by_field("name", "Test")
assert len(result) == 1
mock_client.get.assert_called_once_with('mocks', params=[('name', 'Test')])
mock_client.get.assert_called_once_with("mocks", params=[("name", "Test")])
def test_extract_items_and_count_standard_format(self, base_service):
"""Test response parsing for standard format."""
data = {
'count': 3,
'mocks': [
{'id': 1, 'name': 'Test1'},
{'id': 2, 'name': 'Test2'},
{'id': 3, 'name': 'Test3'}
]
"count": 3,
"mocks": [
{"id": 1, "name": "Test1"},
{"id": 2, "name": "Test2"},
{"id": 3, "name": "Test3"},
],
}
items, count = base_service._extract_items_and_count_from_response(data)
assert len(items) == 3
assert count == 3
assert items[0]['name'] == 'Test1'
assert items[0]["name"] == "Test1"
def test_extract_items_and_count_single_object(self, base_service):
"""Test response parsing for single object."""
data = {'id': 1, 'name': 'Single'}
data = {"id": 1, "name": "Single"}
items, count = base_service._extract_items_and_count_from_response(data)
assert len(items) == 1
assert count == 1
assert items[0] == data
def test_extract_items_and_count_direct_list(self, base_service):
"""Test response parsing for direct list."""
data = [
{'id': 1, 'name': 'Test1'},
{'id': 2, 'name': 'Test2'}
]
data = [{"id": 1, "name": "Test1"}, {"id": 2, "name": "Test2"}]
items, count = base_service._extract_items_and_count_from_response(data)
assert len(items) == 2
assert count == 2
class TestBaseServiceExtras:
"""Additional coverage tests for BaseService edge cases."""
@pytest.mark.asyncio
async def test_base_service_additional_methods(self):
"""Test additional BaseService methods for coverage."""
from services.base_service import BaseService
from models.base import SBABaseModel
class TestModel(SBABaseModel):
name: str
value: int = 100
mock_client = AsyncMock()
service = BaseService(TestModel, 'test', client=mock_client)
service = BaseService(TestModel, "test", client=mock_client)
# Test count method
mock_client.reset_mock()
mock_client.get.return_value = {'count': 42, 'test': []}
count = await service.count(params=[('active', 'true')])
mock_client.get.return_value = {"count": 42, "test": []}
count = await service.count(params=[("active", "true")])
assert count == 42
# Test update_from_model with ID
mock_client.reset_mock()
model = TestModel(id=1, name="Updated", value=300)
mock_client.put.return_value = {"id": 1, "name": "Updated", "value": 300}
result = await service.update_from_model(model)
assert result.name == "Updated"
# Test update_from_model without ID
model_no_id = TestModel(name="Test")
with pytest.raises(ValueError, match="Cannot update TestModel without ID"):
await service.update_from_model(model_no_id)
def test_base_service_response_parsing_edge_cases(self):
"""Test edge cases in response parsing."""
from services.base_service import BaseService
from models.base import SBABaseModel
class TestModel(SBABaseModel):
name: str
service = BaseService(TestModel, 'test')
service = BaseService(TestModel, "test")
# Test with 'items' field
data = {'count': 2, 'items': [{'name': 'Item1'}, {'name': 'Item2'}]}
data = {"count": 2, "items": [{"name": "Item1"}, {"name": "Item2"}]}
items, count = service._extract_items_and_count_from_response(data)
assert len(items) == 2
assert count == 2
# Test with 'data' field
data = {'count': 1, 'data': [{'name': 'DataItem'}]}
data = {"count": 1, "data": [{"name": "DataItem"}]}
items, count = service._extract_items_and_count_from_response(data)
assert len(items) == 1
assert count == 1
# Test with count but no recognizable list field
data = {'count': 5, 'unknown_field': [{'name': 'Item'}]}
data = {"count": 5, "unknown_field": [{"name": "Item"}]}
items, count = service._extract_items_and_count_from_response(data)
assert len(items) == 0
assert count == 5
# Test with unexpected data type
items, count = service._extract_items_and_count_from_response("unexpected")
assert len(items) == 0
assert count == 0
assert count == 0

View File

@ -3,6 +3,7 @@ Interactive Trade Embed Views
Handles the Discord embed and button interfaces for the multi-team trade builder.
"""
import discord
from typing import Optional, List
from datetime import datetime, timezone
@ -31,60 +32,56 @@ class TradeEmbedView(discord.ui.View):
"""Check if user has permission to interact with this view."""
if interaction.user.id != self.user_id:
await interaction.response.send_message(
"You don't have permission to use this trade builder.",
ephemeral=True
"You don't have permission to use this trade builder.",
ephemeral=True,
)
return False
return True
async def on_timeout(self) -> None:
"""Handle view timeout."""
# Disable all buttons when timeout occurs
for item in self.children:
if isinstance(item, discord.ui.Button):
item.disabled = True
@discord.ui.button(label="Remove Move", style=discord.ButtonStyle.red, emoji="")
async def remove_move_button(self, interaction: discord.Interaction, button: discord.ui.Button):
@discord.ui.button(label="Remove Move", style=discord.ButtonStyle.red)
async def remove_move_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle remove move button click."""
if self.builder.is_empty:
await interaction.response.send_message(
"❌ No moves to remove. Add some moves first!",
ephemeral=True
"No moves to remove. Add some moves first!", ephemeral=True
)
return
# Create select menu for move removal
select_view = RemoveTradeMovesView(self.builder, self.user_id)
embed = await create_trade_embed(self.builder)
await interaction.response.edit_message(embed=embed, view=select_view)
@discord.ui.button(label="Validate Trade", style=discord.ButtonStyle.secondary, emoji="🔍")
async def validate_button(self, interaction: discord.Interaction, button: discord.ui.Button):
@discord.ui.button(label="Validate Trade", style=discord.ButtonStyle.secondary)
async def validate_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle validate trade button click."""
await interaction.response.defer(ephemeral=True)
# Perform detailed validation
validation = await self.builder.validate_trade()
# Create validation report
if validation.is_legal:
status_emoji = ""
status_text = "**Trade is LEGAL**"
color = EmbedColors.SUCCESS
else:
status_emoji = ""
status_text = "**Trade has ERRORS**"
color = EmbedColors.ERROR
embed = EmbedTemplate.create_base_embed(
title=f"{status_emoji} Trade Validation Report",
title="Trade Validation Report",
description=status_text,
color=color
color=color,
)
# Add team-by-team validation
for participant in self.builder.trade.participants:
team_validation = validation.get_participant_validation(participant.team.id)
if team_validation:
@ -98,72 +95,65 @@ class TradeEmbedView(discord.ui.View):
team_status.append(team_validation.pre_existing_transactions_note)
embed.add_field(
name=f"🏟️ {participant.team.abbrev} - {participant.team.sname}",
name=f"{participant.team.abbrev} - {participant.team.sname}",
value="\n".join(team_status),
inline=False
inline=False,
)
# Add overall errors and suggestions
if validation.all_errors:
error_text = "\n".join([f"{error}" for error in validation.all_errors])
embed.add_field(
name="❌ Errors",
value=error_text,
inline=False
)
error_text = "\n".join([f"- {error}" for error in validation.all_errors])
embed.add_field(name="Errors", value=error_text, inline=False)
if validation.all_suggestions:
suggestion_text = "\n".join([f"💡 {suggestion}" for suggestion in validation.all_suggestions])
embed.add_field(
name="💡 Suggestions",
value=suggestion_text,
inline=False
suggestion_text = "\n".join(
[f"- {suggestion}" for suggestion in validation.all_suggestions]
)
embed.add_field(name="Suggestions", value=suggestion_text, inline=False)
await interaction.followup.send(embed=embed, ephemeral=True)
@discord.ui.button(label="Submit Trade", style=discord.ButtonStyle.primary, emoji="📤")
async def submit_button(self, interaction: discord.Interaction, button: discord.ui.Button):
@discord.ui.button(label="Submit Trade", style=discord.ButtonStyle.primary)
async def submit_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle submit trade button click."""
if self.builder.is_empty:
await interaction.response.send_message(
"❌ Cannot submit empty trade. Add some moves first!",
ephemeral=True
"Cannot submit empty trade. Add some moves first!", ephemeral=True
)
return
# Validate before submission
validation = await self.builder.validate_trade()
if not validation.is_legal:
error_msg = "**Cannot submit illegal trade:**\n"
error_msg += "\n".join([f" {error}" for error in validation.all_errors])
error_msg = "**Cannot submit illegal trade:**\n"
error_msg += "\n".join([f"- {error}" for error in validation.all_errors])
if validation.all_suggestions:
error_msg += "\n\n**Suggestions:**\n"
error_msg += "\n".join([f"💡 {suggestion}" for suggestion in validation.all_suggestions])
error_msg += "\n".join(
[f"- {suggestion}" for suggestion in validation.all_suggestions]
)
await interaction.response.send_message(error_msg, ephemeral=True)
return
# Show confirmation modal
modal = SubmitTradeConfirmationModal(self.builder)
await interaction.response.send_modal(modal)
@discord.ui.button(label="Cancel Trade", style=discord.ButtonStyle.secondary, emoji="")
async def cancel_button(self, interaction: discord.Interaction, button: discord.ui.Button):
@discord.ui.button(label="Cancel Trade", style=discord.ButtonStyle.secondary)
async def cancel_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle cancel trade button click."""
self.builder.clear_trade()
embed = await create_trade_embed(self.builder)
# Disable all buttons after cancellation
for item in self.children:
if isinstance(item, discord.ui.Button):
item.disabled = True
await interaction.response.edit_message(
content="❌ **Trade cancelled and cleared.**",
embed=embed,
view=self
content="**Trade cancelled and cleared.**", embed=embed, view=self
)
self.stop()
@ -176,12 +166,12 @@ class RemoveTradeMovesView(discord.ui.View):
self.builder = builder
self.user_id = user_id
# Create select menu with current moves
if not builder.is_empty:
self.add_item(RemoveTradeMovesSelect(builder))
# Add back button
back_button = discord.ui.Button(label="Back", style=discord.ButtonStyle.secondary, emoji="⬅️")
back_button = discord.ui.Button(
label="Back", style=discord.ButtonStyle.secondary
)
back_button.callback = self.back_callback
self.add_item(back_button)
@ -202,35 +192,36 @@ class RemoveTradeMovesSelect(discord.ui.Select):
def __init__(self, builder: TradeBuilder):
self.builder = builder
# Create options from all moves (cross-team and supplementary)
options = []
move_count = 0
# Add cross-team moves
for move in builder.trade.cross_team_moves[:20]: # Limit to avoid Discord's 25 option limit
options.append(discord.SelectOption(
label=f"{move.player.name}",
description=move.description[:100], # Discord description limit
value=str(move.player.id),
emoji="🔄"
))
for move in builder.trade.cross_team_moves[
:20
]: # Limit to avoid Discord's 25 option limit
options.append(
discord.SelectOption(
label=f"{move.player.name}",
description=move.description[:100],
value=str(move.player.id),
)
)
move_count += 1
# Add supplementary moves if there's room
remaining_slots = 25 - move_count
for move in builder.trade.supplementary_moves[:remaining_slots]:
options.append(discord.SelectOption(
label=f"{move.player.name}",
description=move.description[:100],
value=str(move.player.id),
emoji="⚙️"
))
options.append(
discord.SelectOption(
label=f"{move.player.name}",
description=move.description[:100],
value=str(move.player.id),
)
)
super().__init__(
placeholder="Select a move to remove...",
min_values=1,
max_values=1,
options=options
options=options,
)
async def callback(self, interaction: discord.Interaction):
@ -241,27 +232,25 @@ class RemoveTradeMovesSelect(discord.ui.Select):
if success:
await interaction.response.send_message(
f"✅ Removed move for player ID {player_id}",
ephemeral=True
f"Removed move for player ID {player_id}", ephemeral=True
)
# Update the embed
main_view = TradeEmbedView(self.builder, interaction.user.id)
embed = await create_trade_embed(self.builder)
# Edit the original message
await interaction.edit_original_response(embed=embed, view=main_view)
else:
await interaction.response.send_message(
f"❌ Could not remove move: {error_msg}",
ephemeral=True
f"Could not remove move: {error_msg}", ephemeral=True
)
class SubmitTradeConfirmationModal(discord.ui.Modal):
"""Modal for confirming trade submission - posts acceptance request to trade channel."""
def __init__(self, builder: TradeBuilder, trade_channel: Optional[discord.TextChannel] = None):
def __init__(
self, builder: TradeBuilder, trade_channel: Optional[discord.TextChannel] = None
):
super().__init__(title="Confirm Trade Submission")
self.builder = builder
self.trade_channel = trade_channel
@ -270,7 +259,7 @@ class SubmitTradeConfirmationModal(discord.ui.Modal):
label="Type 'CONFIRM' to submit for approval",
placeholder="CONFIRM",
required=True,
max_length=7
max_length=7,
)
self.add_item(self.confirmation)
@ -279,56 +268,52 @@ class SubmitTradeConfirmationModal(discord.ui.Modal):
"""Handle confirmation submission - posts acceptance view to trade channel."""
if self.confirmation.value.upper() != "CONFIRM":
await interaction.response.send_message(
"Trade not submitted. You must type 'CONFIRM' exactly.",
ephemeral=True
"Trade not submitted. You must type 'CONFIRM' exactly.",
ephemeral=True,
)
return
await interaction.response.defer(ephemeral=True)
try:
# Update trade status to PROPOSED
from models.trade import TradeStatus
self.builder.trade.status = TradeStatus.PROPOSED
# Create acceptance embed and view
acceptance_embed = await create_trade_acceptance_embed(self.builder)
acceptance_view = TradeAcceptanceView(self.builder)
# Find the trade channel to post to
channel = self.trade_channel
if not channel:
# Try to find trade channel by name pattern
trade_channel_name = f"trade-{'-'.join(t.abbrev.lower() for t in self.builder.participating_teams)}"
for ch in interaction.guild.text_channels: # type: ignore
if ch.name.startswith("trade-") and self.builder.trade_id[:4] in ch.name:
if (
ch.name.startswith("trade-")
and self.builder.trade_id[:4] in ch.name
):
channel = ch
break
if channel:
# Post acceptance request to trade channel
await channel.send(
content="📋 **Trade submitted for approval!** All teams must accept to complete the trade.",
content="**Trade submitted for approval.** All teams must accept to complete the trade.",
embed=acceptance_embed,
view=acceptance_view
view=acceptance_view,
)
await interaction.followup.send(
f"✅ Trade submitted for approval!\n\nThe acceptance request has been posted to {channel.mention}.\n"
f"Trade submitted for approval.\n\nThe acceptance request has been posted to {channel.mention}.\n"
f"All participating teams must click **Accept Trade** to finalize.",
ephemeral=True
ephemeral=True,
)
else:
# No trade channel found, post in current channel
await interaction.followup.send(
content="📋 **Trade submitted for approval!** All teams must accept to complete the trade.",
content="**Trade submitted for approval.** All teams must accept to complete the trade.",
embed=acceptance_embed,
view=acceptance_view
view=acceptance_view,
)
except Exception as e:
await interaction.followup.send(
f"❌ Error submitting trade: {str(e)}",
ephemeral=True
f"Error submitting trade: {str(e)}", ephemeral=True
)
@ -343,8 +328,11 @@ class TradeAcceptanceView(discord.ui.View):
"""Get the team owned by the interacting user."""
from services.team_service import team_service
from config import get_config
config = get_config()
return await team_service.get_team_by_owner(interaction.user.id, config.sba_season)
return await team_service.get_team_by_owner(
interaction.user.id, config.sba_season
)
async def interaction_check(self, interaction: discord.Interaction) -> bool:
"""Check if user is a GM of a participating team."""
@ -352,17 +340,14 @@ class TradeAcceptanceView(discord.ui.View):
if not user_team:
await interaction.response.send_message(
"❌ You don't own a team in the league.",
ephemeral=True
"You don't own a team in the league.", ephemeral=True
)
return False
# Check if their team (or organization) is participating
participant = self.builder.trade.get_participant_by_organization(user_team)
if not participant:
await interaction.response.send_message(
"❌ Your team is not part of this trade.",
ephemeral=True
"Your team is not part of this trade.", ephemeral=True
)
return False
@ -374,47 +359,45 @@ class TradeAcceptanceView(discord.ui.View):
if isinstance(item, discord.ui.Button):
item.disabled = True
@discord.ui.button(label="Accept Trade", style=discord.ButtonStyle.success, emoji="")
async def accept_button(self, interaction: discord.Interaction, button: discord.ui.Button):
@discord.ui.button(label="Accept Trade", style=discord.ButtonStyle.success)
async def accept_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle accept button click."""
user_team = await self._get_user_team(interaction)
if not user_team:
return
# Find the participating team (could be org affiliate)
participant = self.builder.trade.get_participant_by_organization(user_team)
if not participant:
return
team_id = participant.team.id
# Check if already accepted
if self.builder.has_team_accepted(team_id):
await interaction.response.send_message(
f"{participant.team.abbrev} has already accepted this trade.",
ephemeral=True
f"{participant.team.abbrev} has already accepted this trade.",
ephemeral=True,
)
return
# Record acceptance
all_accepted = self.builder.accept_trade(team_id)
if all_accepted:
# All teams accepted - finalize the trade
await self._finalize_trade(interaction)
else:
# Update embed to show new acceptance status
embed = await create_trade_acceptance_embed(self.builder)
await interaction.response.edit_message(embed=embed, view=self)
# Send confirmation to channel
await interaction.followup.send(
f"**{participant.team.abbrev}** has accepted the trade! "
f"**{participant.team.abbrev}** has accepted the trade. "
f"({len(self.builder.accepted_teams)}/{self.builder.team_count} teams)"
)
@discord.ui.button(label="Reject Trade", style=discord.ButtonStyle.danger, emoji="")
async def reject_button(self, interaction: discord.Interaction, button: discord.ui.Button):
@discord.ui.button(label="Reject Trade", style=discord.ButtonStyle.danger)
async def reject_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle reject button click - moves trade back to DRAFT."""
user_team = await self._get_user_team(interaction)
if not user_team:
@ -424,20 +407,16 @@ class TradeAcceptanceView(discord.ui.View):
if not participant:
return
# Reject the trade
self.builder.reject_trade()
# Disable buttons
self.accept_button.disabled = True
self.reject_button.disabled = True
# Update embed to show rejection
embed = await create_trade_rejection_embed(self.builder, participant.team)
await interaction.response.edit_message(embed=embed, view=self)
# Notify the channel
await interaction.followup.send(
f"**{participant.team.abbrev}** has rejected the trade.\n\n"
f"**{participant.team.abbrev}** has rejected the trade.\n\n"
f"The trade has been moved back to **DRAFT** status. "
f"Teams can continue negotiating using `/trade` commands."
)
@ -459,41 +438,52 @@ class TradeAcceptanceView(discord.ui.View):
config = get_config()
# Get next week for transactions
current = await league_service.get_current_state()
next_week = current.week + 1 if current else 1
# Create FA team for reference
fa_team = Team(
id=config.free_agent_team_id,
abbrev="FA",
sname="Free Agents",
lname="Free Agency",
season=self.builder.trade.season
season=self.builder.trade.season,
) # type: ignore
# Create transactions from all moves
transactions: List[Transaction] = []
move_id = f"Trade-{self.builder.trade_id}-{int(datetime.now(timezone.utc).timestamp())}"
# Process cross-team moves
for move in self.builder.trade.cross_team_moves:
# Get actual team affiliates for from/to based on roster type
if move.from_roster == RosterType.MAJOR_LEAGUE:
old_team = move.source_team
elif move.from_roster == RosterType.MINOR_LEAGUE:
old_team = await move.source_team.minor_league_affiliate() if move.source_team else None
old_team = (
await move.source_team.minor_league_affiliate()
if move.source_team
else None
)
elif move.from_roster == RosterType.INJURED_LIST:
old_team = await move.source_team.injured_list_affiliate() if move.source_team else None
old_team = (
await move.source_team.injured_list_affiliate()
if move.source_team
else None
)
else:
old_team = move.source_team
if move.to_roster == RosterType.MAJOR_LEAGUE:
new_team = move.destination_team
elif move.to_roster == RosterType.MINOR_LEAGUE:
new_team = await move.destination_team.minor_league_affiliate() if move.destination_team else None
new_team = (
await move.destination_team.minor_league_affiliate()
if move.destination_team
else None
)
elif move.to_roster == RosterType.INJURED_LIST:
new_team = await move.destination_team.injured_list_affiliate() if move.destination_team else None
new_team = (
await move.destination_team.injured_list_affiliate()
if move.destination_team
else None
)
else:
new_team = move.destination_team
@ -507,18 +497,25 @@ class TradeAcceptanceView(discord.ui.View):
oldteam=old_team,
newteam=new_team,
cancelled=False,
frozen=False # Trades are NOT frozen - immediately effective
frozen=False,
)
transactions.append(transaction)
# Process supplementary moves
for move in self.builder.trade.supplementary_moves:
if move.from_roster == RosterType.MAJOR_LEAGUE:
old_team = move.source_team
elif move.from_roster == RosterType.MINOR_LEAGUE:
old_team = await move.source_team.minor_league_affiliate() if move.source_team else None
old_team = (
await move.source_team.minor_league_affiliate()
if move.source_team
else None
)
elif move.from_roster == RosterType.INJURED_LIST:
old_team = await move.source_team.injured_list_affiliate() if move.source_team else None
old_team = (
await move.source_team.injured_list_affiliate()
if move.source_team
else None
)
elif move.from_roster == RosterType.FREE_AGENCY:
old_team = fa_team
else:
@ -527,9 +524,17 @@ class TradeAcceptanceView(discord.ui.View):
if move.to_roster == RosterType.MAJOR_LEAGUE:
new_team = move.destination_team
elif move.to_roster == RosterType.MINOR_LEAGUE:
new_team = await move.destination_team.minor_league_affiliate() if move.destination_team else None
new_team = (
await move.destination_team.minor_league_affiliate()
if move.destination_team
else None
)
elif move.to_roster == RosterType.INJURED_LIST:
new_team = await move.destination_team.injured_list_affiliate() if move.destination_team else None
new_team = (
await move.destination_team.injured_list_affiliate()
if move.destination_team
else None
)
elif move.to_roster == RosterType.FREE_AGENCY:
new_team = fa_team
else:
@ -545,45 +550,42 @@ class TradeAcceptanceView(discord.ui.View):
oldteam=old_team,
newteam=new_team,
cancelled=False,
frozen=False # Trades are NOT frozen - immediately effective
frozen=False,
)
transactions.append(transaction)
# POST transactions to database
if transactions:
created_transactions = await transaction_service.create_transaction_batch(transactions)
created_transactions = (
await transaction_service.create_transaction_batch(transactions)
)
else:
created_transactions = []
# Post to #transaction-log channel
if created_transactions and interaction.client:
await post_trade_to_log(
bot=interaction.client,
builder=self.builder,
transactions=created_transactions,
effective_week=next_week
effective_week=next_week,
)
# Update trade status
self.builder.trade.status = TradeStatus.ACCEPTED
# Disable buttons
self.accept_button.disabled = True
self.reject_button.disabled = True
# Update embed to show completion
embed = await create_trade_complete_embed(self.builder, len(created_transactions), next_week)
embed = await create_trade_complete_embed(
self.builder, len(created_transactions), next_week
)
await interaction.edit_original_response(embed=embed, view=self)
# Send completion message
await interaction.followup.send(
f"🎉 **Trade Complete!**\n\n"
f"**Trade Complete!**\n\n"
f"All {self.builder.team_count} teams have accepted the trade.\n"
f"**{len(created_transactions)} transactions** have been created for **Week {next_week}**.\n\n"
f"Trade ID: `{self.builder.trade_id}`"
)
# Clear the trade builder
for team in self.builder.participating_teams:
clear_trade_builder_by_team(team.id)
@ -591,81 +593,79 @@ class TradeAcceptanceView(discord.ui.View):
except Exception as e:
await interaction.followup.send(
f"❌ Error finalizing trade: {str(e)}",
ephemeral=True
f"Error finalizing trade: {str(e)}", ephemeral=True
)
async def create_trade_acceptance_embed(builder: TradeBuilder) -> discord.Embed:
"""Create embed showing trade details and acceptance status."""
embed = EmbedTemplate.create_base_embed(
title=f"📋 Trade Pending Acceptance - {builder.trade.get_trade_summary()}",
title=f"Trade Pending Acceptance - {builder.trade.get_trade_summary()}",
description="All participating teams must accept to complete the trade.",
color=EmbedColors.WARNING
color=EmbedColors.WARNING,
)
# Show participating teams
team_list = [f"{team.abbrev} - {team.sname}" for team in builder.participating_teams]
team_list = [
f"- {team.abbrev} - {team.sname}" for team in builder.participating_teams
]
embed.add_field(
name=f"🏟️ Participating Teams ({builder.team_count})",
name=f"Participating Teams ({builder.team_count})",
value="\n".join(team_list),
inline=False
inline=False,
)
# Show cross-team moves
if builder.trade.cross_team_moves:
moves_text = ""
for move in builder.trade.cross_team_moves[:10]:
moves_text += f" {move.description}\n"
moves_text += f"- {move.description}\n"
if len(builder.trade.cross_team_moves) > 10:
moves_text += f"... and {len(builder.trade.cross_team_moves) - 10} more"
embed.add_field(
name=f"🔄 Player Exchanges ({len(builder.trade.cross_team_moves)})",
name=f"Player Exchanges ({len(builder.trade.cross_team_moves)})",
value=moves_text,
inline=False
inline=False,
)
# Show supplementary moves if any
if builder.trade.supplementary_moves:
supp_text = ""
for move in builder.trade.supplementary_moves[:5]:
supp_text += f" {move.description}\n"
supp_text += f"- {move.description}\n"
if len(builder.trade.supplementary_moves) > 5:
supp_text += f"... and {len(builder.trade.supplementary_moves) - 5} more"
embed.add_field(
name=f"⚙️ Supplementary Moves ({len(builder.trade.supplementary_moves)})",
name=f"Supplementary Moves ({len(builder.trade.supplementary_moves)})",
value=supp_text,
inline=False
inline=False,
)
# Show acceptance status
status_lines = []
for team in builder.participating_teams:
if team.id in builder.accepted_teams:
status_lines.append(f"**{team.abbrev}** - Accepted")
status_lines.append(f"**{team.abbrev}** - Accepted")
else:
status_lines.append(f"**{team.abbrev}** - Pending")
status_lines.append(f"**{team.abbrev}** - Pending")
embed.add_field(
name="📊 Acceptance Status",
value="\n".join(status_lines),
inline=False
name="Acceptance Status", value="\n".join(status_lines), inline=False
)
# Add footer
embed.set_footer(text=f"Trade ID: {builder.trade_id}{len(builder.accepted_teams)}/{builder.team_count} teams accepted")
embed.set_footer(
text=f"Trade ID: {builder.trade_id} | {len(builder.accepted_teams)}/{builder.team_count} teams accepted"
)
return embed
async def create_trade_rejection_embed(builder: TradeBuilder, rejecting_team: Team) -> discord.Embed:
async def create_trade_rejection_embed(
builder: TradeBuilder, rejecting_team: Team
) -> discord.Embed:
"""Create embed showing trade was rejected."""
embed = EmbedTemplate.create_base_embed(
title=f"Trade Rejected - {builder.trade.get_trade_summary()}",
title=f"Trade Rejected - {builder.trade.get_trade_summary()}",
description=f"**{rejecting_team.abbrev}** has rejected the trade.\n\n"
f"The trade has been moved back to **DRAFT** status.\n"
f"Teams can continue negotiating using `/trade` commands.",
color=EmbedColors.ERROR
f"The trade has been moved back to **DRAFT** status.\n"
f"Teams can continue negotiating using `/trade` commands.",
color=EmbedColors.ERROR,
)
embed.set_footer(text=f"Trade ID: {builder.trade_id}")
@ -673,37 +673,33 @@ async def create_trade_rejection_embed(builder: TradeBuilder, rejecting_team: Te
return embed
async def create_trade_complete_embed(builder: TradeBuilder, transaction_count: int, effective_week: int) -> discord.Embed:
async def create_trade_complete_embed(
builder: TradeBuilder, transaction_count: int, effective_week: int
) -> discord.Embed:
"""Create embed showing trade was completed."""
embed = EmbedTemplate.create_base_embed(
title=f"🎉 Trade Complete! - {builder.trade.get_trade_summary()}",
description=f"All {builder.team_count} teams have accepted the trade!\n\n"
f"**{transaction_count} transactions** created for **Week {effective_week}**.",
color=EmbedColors.SUCCESS
title=f"Trade Complete - {builder.trade.get_trade_summary()}",
description=f"All {builder.team_count} teams have accepted the trade.\n\n"
f"**{transaction_count} transactions** created for **Week {effective_week}**.",
color=EmbedColors.SUCCESS,
)
# Show final acceptance status (all green)
status_lines = [f"✅ **{team.abbrev}** - Accepted" for team in builder.participating_teams]
embed.add_field(
name="📊 Final Status",
value="\n".join(status_lines),
inline=False
)
status_lines = [
f"**{team.abbrev}** - Accepted" for team in builder.participating_teams
]
embed.add_field(name="Final Status", value="\n".join(status_lines), inline=False)
# Show cross-team moves
if builder.trade.cross_team_moves:
moves_text = ""
for move in builder.trade.cross_team_moves[:8]:
moves_text += f" {move.description}\n"
moves_text += f"- {move.description}\n"
if len(builder.trade.cross_team_moves) > 8:
moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more"
embed.add_field(
name=f"🔄 Player Exchanges",
value=moves_text,
inline=False
)
embed.add_field(name="Player Exchanges", value=moves_text, inline=False)
embed.set_footer(text=f"Trade ID: {builder.trade_id} • Effective: Week {effective_week}")
embed.set_footer(
text=f"Trade ID: {builder.trade_id} | Effective: Week {effective_week}"
)
return embed
@ -718,7 +714,6 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed:
Returns:
Discord embed with current trade state
"""
# Determine embed color based on trade status
if builder.is_empty:
color = EmbedColors.SECONDARY
else:
@ -726,79 +721,79 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed:
color = EmbedColors.SUCCESS if validation.is_legal else EmbedColors.WARNING
embed = EmbedTemplate.create_base_embed(
title=f"📋 Trade Builder - {builder.trade.get_trade_summary()}",
description=f"Build your multi-team trade",
color=color
title=f"Trade Builder - {builder.trade.get_trade_summary()}",
description="Build your multi-team trade",
color=color,
)
# Add participating teams section
team_list = [f"{team.abbrev} - {team.sname}" for team in builder.participating_teams]
team_list = [
f"- {team.abbrev} - {team.sname}" for team in builder.participating_teams
]
embed.add_field(
name=f"🏟️ Participating Teams ({builder.team_count})",
name=f"Participating Teams ({builder.team_count})",
value="\n".join(team_list) if team_list else "*No teams yet*",
inline=False
inline=False,
)
# Add current moves section
if builder.is_empty:
embed.add_field(
name="Current Moves",
value="*No moves yet. Use the `/trade` commands to build your trade.*",
inline=False
inline=False,
)
else:
# Show cross-team moves
if builder.trade.cross_team_moves:
moves_text = ""
for i, move in enumerate(builder.trade.cross_team_moves[:8], 1): # Limit display
for i, move in enumerate(builder.trade.cross_team_moves[:8], 1):
moves_text += f"{i}. {move.description}\n"
if len(builder.trade.cross_team_moves) > 8:
moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more"
embed.add_field(
name=f"🔄 Player Exchanges ({len(builder.trade.cross_team_moves)})",
name=f"Player Exchanges ({len(builder.trade.cross_team_moves)})",
value=moves_text,
inline=False
inline=False,
)
# Show supplementary moves
if builder.trade.supplementary_moves:
supp_text = ""
for i, move in enumerate(builder.trade.supplementary_moves[:5], 1): # Limit display
for i, move in enumerate(builder.trade.supplementary_moves[:5], 1):
supp_text += f"{i}. {move.description}\n"
if len(builder.trade.supplementary_moves) > 5:
supp_text += f"... and {len(builder.trade.supplementary_moves) - 5} more"
supp_text += (
f"... and {len(builder.trade.supplementary_moves) - 5} more"
)
embed.add_field(
name=f"⚙️ Supplementary Moves ({len(builder.trade.supplementary_moves)})",
name=f"Supplementary Moves ({len(builder.trade.supplementary_moves)})",
value=supp_text,
inline=False
inline=False,
)
# Add quick validation summary
validation = await builder.validate_trade()
if validation.is_legal:
status_text = "Trade appears legal"
status_text = "Trade appears legal"
else:
error_count = len(validation.all_errors)
status_text = f"{error_count} error{'s' if error_count != 1 else ''} found"
status_text = f"{error_count} error{'s' if error_count != 1 else ''} found\n"
status_text += "\n".join(f"- {error}" for error in validation.all_errors)
if validation.all_suggestions:
status_text += "\n" + "\n".join(
f"- {s}" for s in validation.all_suggestions
)
embed.add_field(name="Quick Status", value=status_text, inline=False)
embed.add_field(
name="🔍 Quick Status",
value=status_text,
inline=False
name="Build Your Trade",
value="- `/trade add-player` - Add player exchanges\n- `/trade supplementary` - Add internal moves\n- `/trade add-team` - Add more teams",
inline=False,
)
# Add instructions for adding more moves
embed.add_field(
name=" Build Your Trade",
value="• `/trade add-player` - Add player exchanges\n• `/trade supplementary` - Add internal moves\n• `/trade add-team` - Add more teams",
inline=False
embed.set_footer(
text=f"Trade ID: {builder.trade_id} | Created: {datetime.fromisoformat(builder.trade.created_at).strftime('%H:%M:%S')}"
)
# Add footer with trade ID and timestamp
embed.set_footer(text=f"Trade ID: {builder.trade_id} • Created: {datetime.fromisoformat(builder.trade.created_at).strftime('%H:%M:%S')}")
return embed
return embed