fix: prevent partial DB writes on scorecard submission failure #79
@ -210,17 +210,41 @@ class SubmitScorecardCommands(commands.Cog):
|
|||||||
|
|
||||||
game_id = scheduled_game.id
|
game_id = scheduled_game.id
|
||||||
|
|
||||||
# Phase 6: Read Scorecard Data
|
# Phase 6: Read ALL Scorecard Data (before any DB writes)
|
||||||
|
# Reading everything first prevents partial commits if the
|
||||||
|
# spreadsheet has formula errors (e.g. #N/A in pitching decisions)
|
||||||
await interaction.edit_original_response(
|
await interaction.edit_original_response(
|
||||||
content="📊 Reading play-by-play data..."
|
content="📊 Reading scorecard data..."
|
||||||
)
|
)
|
||||||
|
|
||||||
plays_data = await self.sheets_service.read_playtable_data(scorecard)
|
plays_data = await self.sheets_service.read_playtable_data(scorecard)
|
||||||
|
box_score = await self.sheets_service.read_box_score(scorecard)
|
||||||
|
decisions_data = await self.sheets_service.read_pitching_decisions(
|
||||||
|
scorecard
|
||||||
|
)
|
||||||
|
|
||||||
# Add game_id to each play
|
# Add game_id to each play
|
||||||
for play in plays_data:
|
for play in plays_data:
|
||||||
play["game_id"] = game_id
|
play["game_id"] = game_id
|
||||||
|
|
||||||
|
# Add game metadata to each decision
|
||||||
|
for decision in decisions_data:
|
||||||
|
decision["game_id"] = game_id
|
||||||
|
decision["season"] = current.season
|
||||||
|
decision["week"] = setup_data["week"]
|
||||||
|
decision["game_num"] = setup_data["game_num"]
|
||||||
|
|
||||||
|
# Validate WP and LP exist and fetch Player objects
|
||||||
|
wp, lp, sv, holders, _blown_saves = (
|
||||||
|
await decision_service.find_winning_losing_pitchers(decisions_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
if wp is None or lp is None:
|
||||||
|
await interaction.edit_original_response(
|
||||||
|
content="❌ Your card is missing either a Winning Pitcher or Losing Pitcher"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Phase 7: POST Plays
|
# Phase 7: POST Plays
|
||||||
await interaction.edit_original_response(
|
await interaction.edit_original_response(
|
||||||
content="💾 Submitting plays to database..."
|
content="💾 Submitting plays to database..."
|
||||||
@ -244,10 +268,7 @@ class SubmitScorecardCommands(commands.Cog):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Phase 8: Read Box Score
|
# Phase 8: PATCH Game
|
||||||
box_score = await self.sheets_service.read_box_score(scorecard)
|
|
||||||
|
|
||||||
# Phase 9: PATCH Game
|
|
||||||
await interaction.edit_original_response(
|
await interaction.edit_original_response(
|
||||||
content="⚾ Updating game result..."
|
content="⚾ Updating game result..."
|
||||||
)
|
)
|
||||||
@ -275,33 +296,7 @@ class SubmitScorecardCommands(commands.Cog):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Phase 10: Read Pitching Decisions
|
# Phase 9: POST Decisions
|
||||||
decisions_data = await self.sheets_service.read_pitching_decisions(
|
|
||||||
scorecard
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add game metadata to each decision
|
|
||||||
for decision in decisions_data:
|
|
||||||
decision["game_id"] = game_id
|
|
||||||
decision["season"] = current.season
|
|
||||||
decision["week"] = setup_data["week"]
|
|
||||||
decision["game_num"] = setup_data["game_num"]
|
|
||||||
|
|
||||||
# Validate WP and LP exist and fetch Player objects
|
|
||||||
wp, lp, sv, holders, _blown_saves = (
|
|
||||||
await decision_service.find_winning_losing_pitchers(decisions_data)
|
|
||||||
)
|
|
||||||
|
|
||||||
if wp is None or lp is None:
|
|
||||||
# Rollback
|
|
||||||
await game_service.wipe_game_data(game_id)
|
|
||||||
await play_service.delete_plays_for_game(game_id)
|
|
||||||
await interaction.edit_original_response(
|
|
||||||
content="❌ Your card is missing either a Winning Pitcher or Losing Pitcher"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Phase 11: POST Decisions
|
|
||||||
await interaction.edit_original_response(
|
await interaction.edit_original_response(
|
||||||
content="🎯 Submitting pitching decisions..."
|
content="🎯 Submitting pitching decisions..."
|
||||||
)
|
)
|
||||||
@ -361,6 +356,30 @@ class SubmitScorecardCommands(commands.Cog):
|
|||||||
# Success!
|
# Success!
|
||||||
await interaction.edit_original_response(content="✅ You are all set!")
|
await interaction.edit_original_response(content="✅ You are all set!")
|
||||||
|
|
||||||
|
except SheetsException as e:
|
||||||
|
# Spreadsheet reading error - show the detailed message to the user
|
||||||
|
self.logger.error(
|
||||||
|
f"Spreadsheet error in scorecard submission: {e}", error=e
|
||||||
|
)
|
||||||
|
|
||||||
|
if rollback_state and game_id:
|
||||||
|
try:
|
||||||
|
if rollback_state == "GAME_PATCHED":
|
||||||
|
await game_service.wipe_game_data(game_id)
|
||||||
|
await play_service.delete_plays_for_game(game_id)
|
||||||
|
elif rollback_state == "PLAYS_POSTED":
|
||||||
|
await play_service.delete_plays_for_game(game_id)
|
||||||
|
except Exception:
|
||||||
|
pass # Best effort rollback
|
||||||
|
|
||||||
|
await interaction.edit_original_response(
|
||||||
|
content=(
|
||||||
|
f"❌ There's a problem with your scorecard:\n\n"
|
||||||
|
f"{str(e)}\n\n"
|
||||||
|
f"Please fix the issue in your spreadsheet and resubmit."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Unexpected error - attempt rollback
|
# Unexpected error - attempt rollback
|
||||||
self.logger.error(f"Unexpected error in scorecard submission: {e}", error=e)
|
self.logger.error(f"Unexpected error in scorecard submission: {e}", error=e)
|
||||||
|
|||||||
@ -3,6 +3,7 @@ Base service class for Discord Bot v2.0
|
|||||||
|
|
||||||
Provides common CRUD operations and error handling for all data services.
|
Provides common CRUD operations and error handling for all data services.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import Optional, Type, TypeVar, Generic, Dict, Any, List, Tuple
|
from typing import Optional, Type, TypeVar, Generic, Dict, Any, List, Tuple
|
||||||
@ -12,15 +13,15 @@ from models.base import SBABaseModel
|
|||||||
from exceptions import APIException
|
from exceptions import APIException
|
||||||
from utils.cache import CacheManager
|
from utils.cache import CacheManager
|
||||||
|
|
||||||
logger = logging.getLogger(f'{__name__}.BaseService')
|
logger = logging.getLogger(f"{__name__}.BaseService")
|
||||||
|
|
||||||
T = TypeVar('T', bound=SBABaseModel)
|
T = TypeVar("T", bound=SBABaseModel)
|
||||||
|
|
||||||
|
|
||||||
class BaseService(Generic[T]):
|
class BaseService(Generic[T]):
|
||||||
"""
|
"""
|
||||||
Base service class providing common CRUD operations for SBA models.
|
Base service class providing common CRUD operations for SBA models.
|
||||||
|
|
||||||
Features:
|
Features:
|
||||||
- Generic type support for any SBABaseModel subclass
|
- Generic type support for any SBABaseModel subclass
|
||||||
- Automatic model validation and conversion
|
- Automatic model validation and conversion
|
||||||
@ -28,15 +29,17 @@ class BaseService(Generic[T]):
|
|||||||
- API response format handling (count + list format)
|
- API response format handling (count + list format)
|
||||||
- Connection management via global client
|
- Connection management via global client
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
model_class: Type[T],
|
self,
|
||||||
endpoint: str,
|
model_class: Type[T],
|
||||||
client: Optional[APIClient] = None,
|
endpoint: str,
|
||||||
cache_manager: Optional[CacheManager] = None):
|
client: Optional[APIClient] = None,
|
||||||
|
cache_manager: Optional[CacheManager] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize base service.
|
Initialize base service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_class: Pydantic model class for this service
|
model_class: Pydantic model class for this service
|
||||||
endpoint: API endpoint path (e.g., 'players', 'teams')
|
endpoint: API endpoint path (e.g., 'players', 'teams')
|
||||||
@ -48,40 +51,44 @@ class BaseService(Generic[T]):
|
|||||||
self._client = client
|
self._client = client
|
||||||
self._cached_client: Optional[APIClient] = None
|
self._cached_client: Optional[APIClient] = None
|
||||||
self.cache = cache_manager or CacheManager()
|
self.cache = cache_manager or CacheManager()
|
||||||
|
|
||||||
logger.debug(f"Initialized {self.__class__.__name__} for {model_class.__name__} at endpoint '{endpoint}'")
|
logger.debug(
|
||||||
|
f"Initialized {self.__class__.__name__} for {model_class.__name__} at endpoint '{endpoint}'"
|
||||||
def _generate_cache_key(self, method: str, params: Optional[List[Tuple[str, Any]]] = None) -> str:
|
)
|
||||||
|
|
||||||
|
def _generate_cache_key(
|
||||||
|
self, method: str, params: Optional[List[Tuple[str, Any]]] = None
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate consistent cache key for API calls.
|
Generate consistent cache key for API calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
method: API method name
|
method: API method name
|
||||||
params: Query parameters as list of tuples
|
params: Query parameters as list of tuples
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SHA256-hashed cache key
|
SHA256-hashed cache key
|
||||||
"""
|
"""
|
||||||
key_parts = [self.endpoint, method]
|
key_parts = [self.endpoint, method]
|
||||||
|
|
||||||
if params:
|
if params:
|
||||||
# Sort parameters for consistent key generation
|
# Sort parameters for consistent key generation
|
||||||
sorted_params = sorted(params, key=lambda x: str(x[0]))
|
sorted_params = sorted(params, key=lambda x: str(x[0]))
|
||||||
param_str = "&".join([f"{k}={v}" for k, v in sorted_params])
|
param_str = "&".join([f"{k}={v}" for k, v in sorted_params])
|
||||||
key_parts.append(param_str)
|
key_parts.append(param_str)
|
||||||
|
|
||||||
key_data = ":".join(key_parts)
|
key_data = ":".join(key_parts)
|
||||||
key_hash = hashlib.sha256(key_data.encode()).hexdigest()[:16] # First 16 chars
|
key_hash = hashlib.sha256(key_data.encode()).hexdigest()[:16] # First 16 chars
|
||||||
|
|
||||||
return self.cache.cache_key("sba", f"{self.endpoint}_{key_hash}")
|
return self.cache.cache_key("sba", f"{self.endpoint}_{key_hash}")
|
||||||
|
|
||||||
async def _get_cached_items(self, cache_key: str) -> Optional[List[T]]:
|
async def _get_cached_items(self, cache_key: str) -> Optional[List[T]]:
|
||||||
"""
|
"""
|
||||||
Get cached list of model items.
|
Get cached list of model items.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache_key: Cache key to lookup
|
cache_key: Cache key to lookup
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of model instances or None if not cached
|
List of model instances or None if not cached
|
||||||
"""
|
"""
|
||||||
@ -91,13 +98,15 @@ class BaseService(Generic[T]):
|
|||||||
return [self.model_class.from_api_data(item) for item in cached_data]
|
return [self.model_class.from_api_data(item) for item in cached_data]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error deserializing cached data for {cache_key}: {e}")
|
logger.warning(f"Error deserializing cached data for {cache_key}: {e}")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _cache_items(self, cache_key: str, items: List[T], ttl: Optional[int] = None) -> None:
|
async def _cache_items(
|
||||||
|
self, cache_key: str, items: List[T], ttl: Optional[int] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Cache list of model items.
|
Cache list of model items.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache_key: Cache key to store under
|
cache_key: Cache key to store under
|
||||||
items: List of model instances to cache
|
items: List of model instances to cache
|
||||||
@ -105,40 +114,40 @@ class BaseService(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
if not items:
|
if not items:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert to JSON-serializable format
|
# Convert to JSON-serializable format
|
||||||
cache_data = [item.model_dump() for item in items]
|
cache_data = [item.model_dump() for item in items]
|
||||||
await self.cache.set(cache_key, cache_data, ttl)
|
await self.cache.set(cache_key, cache_data, ttl)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error caching items for {cache_key}: {e}")
|
logger.warning(f"Error caching items for {cache_key}: {e}")
|
||||||
|
|
||||||
async def get_client(self) -> APIClient:
|
async def get_client(self) -> APIClient:
|
||||||
"""
|
"""
|
||||||
Get API client instance with caching to reduce async overhead.
|
Get API client instance with caching to reduce async overhead.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
APIClient instance (cached after first access)
|
APIClient instance (cached after first access)
|
||||||
"""
|
"""
|
||||||
if self._client:
|
if self._client:
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
# Cache the global client to avoid repeated async calls
|
# Cache the global client to avoid repeated async calls
|
||||||
if self._cached_client is None:
|
if self._cached_client is None:
|
||||||
self._cached_client = await get_global_client()
|
self._cached_client = await get_global_client()
|
||||||
|
|
||||||
return self._cached_client
|
return self._cached_client
|
||||||
|
|
||||||
async def get_by_id(self, object_id: int) -> Optional[T]:
|
async def get_by_id(self, object_id: int) -> Optional[T]:
|
||||||
"""
|
"""
|
||||||
Get single object by ID.
|
Get single object by ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
object_id: Unique identifier for the object
|
object_id: Unique identifier for the object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model instance or None if not found
|
Model instance or None if not found
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
APIException: For API errors
|
APIException: For API errors
|
||||||
ValueError: For invalid data
|
ValueError: For invalid data
|
||||||
@ -146,167 +155,181 @@ class BaseService(Generic[T]):
|
|||||||
try:
|
try:
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
data = await client.get(self.endpoint, object_id=object_id)
|
data = await client.get(self.endpoint, object_id=object_id)
|
||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
logger.debug(f"{self.model_class.__name__} {object_id} not found")
|
logger.debug(f"{self.model_class.__name__} {object_id} not found")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
model = self.model_class.from_api_data(data)
|
model = self.model_class.from_api_data(data)
|
||||||
logger.debug(f"Retrieved {self.model_class.__name__} {object_id}: {model}")
|
logger.debug(f"Retrieved {self.model_class.__name__} {object_id}: {model}")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
except APIException:
|
except APIException:
|
||||||
logger.error(f"API error retrieving {self.model_class.__name__} {object_id}")
|
logger.error(
|
||||||
|
f"API error retrieving {self.model_class.__name__} {object_id}"
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving {self.model_class.__name__} {object_id}: {e}")
|
logger.error(
|
||||||
|
f"Error retrieving {self.model_class.__name__} {object_id}: {e}"
|
||||||
|
)
|
||||||
raise APIException(f"Failed to retrieve {self.model_class.__name__}: {e}")
|
raise APIException(f"Failed to retrieve {self.model_class.__name__}: {e}")
|
||||||
|
|
||||||
async def get_all(self, params: Optional[List[tuple]] = None) -> Tuple[List[T], int]:
|
async def get_all(
|
||||||
|
self, params: Optional[List[tuple]] = None
|
||||||
|
) -> Tuple[List[T], int]:
|
||||||
"""
|
"""
|
||||||
Get all objects with optional query parameters.
|
Get all objects with optional query parameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: Query parameters as list of (key, value) tuples
|
params: Query parameters as list of (key, value) tuples
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (list of model instances, total count)
|
Tuple of (list of model instances, total count)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
APIException: For API errors
|
APIException: For API errors
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
data = await client.get(self.endpoint, params=params)
|
data = await client.get(self.endpoint, params=params)
|
||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
logger.debug(f"No {self.model_class.__name__} objects found")
|
logger.debug(f"No {self.model_class.__name__} objects found")
|
||||||
return [], 0
|
return [], 0
|
||||||
|
|
||||||
# Handle API response format: {'count': int, '<endpoint>': [...]}
|
# Handle API response format: {'count': int, '<endpoint>': [...]}
|
||||||
items, count = self._extract_items_and_count_from_response(data)
|
items, count = self._extract_items_and_count_from_response(data)
|
||||||
|
|
||||||
models = [self.model_class.from_api_data(item) for item in items]
|
models = [self.model_class.from_api_data(item) for item in items]
|
||||||
logger.debug(f"Retrieved {len(models)} of {count} {self.model_class.__name__} objects")
|
logger.debug(
|
||||||
|
f"Retrieved {len(models)} of {count} {self.model_class.__name__} objects"
|
||||||
|
)
|
||||||
return models, count
|
return models, count
|
||||||
|
|
||||||
except APIException:
|
except APIException:
|
||||||
logger.error(f"API error retrieving {self.model_class.__name__} list")
|
logger.error(f"API error retrieving {self.model_class.__name__} list")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving {self.model_class.__name__} list: {e}")
|
logger.error(f"Error retrieving {self.model_class.__name__} list: {e}")
|
||||||
raise APIException(f"Failed to retrieve {self.model_class.__name__} list: {e}")
|
raise APIException(
|
||||||
|
f"Failed to retrieve {self.model_class.__name__} list: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
async def get_all_items(self, params: Optional[List[tuple]] = None) -> List[T]:
|
async def get_all_items(self, params: Optional[List[tuple]] = None) -> List[T]:
|
||||||
"""
|
"""
|
||||||
Get all objects (convenience method that only returns the list).
|
Get all objects (convenience method that only returns the list).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: Query parameters as list of (key, value) tuples
|
params: Query parameters as list of (key, value) tuples
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of model instances
|
List of model instances
|
||||||
"""
|
"""
|
||||||
items, _ = await self.get_all(params=params)
|
items, _ = await self.get_all(params=params)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
async def create(self, model_data: Dict[str, Any]) -> Optional[T]:
|
async def create(self, model_data: Dict[str, Any]) -> Optional[T]:
|
||||||
"""
|
"""
|
||||||
Create new object from data dictionary.
|
Create new object from data dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_data: Dictionary of model fields
|
model_data: Dictionary of model fields
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Created model instance or None
|
Created model instance or None
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
APIException: For API errors
|
APIException: For API errors
|
||||||
ValueError: For invalid data
|
ValueError: For invalid data
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
response = await client.post(self.endpoint, model_data)
|
response = await client.post(f"{self.endpoint}/", model_data)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
logger.warning(f"No response from {self.model_class.__name__} creation")
|
logger.warning(f"No response from {self.model_class.__name__} creation")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
model = self.model_class.from_api_data(response)
|
model = self.model_class.from_api_data(response)
|
||||||
logger.debug(f"Created {self.model_class.__name__}: {model}")
|
logger.debug(f"Created {self.model_class.__name__}: {model}")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
except APIException:
|
except APIException:
|
||||||
logger.error(f"API error creating {self.model_class.__name__}")
|
logger.error(f"API error creating {self.model_class.__name__}")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating {self.model_class.__name__}: {e}")
|
logger.error(f"Error creating {self.model_class.__name__}: {e}")
|
||||||
raise APIException(f"Failed to create {self.model_class.__name__}: {e}")
|
raise APIException(f"Failed to create {self.model_class.__name__}: {e}")
|
||||||
|
|
||||||
async def create_from_model(self, model: T) -> Optional[T]:
|
async def create_from_model(self, model: T) -> Optional[T]:
|
||||||
"""
|
"""
|
||||||
Create new object from model instance.
|
Create new object from model instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Model instance to create
|
model: Model instance to create
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Created model instance or None
|
Created model instance or None
|
||||||
"""
|
"""
|
||||||
return await self.create(model.to_dict(exclude_none=True))
|
return await self.create(model.to_dict(exclude_none=True))
|
||||||
|
|
||||||
async def update(self, object_id: int, model_data: Dict[str, Any]) -> Optional[T]:
|
async def update(self, object_id: int, model_data: Dict[str, Any]) -> Optional[T]:
|
||||||
"""
|
"""
|
||||||
Update existing object.
|
Update existing object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
object_id: ID of object to update
|
object_id: ID of object to update
|
||||||
model_data: Dictionary of fields to update
|
model_data: Dictionary of fields to update
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated model instance or None if not found
|
Updated model instance or None if not found
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
APIException: For API errors
|
APIException: For API errors
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
response = await client.put(self.endpoint, model_data, object_id=object_id)
|
response = await client.put(self.endpoint, model_data, object_id=object_id)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
logger.debug(f"{self.model_class.__name__} {object_id} not found for update")
|
logger.debug(
|
||||||
|
f"{self.model_class.__name__} {object_id} not found for update"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
model = self.model_class.from_api_data(response)
|
model = self.model_class.from_api_data(response)
|
||||||
logger.debug(f"Updated {self.model_class.__name__} {object_id}: {model}")
|
logger.debug(f"Updated {self.model_class.__name__} {object_id}: {model}")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
except APIException:
|
except APIException:
|
||||||
logger.error(f"API error updating {self.model_class.__name__} {object_id}")
|
logger.error(f"API error updating {self.model_class.__name__} {object_id}")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}")
|
logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}")
|
||||||
raise APIException(f"Failed to update {self.model_class.__name__}: {e}")
|
raise APIException(f"Failed to update {self.model_class.__name__}: {e}")
|
||||||
|
|
||||||
async def update_from_model(self, model: T) -> Optional[T]:
|
async def update_from_model(self, model: T) -> Optional[T]:
|
||||||
"""
|
"""
|
||||||
Update object from model instance.
|
Update object from model instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Model instance to update (must have ID)
|
model: Model instance to update (must have ID)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated model instance or None
|
Updated model instance or None
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If model has no ID
|
ValueError: If model has no ID
|
||||||
"""
|
"""
|
||||||
if not model.id:
|
if not model.id:
|
||||||
raise ValueError(f"Cannot update {self.model_class.__name__} without ID")
|
raise ValueError(f"Cannot update {self.model_class.__name__} without ID")
|
||||||
|
|
||||||
return await self.update(model.id, model.to_dict(exclude_none=True))
|
return await self.update(model.id, model.to_dict(exclude_none=True))
|
||||||
|
|
||||||
async def patch(self, object_id: int, model_data: Dict[str, Any], use_query_params: bool = False) -> Optional[T]:
|
async def patch(
|
||||||
|
self, object_id: int, model_data: Dict[str, Any], use_query_params: bool = False
|
||||||
|
) -> Optional[T]:
|
||||||
"""
|
"""
|
||||||
Update existing object with HTTP PATCH.
|
Update existing object with HTTP PATCH.
|
||||||
|
|
||||||
@ -323,10 +346,14 @@ class BaseService(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
response = await client.patch(self.endpoint, model_data, object_id, use_query_params=use_query_params)
|
response = await client.patch(
|
||||||
|
self.endpoint, model_data, object_id, use_query_params=use_query_params
|
||||||
|
)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
logger.debug(f"{self.model_class.__name__} {object_id} not found for update")
|
logger.debug(
|
||||||
|
f"{self.model_class.__name__} {object_id} not found for update"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
model = self.model_class.from_api_data(response)
|
model = self.model_class.from_api_data(response)
|
||||||
@ -340,134 +367,142 @@ class BaseService(Generic[T]):
|
|||||||
logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}")
|
logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}")
|
||||||
raise APIException(f"Failed to update {self.model_class.__name__}: {e}")
|
raise APIException(f"Failed to update {self.model_class.__name__}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def delete(self, object_id: int) -> bool:
|
async def delete(self, object_id: int) -> bool:
|
||||||
"""
|
"""
|
||||||
Delete object by ID.
|
Delete object by ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
object_id: ID of object to delete
|
object_id: ID of object to delete
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if deleted, False if not found
|
True if deleted, False if not found
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
APIException: For API errors
|
APIException: For API errors
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
success = await client.delete(self.endpoint, object_id=object_id)
|
success = await client.delete(self.endpoint, object_id=object_id)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.debug(f"Deleted {self.model_class.__name__} {object_id}")
|
logger.debug(f"Deleted {self.model_class.__name__} {object_id}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"{self.model_class.__name__} {object_id} not found for deletion")
|
logger.debug(
|
||||||
|
f"{self.model_class.__name__} {object_id} not found for deletion"
|
||||||
|
)
|
||||||
|
|
||||||
return success
|
return success
|
||||||
|
|
||||||
except APIException:
|
except APIException:
|
||||||
logger.error(f"API error deleting {self.model_class.__name__} {object_id}")
|
logger.error(f"API error deleting {self.model_class.__name__} {object_id}")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting {self.model_class.__name__} {object_id}: {e}")
|
logger.error(f"Error deleting {self.model_class.__name__} {object_id}: {e}")
|
||||||
raise APIException(f"Failed to delete {self.model_class.__name__}: {e}")
|
raise APIException(f"Failed to delete {self.model_class.__name__}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def get_by_field(self, field: str, value: Any) -> List[T]:
|
async def get_by_field(self, field: str, value: Any) -> List[T]:
|
||||||
"""
|
"""
|
||||||
Get objects by specific field value.
|
Get objects by specific field value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
field: Field name to search
|
field: Field name to search
|
||||||
value: Field value to match
|
value: Field value to match
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of matching model instances
|
List of matching model instances
|
||||||
"""
|
"""
|
||||||
params = [(field, str(value))]
|
params = [(field, str(value))]
|
||||||
return await self.get_all_items(params=params)
|
return await self.get_all_items(params=params)
|
||||||
|
|
||||||
async def count(self, params: Optional[List[tuple]] = None) -> int:
|
async def count(self, params: Optional[List[tuple]] = None) -> int:
|
||||||
"""
|
"""
|
||||||
Get count of objects matching parameters.
|
Get count of objects matching parameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: Query parameters
|
params: Query parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of matching objects (from API count field)
|
Number of matching objects (from API count field)
|
||||||
"""
|
"""
|
||||||
_, count = await self.get_all(params=params)
|
_, count = await self.get_all(params=params)
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def _extract_items_and_count_from_response(self, data: Any) -> Tuple[List[Dict[str, Any]], int]:
|
def _extract_items_and_count_from_response(
|
||||||
|
self, data: Any
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
"""
|
"""
|
||||||
Extract items list and count from API response with optimized parsing.
|
Extract items list and count from API response with optimized parsing.
|
||||||
|
|
||||||
Expected format: {'count': int, '<endpoint>': [...]}
|
Expected format: {'count': int, '<endpoint>': [...]}
|
||||||
Single object format: {'id': 1, 'name': '...'}
|
Single object format: {'id': 1, 'name': '...'}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: API response data
|
data: API response data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (items list, total count)
|
Tuple of (items list, total count)
|
||||||
"""
|
"""
|
||||||
if isinstance(data, list):
|
if isinstance(data, list):
|
||||||
return data, len(data)
|
return data, len(data)
|
||||||
|
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
logger.warning(f"Unexpected response format for {self.model_class.__name__}: {type(data)}")
|
logger.warning(
|
||||||
|
f"Unexpected response format for {self.model_class.__name__}: {type(data)}"
|
||||||
|
)
|
||||||
return [], 0
|
return [], 0
|
||||||
|
|
||||||
# Single pass through the response dict - get count first
|
# Single pass through the response dict - get count first
|
||||||
count = data.get('count', 0)
|
count = data.get("count", 0)
|
||||||
|
|
||||||
# Priority order for finding items list (most common first)
|
# Priority order for finding items list (most common first)
|
||||||
field_candidates = [self.endpoint, 'items', 'data', 'results']
|
field_candidates = [self.endpoint, "items", "data", "results"]
|
||||||
for field_name in field_candidates:
|
for field_name in field_candidates:
|
||||||
if field_name in data and isinstance(data[field_name], list):
|
if field_name in data and isinstance(data[field_name], list):
|
||||||
return data[field_name], count or len(data[field_name])
|
return data[field_name], count or len(data[field_name])
|
||||||
|
|
||||||
# Single object response (check for common identifying fields)
|
# Single object response (check for common identifying fields)
|
||||||
if any(key in data for key in ['id', 'name', 'abbrev']):
|
if any(key in data for key in ["id", "name", "abbrev"]):
|
||||||
return [data], 1
|
return [data], 1
|
||||||
|
|
||||||
return [], count
|
return [], count
|
||||||
|
|
||||||
async def get_items_with_params(self, params: Optional[List[tuple]] = None) -> List[T]:
|
async def get_items_with_params(
|
||||||
|
self, params: Optional[List[tuple]] = None
|
||||||
|
) -> List[T]:
|
||||||
"""
|
"""
|
||||||
Get all items with parameters (alias for get_all_items for compatibility).
|
Get all items with parameters (alias for get_all_items for compatibility).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: Query parameters as list of (key, value) tuples
|
params: Query parameters as list of (key, value) tuples
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of model instances
|
List of model instances
|
||||||
"""
|
"""
|
||||||
return await self.get_all_items(params=params)
|
return await self.get_all_items(params=params)
|
||||||
|
|
||||||
async def create_item(self, model_data: Dict[str, Any]) -> Optional[T]:
|
async def create_item(self, model_data: Dict[str, Any]) -> Optional[T]:
|
||||||
"""
|
"""
|
||||||
Create item (alias for create for compatibility).
|
Create item (alias for create for compatibility).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_data: Dictionary of model fields
|
model_data: Dictionary of model fields
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Created model instance or None
|
Created model instance or None
|
||||||
"""
|
"""
|
||||||
return await self.create(model_data)
|
return await self.create(model_data)
|
||||||
|
|
||||||
async def update_item_by_field(self, field: str, value: Any, update_data: Dict[str, Any]) -> Optional[T]:
|
async def update_item_by_field(
|
||||||
|
self, field: str, value: Any, update_data: Dict[str, Any]
|
||||||
|
) -> Optional[T]:
|
||||||
"""
|
"""
|
||||||
Update item by field value.
|
Update item by field value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
field: Field name to search by
|
field: Field name to search by
|
||||||
value: Field value to match
|
value: Field value to match
|
||||||
update_data: Data to update
|
update_data: Data to update
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated model instance or None if not found
|
Updated model instance or None if not found
|
||||||
"""
|
"""
|
||||||
@ -475,22 +510,22 @@ class BaseService(Generic[T]):
|
|||||||
items = await self.get_by_field(field, value)
|
items = await self.get_by_field(field, value)
|
||||||
if not items:
|
if not items:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Update the first matching item
|
# Update the first matching item
|
||||||
item = items[0]
|
item = items[0]
|
||||||
if not item.id:
|
if not item.id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return await self.update(item.id, update_data)
|
return await self.update(item.id, update_data)
|
||||||
|
|
||||||
async def delete_item_by_field(self, field: str, value: Any) -> bool:
|
async def delete_item_by_field(self, field: str, value: Any) -> bool:
|
||||||
"""
|
"""
|
||||||
Delete item by field value.
|
Delete item by field value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
field: Field name to search by
|
field: Field name to search by
|
||||||
value: Field value to match
|
value: Field value to match
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if deleted, False if not found
|
True if deleted, False if not found
|
||||||
"""
|
"""
|
||||||
@ -498,62 +533,41 @@ class BaseService(Generic[T]):
|
|||||||
items = await self.get_by_field(field, value)
|
items = await self.get_by_field(field, value)
|
||||||
if not items:
|
if not items:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Delete the first matching item
|
# Delete the first matching item
|
||||||
item = items[0]
|
item = items[0]
|
||||||
if not item.id:
|
if not item.id:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await self.delete(item.id)
|
return await self.delete(item.id)
|
||||||
|
|
||||||
async def create_item_in_table(self, table_name: str, item_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
async def get_items_from_table_with_params(
|
||||||
"""
|
self, table_name: str, params: List[tuple]
|
||||||
Create item in a specific table (simplified for custom commands service).
|
) -> List[Dict[str, Any]]:
|
||||||
This is a placeholder - real implementation would need table-specific endpoints.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
table_name: Name of the table
|
|
||||||
item_data: Data to create
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created item data or None
|
|
||||||
"""
|
|
||||||
# For now, use the main endpoint - this would need proper implementation
|
|
||||||
# for different tables like 'custom_command_creators'
|
|
||||||
try:
|
|
||||||
client = await self.get_client()
|
|
||||||
# Use table name as endpoint for now
|
|
||||||
response = await client.post(table_name, item_data)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating item in table {table_name}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_items_from_table_with_params(self, table_name: str, params: List[tuple]) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
"""
|
||||||
Get items from a specific table with parameters.
|
Get items from a specific table with parameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table_name: Name of the table
|
table_name: Name of the table
|
||||||
params: Query parameters
|
params: Query parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of item dictionaries
|
List of item dictionaries
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
data = await client.get(table_name, params=params)
|
data = await client.get(table_name, params=params)
|
||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Handle response format
|
# Handle response format
|
||||||
items, _ = self._extract_items_and_count_from_response(data)
|
items, _ = self._extract_items_and_count_from_response(data)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting items from table {table_name}: {e}")
|
logger.error(f"Error getting items from table {table_name}: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"{self.__class__.__name__}(model={self.model_class.__name__}, endpoint='{self.endpoint}')"
|
return f"{self.__class__.__name__}(model={self.model_class.__name__}, endpoint='{self.endpoint}')"
|
||||||
|
|||||||
@ -552,9 +552,8 @@ class CustomCommandsService(BaseService[CustomCommand]):
|
|||||||
"active_commands": 0,
|
"active_commands": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
result = await self.create_item_in_table(
|
client = await self.get_client()
|
||||||
"custom_commands/creators", creator_data
|
result = await client.post("custom_commands/creators", creator_data)
|
||||||
)
|
|
||||||
if not result:
|
if not result:
|
||||||
raise BotException("Failed to create command creator")
|
raise BotException("Failed to create command creator")
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@ Decision Service
|
|||||||
|
|
||||||
Manages pitching decision operations for game submission.
|
Manages pitching decision operations for game submission.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Dict, Any, Optional, Tuple
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
|
||||||
from utils.logging import get_contextual_logger
|
from utils.logging import get_contextual_logger
|
||||||
@ -16,17 +17,14 @@ class DecisionService:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize decision service."""
|
"""Initialize decision service."""
|
||||||
self.logger = get_contextual_logger(f'{__name__}.DecisionService')
|
self.logger = get_contextual_logger(f"{__name__}.DecisionService")
|
||||||
self._get_client = get_global_client
|
self._get_client = get_global_client
|
||||||
|
|
||||||
async def get_client(self):
|
async def get_client(self):
|
||||||
"""Get the API client."""
|
"""Get the API client."""
|
||||||
return await self._get_client()
|
return await self._get_client()
|
||||||
|
|
||||||
async def create_decisions_batch(
|
async def create_decisions_batch(self, decisions: List[Dict[str, Any]]) -> bool:
|
||||||
self,
|
|
||||||
decisions: List[Dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
"""
|
"""
|
||||||
POST batch of decisions to /decisions endpoint.
|
POST batch of decisions to /decisions endpoint.
|
||||||
|
|
||||||
@ -42,8 +40,10 @@ class DecisionService:
|
|||||||
try:
|
try:
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
|
|
||||||
payload = {'decisions': decisions}
|
payload = {"decisions": decisions}
|
||||||
await client.post('decisions', payload)
|
# Trailing slash required: without it, the server returns a 307 redirect
|
||||||
|
# and aiohttp drops the POST body when following the redirect
|
||||||
|
await client.post("decisions/", payload)
|
||||||
|
|
||||||
self.logger.info(f"Created {len(decisions)} decisions")
|
self.logger.info(f"Created {len(decisions)} decisions")
|
||||||
return True
|
return True
|
||||||
@ -70,7 +70,7 @@ class DecisionService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
await client.delete(f'decisions/game/{game_id}')
|
await client.delete(f"decisions/game/{game_id}")
|
||||||
|
|
||||||
self.logger.info(f"Deleted decisions for game {game_id}")
|
self.logger.info(f"Deleted decisions for game {game_id}")
|
||||||
return True
|
return True
|
||||||
@ -80,9 +80,10 @@ class DecisionService:
|
|||||||
raise APIException(f"Failed to delete decisions: {e}")
|
raise APIException(f"Failed to delete decisions: {e}")
|
||||||
|
|
||||||
async def find_winning_losing_pitchers(
|
async def find_winning_losing_pitchers(
|
||||||
self,
|
self, decisions_data: List[Dict[str, Any]]
|
||||||
decisions_data: List[Dict[str, Any]]
|
) -> Tuple[
|
||||||
) -> Tuple[Optional[Player], Optional[Player], Optional[Player], List[Player], List[Player]]:
|
Optional[Player], Optional[Player], Optional[Player], List[Player], List[Player]
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Extract WP, LP, SV, Holds, Blown Saves from decisions list and fetch Player objects.
|
Extract WP, LP, SV, Holds, Blown Saves from decisions list and fetch Player objects.
|
||||||
|
|
||||||
@ -110,17 +111,17 @@ class DecisionService:
|
|||||||
|
|
||||||
# First pass: Extract IDs
|
# First pass: Extract IDs
|
||||||
for decision in decisions_data:
|
for decision in decisions_data:
|
||||||
pitcher_id = int(decision.get('pitcher_id', 0))
|
pitcher_id = int(decision.get("pitcher_id", 0))
|
||||||
|
|
||||||
if int(decision.get('win', 0)) == 1:
|
if int(decision.get("win", 0)) == 1:
|
||||||
wp_id = pitcher_id
|
wp_id = pitcher_id
|
||||||
if int(decision.get('loss', 0)) == 1:
|
if int(decision.get("loss", 0)) == 1:
|
||||||
lp_id = pitcher_id
|
lp_id = pitcher_id
|
||||||
if int(decision.get('is_save', 0)) == 1:
|
if int(decision.get("is_save", 0)) == 1:
|
||||||
sv_id = pitcher_id
|
sv_id = pitcher_id
|
||||||
if int(decision.get('hold', 0)) == 1:
|
if int(decision.get("hold", 0)) == 1:
|
||||||
hold_ids.append(pitcher_id)
|
hold_ids.append(pitcher_id)
|
||||||
if int(decision.get('b_save', 0)) == 1:
|
if int(decision.get("b_save", 0)) == 1:
|
||||||
bsv_ids.append(pitcher_id)
|
bsv_ids.append(pitcher_id)
|
||||||
|
|
||||||
# Second pass: Fetch Player objects
|
# Second pass: Fetch Player objects
|
||||||
@ -154,9 +155,9 @@ class DecisionService:
|
|||||||
"""
|
"""
|
||||||
error_str = str(error)
|
error_str = str(error)
|
||||||
|
|
||||||
if 'Player ID' in error_str and 'not found' in error_str:
|
if "Player ID" in error_str and "not found" in error_str:
|
||||||
return "Invalid pitcher ID in decision data."
|
return "Invalid pitcher ID in decision data."
|
||||||
elif 'Game ID' in error_str and 'not found' in error_str:
|
elif "Game ID" in error_str and "not found" in error_str:
|
||||||
return "Game not found for decisions."
|
return "Game not found for decisions."
|
||||||
else:
|
else:
|
||||||
return f"Error submitting decisions: {error_str}"
|
return f"Error submitting decisions: {error_str}"
|
||||||
|
|||||||
@ -3,13 +3,14 @@ Draft list service for Discord Bot v2.0
|
|||||||
|
|
||||||
Handles team draft list (auto-draft queue) operations. NO CACHING - lists change frequently.
|
Handles team draft list (auto-draft queue) operations. NO CACHING - lists change frequently.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from services.base_service import BaseService
|
from services.base_service import BaseService
|
||||||
from models.draft_list import DraftList
|
from models.draft_list import DraftList
|
||||||
|
|
||||||
logger = logging.getLogger(f'{__name__}.DraftListService')
|
logger = logging.getLogger(f"{__name__}.DraftListService")
|
||||||
|
|
||||||
|
|
||||||
class DraftListService(BaseService[DraftList]):
|
class DraftListService(BaseService[DraftList]):
|
||||||
@ -32,7 +33,7 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize draft list service."""
|
"""Initialize draft list service."""
|
||||||
super().__init__(DraftList, 'draftlist')
|
super().__init__(DraftList, "draftlist")
|
||||||
logger.debug("DraftListService initialized")
|
logger.debug("DraftListService initialized")
|
||||||
|
|
||||||
def _extract_items_and_count_from_response(self, data):
|
def _extract_items_and_count_from_response(self, data):
|
||||||
@ -54,20 +55,16 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
return [], 0
|
return [], 0
|
||||||
|
|
||||||
# Get count
|
# Get count
|
||||||
count = data.get('count', 0)
|
count = data.get("count", 0)
|
||||||
|
|
||||||
# API returns items under 'picks' key (not 'draftlist')
|
# API returns items under 'picks' key (not 'draftlist')
|
||||||
if 'picks' in data and isinstance(data['picks'], list):
|
if "picks" in data and isinstance(data["picks"], list):
|
||||||
return data['picks'], count or len(data['picks'])
|
return data["picks"], count or len(data["picks"])
|
||||||
|
|
||||||
# Fallback to standard extraction
|
# Fallback to standard extraction
|
||||||
return super()._extract_items_and_count_from_response(data)
|
return super()._extract_items_and_count_from_response(data)
|
||||||
|
|
||||||
async def get_team_list(
|
async def get_team_list(self, season: int, team_id: int) -> List[DraftList]:
|
||||||
self,
|
|
||||||
season: int,
|
|
||||||
team_id: int
|
|
||||||
) -> List[DraftList]:
|
|
||||||
"""
|
"""
|
||||||
Get team's draft list ordered by rank.
|
Get team's draft list ordered by rank.
|
||||||
|
|
||||||
@ -82,8 +79,8 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
params = [
|
params = [
|
||||||
('season', str(season)),
|
("season", str(season)),
|
||||||
('team_id', str(team_id))
|
("team_id", str(team_id)),
|
||||||
# NOTE: API does not support 'sort' param - results must be sorted client-side
|
# NOTE: API does not support 'sort' param - results must be sorted client-side
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -100,11 +97,7 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
async def add_to_list(
|
async def add_to_list(
|
||||||
self,
|
self, season: int, team_id: int, player_id: int, rank: Optional[int] = None
|
||||||
season: int,
|
|
||||||
team_id: int,
|
|
||||||
player_id: int,
|
|
||||||
rank: Optional[int] = None
|
|
||||||
) -> Optional[List[DraftList]]:
|
) -> Optional[List[DraftList]]:
|
||||||
"""
|
"""
|
||||||
Add player to team's draft list.
|
Add player to team's draft list.
|
||||||
@ -133,10 +126,10 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
|
|
||||||
# Create new entry data
|
# Create new entry data
|
||||||
new_entry_data = {
|
new_entry_data = {
|
||||||
'season': season,
|
"season": season,
|
||||||
'team_id': team_id,
|
"team_id": team_id,
|
||||||
'player_id': player_id,
|
"player_id": player_id,
|
||||||
'rank': rank
|
"rank": rank,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Build complete list for bulk replacement
|
# Build complete list for bulk replacement
|
||||||
@ -146,36 +139,42 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
for entry in current_list:
|
for entry in current_list:
|
||||||
if entry.rank >= rank:
|
if entry.rank >= rank:
|
||||||
# Shift down entries at or after insertion point
|
# Shift down entries at or after insertion point
|
||||||
draft_list_entries.append({
|
draft_list_entries.append(
|
||||||
'season': entry.season,
|
{
|
||||||
'team_id': entry.team_id,
|
"season": entry.season,
|
||||||
'player_id': entry.player_id,
|
"team_id": entry.team_id,
|
||||||
'rank': entry.rank + 1
|
"player_id": entry.player_id,
|
||||||
})
|
"rank": entry.rank + 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Keep existing rank for entries before insertion point
|
# Keep existing rank for entries before insertion point
|
||||||
draft_list_entries.append({
|
draft_list_entries.append(
|
||||||
'season': entry.season,
|
{
|
||||||
'team_id': entry.team_id,
|
"season": entry.season,
|
||||||
'player_id': entry.player_id,
|
"team_id": entry.team_id,
|
||||||
'rank': entry.rank
|
"player_id": entry.player_id,
|
||||||
})
|
"rank": entry.rank,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Add new entry
|
# Add new entry
|
||||||
draft_list_entries.append(new_entry_data)
|
draft_list_entries.append(new_entry_data)
|
||||||
|
|
||||||
# Sort by rank for consistency
|
# Sort by rank for consistency
|
||||||
draft_list_entries.sort(key=lambda x: x['rank'])
|
draft_list_entries.sort(key=lambda x: x["rank"])
|
||||||
|
|
||||||
# POST entire list (bulk replacement)
|
# POST entire list (bulk replacement)
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
payload = {
|
payload = {
|
||||||
'count': len(draft_list_entries),
|
"count": len(draft_list_entries),
|
||||||
'draft_list': draft_list_entries
|
"draft_list": draft_list_entries,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug(f"Posting draft list for team {team_id}: {len(draft_list_entries)} entries")
|
logger.debug(
|
||||||
response = await client.post(self.endpoint, payload)
|
f"Posting draft list for team {team_id}: {len(draft_list_entries)} entries"
|
||||||
|
)
|
||||||
|
response = await client.post(f"{self.endpoint}/", payload)
|
||||||
logger.debug(f"POST response: {response}")
|
logger.debug(f"POST response: {response}")
|
||||||
|
|
||||||
# Verify by fetching the list back (API returns full objects)
|
# Verify by fetching the list back (API returns full objects)
|
||||||
@ -184,20 +183,21 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
|
|
||||||
# Verify the player was added
|
# Verify the player was added
|
||||||
if not any(entry.player_id == player_id for entry in verification):
|
if not any(entry.player_id == player_id for entry in verification):
|
||||||
logger.error(f"Player {player_id} not found in list after POST - operation may have failed")
|
logger.error(
|
||||||
|
f"Player {player_id} not found in list after POST - operation may have failed"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logger.info(f"Added player {player_id} to team {team_id} draft list at rank {rank}")
|
logger.info(
|
||||||
|
f"Added player {player_id} to team {team_id} draft list at rank {rank}"
|
||||||
|
)
|
||||||
return verification # Return full updated list
|
return verification # Return full updated list
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error adding player {player_id} to draft list: {e}")
|
logger.error(f"Error adding player {player_id} to draft list: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def remove_from_list(
|
async def remove_from_list(self, entry_id: int) -> bool:
|
||||||
self,
|
|
||||||
entry_id: int
|
|
||||||
) -> bool:
|
|
||||||
"""
|
"""
|
||||||
Remove entry from draft list by ID.
|
Remove entry from draft list by ID.
|
||||||
|
|
||||||
@ -209,14 +209,13 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
Returns:
|
Returns:
|
||||||
True if deletion succeeded
|
True if deletion succeeded
|
||||||
"""
|
"""
|
||||||
logger.warning("remove_from_list() called with entry_id - use remove_player_from_list() instead")
|
logger.warning(
|
||||||
|
"remove_from_list() called with entry_id - use remove_player_from_list() instead"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def remove_player_from_list(
|
async def remove_player_from_list(
|
||||||
self,
|
self, season: int, team_id: int, player_id: int
|
||||||
season: int,
|
|
||||||
team_id: int,
|
|
||||||
player_id: int
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Remove specific player from team's draft list.
|
Remove specific player from team's draft list.
|
||||||
@ -238,7 +237,9 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
# Check if player is in list
|
# Check if player is in list
|
||||||
player_found = any(entry.player_id == player_id for entry in current_list)
|
player_found = any(entry.player_id == player_id for entry in current_list)
|
||||||
if not player_found:
|
if not player_found:
|
||||||
logger.warning(f"Player {player_id} not found in team {team_id} draft list")
|
logger.warning(
|
||||||
|
f"Player {player_id} not found in team {team_id} draft list"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Build new list without the player, adjusting ranks
|
# Build new list without the player, adjusting ranks
|
||||||
@ -246,22 +247,24 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
new_rank = 1
|
new_rank = 1
|
||||||
for entry in current_list:
|
for entry in current_list:
|
||||||
if entry.player_id != player_id:
|
if entry.player_id != player_id:
|
||||||
draft_list_entries.append({
|
draft_list_entries.append(
|
||||||
'season': entry.season,
|
{
|
||||||
'team_id': entry.team_id,
|
"season": entry.season,
|
||||||
'player_id': entry.player_id,
|
"team_id": entry.team_id,
|
||||||
'rank': new_rank
|
"player_id": entry.player_id,
|
||||||
})
|
"rank": new_rank,
|
||||||
|
}
|
||||||
|
)
|
||||||
new_rank += 1
|
new_rank += 1
|
||||||
|
|
||||||
# POST updated list (bulk replacement)
|
# POST updated list (bulk replacement)
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
payload = {
|
payload = {
|
||||||
'count': len(draft_list_entries),
|
"count": len(draft_list_entries),
|
||||||
'draft_list': draft_list_entries
|
"draft_list": draft_list_entries,
|
||||||
}
|
}
|
||||||
|
|
||||||
await client.post(self.endpoint, payload)
|
await client.post(f"{self.endpoint}/", payload)
|
||||||
logger.info(f"Removed player {player_id} from team {team_id} draft list")
|
logger.info(f"Removed player {player_id} from team {team_id} draft list")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -270,11 +273,7 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
logger.error(f"Error removing player {player_id} from draft list: {e}")
|
logger.error(f"Error removing player {player_id} from draft list: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def clear_list(
|
async def clear_list(self, season: int, team_id: int) -> bool:
|
||||||
self,
|
|
||||||
season: int,
|
|
||||||
team_id: int
|
|
||||||
) -> bool:
|
|
||||||
"""
|
"""
|
||||||
Clear entire draft list for team.
|
Clear entire draft list for team.
|
||||||
|
|
||||||
@ -309,10 +308,7 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def reorder_list(
|
async def reorder_list(
|
||||||
self,
|
self, season: int, team_id: int, new_order: List[int]
|
||||||
season: int,
|
|
||||||
team_id: int,
|
|
||||||
new_order: List[int]
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Reorder team's draft list.
|
Reorder team's draft list.
|
||||||
@ -342,21 +338,23 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
entry = entry_map[player_id]
|
entry = entry_map[player_id]
|
||||||
draft_list_entries.append({
|
draft_list_entries.append(
|
||||||
'season': entry.season,
|
{
|
||||||
'team_id': entry.team_id,
|
"season": entry.season,
|
||||||
'player_id': entry.player_id,
|
"team_id": entry.team_id,
|
||||||
'rank': new_rank
|
"player_id": entry.player_id,
|
||||||
})
|
"rank": new_rank,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# POST reordered list (bulk replacement)
|
# POST reordered list (bulk replacement)
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
payload = {
|
payload = {
|
||||||
'count': len(draft_list_entries),
|
"count": len(draft_list_entries),
|
||||||
'draft_list': draft_list_entries
|
"draft_list": draft_list_entries,
|
||||||
}
|
}
|
||||||
|
|
||||||
await client.post(self.endpoint, payload)
|
await client.post(f"{self.endpoint}/", payload)
|
||||||
logger.info(f"Reordered draft list for team {team_id}")
|
logger.info(f"Reordered draft list for team {team_id}")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -365,12 +363,7 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
logger.error(f"Error reordering draft list for team {team_id}: {e}")
|
logger.error(f"Error reordering draft list for team {team_id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def move_entry_up(
|
async def move_entry_up(self, season: int, team_id: int, player_id: int) -> bool:
|
||||||
self,
|
|
||||||
season: int,
|
|
||||||
team_id: int,
|
|
||||||
player_id: int
|
|
||||||
) -> bool:
|
|
||||||
"""
|
"""
|
||||||
Move player up one position in draft list (higher priority).
|
Move player up one position in draft list (higher priority).
|
||||||
|
|
||||||
@ -403,7 +396,9 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Find entry above (rank - 1)
|
# Find entry above (rank - 1)
|
||||||
above_entry = next((e for e in entries if e.rank == current_entry.rank - 1), None)
|
above_entry = next(
|
||||||
|
(e for e in entries if e.rank == current_entry.rank - 1), None
|
||||||
|
)
|
||||||
if not above_entry:
|
if not above_entry:
|
||||||
logger.error(f"Could not find entry above rank {current_entry.rank}")
|
logger.error(f"Could not find entry above rank {current_entry.rank}")
|
||||||
return False
|
return False
|
||||||
@ -421,24 +416,26 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
# Keep existing rank
|
# Keep existing rank
|
||||||
new_rank = entry.rank
|
new_rank = entry.rank
|
||||||
|
|
||||||
draft_list_entries.append({
|
draft_list_entries.append(
|
||||||
'season': entry.season,
|
{
|
||||||
'team_id': entry.team_id,
|
"season": entry.season,
|
||||||
'player_id': entry.player_id,
|
"team_id": entry.team_id,
|
||||||
'rank': new_rank
|
"player_id": entry.player_id,
|
||||||
})
|
"rank": new_rank,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Sort by rank
|
# Sort by rank
|
||||||
draft_list_entries.sort(key=lambda x: x['rank'])
|
draft_list_entries.sort(key=lambda x: x["rank"])
|
||||||
|
|
||||||
# POST updated list (bulk replacement)
|
# POST updated list (bulk replacement)
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
payload = {
|
payload = {
|
||||||
'count': len(draft_list_entries),
|
"count": len(draft_list_entries),
|
||||||
'draft_list': draft_list_entries
|
"draft_list": draft_list_entries,
|
||||||
}
|
}
|
||||||
|
|
||||||
await client.post(self.endpoint, payload)
|
await client.post(f"{self.endpoint}/", payload)
|
||||||
logger.info(f"Moved player {player_id} up to rank {current_entry.rank - 1}")
|
logger.info(f"Moved player {player_id} up to rank {current_entry.rank - 1}")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -447,12 +444,7 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
logger.error(f"Error moving player {player_id} up in draft list: {e}")
|
logger.error(f"Error moving player {player_id} up in draft list: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def move_entry_down(
|
async def move_entry_down(self, season: int, team_id: int, player_id: int) -> bool:
|
||||||
self,
|
|
||||||
season: int,
|
|
||||||
team_id: int,
|
|
||||||
player_id: int
|
|
||||||
) -> bool:
|
|
||||||
"""
|
"""
|
||||||
Move player down one position in draft list (lower priority).
|
Move player down one position in draft list (lower priority).
|
||||||
|
|
||||||
@ -485,7 +477,9 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Find entry below (rank + 1)
|
# Find entry below (rank + 1)
|
||||||
below_entry = next((e for e in entries if e.rank == current_entry.rank + 1), None)
|
below_entry = next(
|
||||||
|
(e for e in entries if e.rank == current_entry.rank + 1), None
|
||||||
|
)
|
||||||
if not below_entry:
|
if not below_entry:
|
||||||
logger.error(f"Could not find entry below rank {current_entry.rank}")
|
logger.error(f"Could not find entry below rank {current_entry.rank}")
|
||||||
return False
|
return False
|
||||||
@ -503,25 +497,29 @@ class DraftListService(BaseService[DraftList]):
|
|||||||
# Keep existing rank
|
# Keep existing rank
|
||||||
new_rank = entry.rank
|
new_rank = entry.rank
|
||||||
|
|
||||||
draft_list_entries.append({
|
draft_list_entries.append(
|
||||||
'season': entry.season,
|
{
|
||||||
'team_id': entry.team_id,
|
"season": entry.season,
|
||||||
'player_id': entry.player_id,
|
"team_id": entry.team_id,
|
||||||
'rank': new_rank
|
"player_id": entry.player_id,
|
||||||
})
|
"rank": new_rank,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Sort by rank
|
# Sort by rank
|
||||||
draft_list_entries.sort(key=lambda x: x['rank'])
|
draft_list_entries.sort(key=lambda x: x["rank"])
|
||||||
|
|
||||||
# POST updated list (bulk replacement)
|
# POST updated list (bulk replacement)
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
payload = {
|
payload = {
|
||||||
'count': len(draft_list_entries),
|
"count": len(draft_list_entries),
|
||||||
'draft_list': draft_list_entries
|
"draft_list": draft_list_entries,
|
||||||
}
|
}
|
||||||
|
|
||||||
await client.post(self.endpoint, payload)
|
await client.post(f"{self.endpoint}/", payload)
|
||||||
logger.info(f"Moved player {player_id} down to rank {current_entry.rank + 1}")
|
logger.info(
|
||||||
|
f"Moved player {player_id} down to rank {current_entry.rank + 1}"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@ -3,13 +3,14 @@ Injury service for Discord Bot v2.0
|
|||||||
|
|
||||||
Handles injury-related operations including checking, creating, and clearing injuries.
|
Handles injury-related operations including checking, creating, and clearing injuries.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from services.base_service import BaseService
|
from services.base_service import BaseService
|
||||||
from models.injury import Injury
|
from models.injury import Injury
|
||||||
|
|
||||||
logger = logging.getLogger(f'{__name__}.InjuryService')
|
logger = logging.getLogger(f"{__name__}.InjuryService")
|
||||||
|
|
||||||
|
|
||||||
class InjuryService(BaseService[Injury]):
|
class InjuryService(BaseService[Injury]):
|
||||||
@ -25,7 +26,7 @@ class InjuryService(BaseService[Injury]):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize injury service."""
|
"""Initialize injury service."""
|
||||||
super().__init__(Injury, 'injuries')
|
super().__init__(Injury, "injuries")
|
||||||
logger.debug("InjuryService initialized")
|
logger.debug("InjuryService initialized")
|
||||||
|
|
||||||
async def get_active_injury(self, player_id: int, season: int) -> Optional[Injury]:
|
async def get_active_injury(self, player_id: int, season: int) -> Optional[Injury]:
|
||||||
@ -41,25 +42,31 @@ class InjuryService(BaseService[Injury]):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
params = [
|
params = [
|
||||||
('player_id', str(player_id)),
|
("player_id", str(player_id)),
|
||||||
('season', str(season)),
|
("season", str(season)),
|
||||||
('is_active', 'true')
|
("is_active", "true"),
|
||||||
]
|
]
|
||||||
|
|
||||||
injuries = await self.get_all_items(params=params)
|
injuries = await self.get_all_items(params=params)
|
||||||
|
|
||||||
if injuries:
|
if injuries:
|
||||||
logger.debug(f"Found active injury for player {player_id} in season {season}")
|
logger.debug(
|
||||||
|
f"Found active injury for player {player_id} in season {season}"
|
||||||
|
)
|
||||||
return injuries[0]
|
return injuries[0]
|
||||||
|
|
||||||
logger.debug(f"No active injury found for player {player_id} in season {season}")
|
logger.debug(
|
||||||
|
f"No active injury found for player {player_id} in season {season}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting active injury for player {player_id}: {e}")
|
logger.error(f"Error getting active injury for player {player_id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_injuries_by_player(self, player_id: int, season: int, active_only: bool = False) -> List[Injury]:
|
async def get_injuries_by_player(
|
||||||
|
self, player_id: int, season: int, active_only: bool = False
|
||||||
|
) -> List[Injury]:
|
||||||
"""
|
"""
|
||||||
Get all injuries for a player in a specific season.
|
Get all injuries for a player in a specific season.
|
||||||
|
|
||||||
@ -72,13 +79,10 @@ class InjuryService(BaseService[Injury]):
|
|||||||
List of injuries for the player
|
List of injuries for the player
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
params = [
|
params = [("player_id", str(player_id)), ("season", str(season))]
|
||||||
('player_id', str(player_id)),
|
|
||||||
('season', str(season))
|
|
||||||
]
|
|
||||||
|
|
||||||
if active_only:
|
if active_only:
|
||||||
params.append(('is_active', 'true'))
|
params.append(("is_active", "true"))
|
||||||
|
|
||||||
injuries = await self.get_all_items(params=params)
|
injuries = await self.get_all_items(params=params)
|
||||||
logger.debug(f"Retrieved {len(injuries)} injuries for player {player_id}")
|
logger.debug(f"Retrieved {len(injuries)} injuries for player {player_id}")
|
||||||
@ -88,7 +92,9 @@ class InjuryService(BaseService[Injury]):
|
|||||||
logger.error(f"Error getting injuries for player {player_id}: {e}")
|
logger.error(f"Error getting injuries for player {player_id}: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_injuries_by_team(self, team_id: int, season: int, active_only: bool = True) -> List[Injury]:
|
async def get_injuries_by_team(
|
||||||
|
self, team_id: int, season: int, active_only: bool = True
|
||||||
|
) -> List[Injury]:
|
||||||
"""
|
"""
|
||||||
Get all injuries for a team in a specific season.
|
Get all injuries for a team in a specific season.
|
||||||
|
|
||||||
@ -101,13 +107,10 @@ class InjuryService(BaseService[Injury]):
|
|||||||
List of injuries for the team
|
List of injuries for the team
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
params = [
|
params = [("team_id", str(team_id)), ("season", str(season))]
|
||||||
('team_id', str(team_id)),
|
|
||||||
('season', str(season))
|
|
||||||
]
|
|
||||||
|
|
||||||
if active_only:
|
if active_only:
|
||||||
params.append(('is_active', 'true'))
|
params.append(("is_active", "true"))
|
||||||
|
|
||||||
injuries = await self.get_all_items(params=params)
|
injuries = await self.get_all_items(params=params)
|
||||||
logger.debug(f"Retrieved {len(injuries)} injuries for team {team_id}")
|
logger.debug(f"Retrieved {len(injuries)} injuries for team {team_id}")
|
||||||
@ -125,7 +128,7 @@ class InjuryService(BaseService[Injury]):
|
|||||||
start_week: int,
|
start_week: int,
|
||||||
start_game: int,
|
start_game: int,
|
||||||
end_week: int,
|
end_week: int,
|
||||||
end_game: int
|
end_game: int,
|
||||||
) -> Optional[Injury]:
|
) -> Optional[Injury]:
|
||||||
"""
|
"""
|
||||||
Create a new injury record.
|
Create a new injury record.
|
||||||
@ -144,22 +147,24 @@ class InjuryService(BaseService[Injury]):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
injury_data = {
|
injury_data = {
|
||||||
'season': season,
|
"season": season,
|
||||||
'player_id': player_id,
|
"player_id": player_id,
|
||||||
'total_games': total_games,
|
"total_games": total_games,
|
||||||
'start_week': start_week,
|
"start_week": start_week,
|
||||||
'start_game': start_game,
|
"start_game": start_game,
|
||||||
'end_week': end_week,
|
"end_week": end_week,
|
||||||
'end_game': end_game,
|
"end_game": end_game,
|
||||||
'is_active': True
|
"is_active": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Call the API to create the injury
|
# Call the API to create the injury
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
response = await client.post(self.endpoint, injury_data)
|
response = await client.post(f"{self.endpoint}/", injury_data)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
logger.error(f"Failed to create injury for player {player_id}: No response from API")
|
logger.error(
|
||||||
|
f"Failed to create injury for player {player_id}: No response from API"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Merge the request data with the response to ensure all required fields are present
|
# Merge the request data with the response to ensure all required fields are present
|
||||||
@ -187,7 +192,9 @@ class InjuryService(BaseService[Injury]):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Note: API expects is_active as query parameter, not JSON body
|
# Note: API expects is_active as query parameter, not JSON body
|
||||||
updated_injury = await self.patch(injury_id, {'is_active': False}, use_query_params=True)
|
updated_injury = await self.patch(
|
||||||
|
injury_id, {"is_active": False}, use_query_params=True
|
||||||
|
)
|
||||||
|
|
||||||
if updated_injury:
|
if updated_injury:
|
||||||
logger.info(f"Cleared injury {injury_id}")
|
logger.info(f"Cleared injury {injury_id}")
|
||||||
@ -216,16 +223,18 @@ class InjuryService(BaseService[Injury]):
|
|||||||
try:
|
try:
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
params = [
|
params = [
|
||||||
('season', str(season)),
|
("season", str(season)),
|
||||||
('is_active', 'true'),
|
("is_active", "true"),
|
||||||
('sort', 'return-asc')
|
("sort", "return-asc"),
|
||||||
]
|
]
|
||||||
|
|
||||||
response = await client.get(self.endpoint, params=params)
|
response = await client.get(self.endpoint, params=params)
|
||||||
|
|
||||||
if response and 'injuries' in response:
|
if response and "injuries" in response:
|
||||||
logger.debug(f"Retrieved {len(response['injuries'])} active injuries for season {season}")
|
logger.debug(
|
||||||
return response['injuries']
|
f"Retrieved {len(response['injuries'])} active injuries for season {season}"
|
||||||
|
)
|
||||||
|
return response["injuries"]
|
||||||
|
|
||||||
logger.debug(f"No active injuries found for season {season}")
|
logger.debug(f"No active injuries found for season {season}")
|
||||||
return []
|
return []
|
||||||
|
|||||||
@ -3,6 +3,7 @@ Play Service
|
|||||||
|
|
||||||
Manages play-by-play data operations for game submission.
|
Manages play-by-play data operations for game submission.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
from utils.logging import get_contextual_logger
|
from utils.logging import get_contextual_logger
|
||||||
@ -16,7 +17,7 @@ class PlayService:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize play service."""
|
"""Initialize play service."""
|
||||||
self.logger = get_contextual_logger(f'{__name__}.PlayService')
|
self.logger = get_contextual_logger(f"{__name__}.PlayService")
|
||||||
self._get_client = get_global_client
|
self._get_client = get_global_client
|
||||||
|
|
||||||
async def get_client(self):
|
async def get_client(self):
|
||||||
@ -39,8 +40,10 @@ class PlayService:
|
|||||||
try:
|
try:
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
|
|
||||||
payload = {'plays': plays}
|
payload = {"plays": plays}
|
||||||
response = await client.post('plays', payload)
|
# Trailing slash required: without it, the server returns a 307 redirect
|
||||||
|
# and aiohttp drops the POST body when following the redirect
|
||||||
|
response = await client.post("plays/", payload)
|
||||||
|
|
||||||
self.logger.info(f"Created {len(plays)} plays")
|
self.logger.info(f"Created {len(plays)} plays")
|
||||||
return True
|
return True
|
||||||
@ -68,7 +71,7 @@ class PlayService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
response = await client.delete(f'plays/game/{game_id}')
|
response = await client.delete(f"plays/game/{game_id}")
|
||||||
|
|
||||||
self.logger.info(f"Deleted plays for game {game_id}")
|
self.logger.info(f"Deleted plays for game {game_id}")
|
||||||
return True
|
return True
|
||||||
@ -77,11 +80,7 @@ class PlayService:
|
|||||||
self.logger.error(f"Failed to delete plays for game {game_id}: {e}")
|
self.logger.error(f"Failed to delete plays for game {game_id}: {e}")
|
||||||
raise APIException(f"Failed to delete plays: {e}")
|
raise APIException(f"Failed to delete plays: {e}")
|
||||||
|
|
||||||
async def get_top_plays_by_wpa(
|
async def get_top_plays_by_wpa(self, game_id: int, limit: int = 3) -> List[Play]:
|
||||||
self,
|
|
||||||
game_id: int,
|
|
||||||
limit: int = 3
|
|
||||||
) -> List[Play]:
|
|
||||||
"""
|
"""
|
||||||
Get top plays by WPA (absolute value) for key plays display.
|
Get top plays by WPA (absolute value) for key plays display.
|
||||||
|
|
||||||
@ -95,19 +94,15 @@ class PlayService:
|
|||||||
try:
|
try:
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
|
|
||||||
params = [
|
params = [("game_id", game_id), ("sort", "wpa-desc"), ("limit", limit)]
|
||||||
('game_id', game_id),
|
|
||||||
('sort', 'wpa-desc'),
|
|
||||||
('limit', limit)
|
|
||||||
]
|
|
||||||
|
|
||||||
response = await client.get('plays', params=params)
|
response = await client.get("plays", params=params)
|
||||||
|
|
||||||
if not response or 'plays' not in response:
|
if not response or "plays" not in response:
|
||||||
self.logger.info(f'No plays found for game ID {game_id}')
|
self.logger.info(f"No plays found for game ID {game_id}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
plays = [Play.from_api_data(p) for p in response['plays']]
|
plays = [Play.from_api_data(p) for p in response["plays"]]
|
||||||
|
|
||||||
self.logger.debug(f"Retrieved {len(plays)} top plays for game {game_id}")
|
self.logger.debug(f"Retrieved {len(plays)} top plays for game {game_id}")
|
||||||
return plays
|
return plays
|
||||||
@ -129,11 +124,11 @@ class PlayService:
|
|||||||
error_str = str(error)
|
error_str = str(error)
|
||||||
|
|
||||||
# Common error patterns
|
# Common error patterns
|
||||||
if 'Player ID' in error_str and 'not found' in error_str:
|
if "Player ID" in error_str and "not found" in error_str:
|
||||||
return "Invalid player ID in scorecard data. Please check player IDs."
|
return "Invalid player ID in scorecard data. Please check player IDs."
|
||||||
elif 'Game ID' in error_str and 'not found' in error_str:
|
elif "Game ID" in error_str and "not found" in error_str:
|
||||||
return "Game not found in database. Please contact an admin."
|
return "Game not found in database. Please contact an admin."
|
||||||
elif 'validation' in error_str.lower():
|
elif "validation" in error_str.lower():
|
||||||
return f"Data validation error: {error_str}"
|
return f"Data validation error: {error_str}"
|
||||||
else:
|
else:
|
||||||
return f"Error submitting plays: {error_str}"
|
return f"Error submitting plays: {error_str}"
|
||||||
|
|||||||
@ -416,6 +416,8 @@ class SheetsService:
|
|||||||
self.logger.info(f"Read {len(pit_data)} valid pitching decisions")
|
self.logger.info(f"Read {len(pit_data)} valid pitching decisions")
|
||||||
return pit_data
|
return pit_data
|
||||||
|
|
||||||
|
except SheetsException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Failed to read pitching decisions: {e}")
|
self.logger.error(f"Failed to read pitching decisions: {e}")
|
||||||
raise SheetsException("Unable to read pitching decisions") from e
|
raise SheetsException("Unable to read pitching decisions") from e
|
||||||
@ -458,6 +460,8 @@ class SheetsService:
|
|||||||
"home": [int(x) for x in score_table[1]], # [R, H, E]
|
"home": [int(x) for x in score_table[1]], # [R, H, E]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
except SheetsException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Failed to read box score: {e}")
|
self.logger.error(f"Failed to read box score: {e}")
|
||||||
raise SheetsException("Unable to read box score") from e
|
raise SheetsException("Unable to read box score") from e
|
||||||
|
|||||||
@ -3,6 +3,7 @@ Trade Builder Service
|
|||||||
|
|
||||||
Extends the TransactionBuilder to support multi-team trades and player exchanges.
|
Extends the TransactionBuilder to support multi-team trades and player exchanges.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Set
|
from typing import Dict, List, Optional, Set
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
@ -12,10 +13,14 @@ from config import get_config
|
|||||||
from models.trade import Trade, TradeMove, TradeStatus
|
from models.trade import Trade, TradeMove, TradeStatus
|
||||||
from models.team import Team, RosterType
|
from models.team import Team, RosterType
|
||||||
from models.player import Player
|
from models.player import Player
|
||||||
from services.transaction_builder import TransactionBuilder, RosterValidationResult, TransactionMove
|
from services.transaction_builder import (
|
||||||
|
TransactionBuilder,
|
||||||
|
RosterValidationResult,
|
||||||
|
TransactionMove,
|
||||||
|
)
|
||||||
from services.team_service import team_service
|
from services.team_service import team_service
|
||||||
|
|
||||||
logger = logging.getLogger(f'{__name__}.TradeBuilder')
|
logger = logging.getLogger(f"{__name__}.TradeBuilder")
|
||||||
|
|
||||||
|
|
||||||
class TradeValidationResult:
|
class TradeValidationResult:
|
||||||
@ -52,7 +57,9 @@ class TradeValidationResult:
|
|||||||
suggestions.extend(validation.suggestions)
|
suggestions.extend(validation.suggestions)
|
||||||
return suggestions
|
return suggestions
|
||||||
|
|
||||||
def get_participant_validation(self, team_id: int) -> Optional[RosterValidationResult]:
|
def get_participant_validation(
|
||||||
|
self, team_id: int
|
||||||
|
) -> Optional[RosterValidationResult]:
|
||||||
"""Get validation result for a specific team."""
|
"""Get validation result for a specific team."""
|
||||||
return self.participant_validations.get(team_id)
|
return self.participant_validations.get(team_id)
|
||||||
|
|
||||||
@ -64,7 +71,12 @@ class TradeBuilder:
|
|||||||
Extends the functionality of TransactionBuilder to support trades between teams.
|
Extends the functionality of TransactionBuilder to support trades between teams.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, initiated_by: int, initiating_team: Team, season: int = get_config().sba_season):
|
def __init__(
|
||||||
|
self,
|
||||||
|
initiated_by: int,
|
||||||
|
initiating_team: Team,
|
||||||
|
season: int = get_config().sba_season,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize trade builder.
|
Initialize trade builder.
|
||||||
|
|
||||||
@ -79,7 +91,7 @@ class TradeBuilder:
|
|||||||
status=TradeStatus.DRAFT,
|
status=TradeStatus.DRAFT,
|
||||||
initiated_by=initiated_by,
|
initiated_by=initiated_by,
|
||||||
created_at=datetime.now(timezone.utc).isoformat(),
|
created_at=datetime.now(timezone.utc).isoformat(),
|
||||||
season=season
|
season=season,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the initiating team as first participant
|
# Add the initiating team as first participant
|
||||||
@ -91,7 +103,9 @@ class TradeBuilder:
|
|||||||
# Track which teams have accepted the trade (team_id -> True)
|
# Track which teams have accepted the trade (team_id -> True)
|
||||||
self.accepted_teams: Set[int] = set()
|
self.accepted_teams: Set[int] = set()
|
||||||
|
|
||||||
logger.info(f"TradeBuilder initialized: {self.trade.trade_id} by user {initiated_by} for {initiating_team.abbrev}")
|
logger.info(
|
||||||
|
f"TradeBuilder initialized: {self.trade.trade_id} by user {initiated_by} for {initiating_team.abbrev}"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def trade_id(self) -> str:
|
def trade_id(self) -> str:
|
||||||
@ -127,7 +141,11 @@ class TradeBuilder:
|
|||||||
@property
|
@property
|
||||||
def pending_teams(self) -> List[Team]:
|
def pending_teams(self) -> List[Team]:
|
||||||
"""Get list of teams that haven't accepted yet."""
|
"""Get list of teams that haven't accepted yet."""
|
||||||
return [team for team in self.participating_teams if team.id not in self.accepted_teams]
|
return [
|
||||||
|
team
|
||||||
|
for team in self.participating_teams
|
||||||
|
if team.id not in self.accepted_teams
|
||||||
|
]
|
||||||
|
|
||||||
def accept_trade(self, team_id: int) -> bool:
|
def accept_trade(self, team_id: int) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -140,7 +158,9 @@ class TradeBuilder:
|
|||||||
True if all teams have now accepted, False otherwise
|
True if all teams have now accepted, False otherwise
|
||||||
"""
|
"""
|
||||||
self.accepted_teams.add(team_id)
|
self.accepted_teams.add(team_id)
|
||||||
logger.info(f"Team {team_id} accepted trade {self.trade_id}. Accepted: {len(self.accepted_teams)}/{self.team_count}")
|
logger.info(
|
||||||
|
f"Team {team_id} accepted trade {self.trade_id}. Accepted: {len(self.accepted_teams)}/{self.team_count}"
|
||||||
|
)
|
||||||
return self.all_teams_accepted
|
return self.all_teams_accepted
|
||||||
|
|
||||||
def reject_trade(self) -> None:
|
def reject_trade(self) -> None:
|
||||||
@ -160,7 +180,9 @@ class TradeBuilder:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict mapping team_id to acceptance status (True/False)
|
Dict mapping team_id to acceptance status (True/False)
|
||||||
"""
|
"""
|
||||||
return {team.id: team.id in self.accepted_teams for team in self.participating_teams}
|
return {
|
||||||
|
team.id: team.id in self.accepted_teams for team in self.participating_teams
|
||||||
|
}
|
||||||
|
|
||||||
def has_team_accepted(self, team_id: int) -> bool:
|
def has_team_accepted(self, team_id: int) -> bool:
|
||||||
"""Check if a specific team has accepted."""
|
"""Check if a specific team has accepted."""
|
||||||
@ -184,7 +206,9 @@ class TradeBuilder:
|
|||||||
participant = self.trade.add_participant(team)
|
participant = self.trade.add_participant(team)
|
||||||
|
|
||||||
# Create transaction builder for this team
|
# Create transaction builder for this team
|
||||||
self._team_builders[team.id] = TransactionBuilder(team, self.trade.initiated_by, self.trade.season)
|
self._team_builders[team.id] = TransactionBuilder(
|
||||||
|
team, self.trade.initiated_by, self.trade.season
|
||||||
|
)
|
||||||
|
|
||||||
# Register team in secondary index for multi-GM access
|
# Register team in secondary index for multi-GM access
|
||||||
trade_key = f"{self.trade.initiated_by}:trade"
|
trade_key = f"{self.trade.initiated_by}:trade"
|
||||||
@ -209,7 +233,10 @@ class TradeBuilder:
|
|||||||
|
|
||||||
# Check if team has moves - prevent removal if they do
|
# Check if team has moves - prevent removal if they do
|
||||||
if participant.all_moves:
|
if participant.all_moves:
|
||||||
return False, f"{participant.team.abbrev} has moves in this trade and cannot be removed"
|
return (
|
||||||
|
False,
|
||||||
|
f"{participant.team.abbrev} has moves in this trade and cannot be removed",
|
||||||
|
)
|
||||||
|
|
||||||
# Remove team
|
# Remove team
|
||||||
removed = self.trade.remove_participant(team_id)
|
removed = self.trade.remove_participant(team_id)
|
||||||
@ -229,7 +256,7 @@ class TradeBuilder:
|
|||||||
from_team: Team,
|
from_team: Team,
|
||||||
to_team: Team,
|
to_team: Team,
|
||||||
from_roster: RosterType,
|
from_roster: RosterType,
|
||||||
to_roster: RosterType
|
to_roster: RosterType,
|
||||||
) -> tuple[bool, str]:
|
) -> tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Add a player move to the trade.
|
Add a player move to the trade.
|
||||||
@ -246,7 +273,10 @@ class TradeBuilder:
|
|||||||
"""
|
"""
|
||||||
# Validate player is not from Free Agency
|
# Validate player is not from Free Agency
|
||||||
if player.team_id == get_config().free_agent_team_id:
|
if player.team_id == get_config().free_agent_team_id:
|
||||||
return False, f"Cannot add {player.name} from Free Agency. Players must be traded from teams within the organizations involved in the trade."
|
return (
|
||||||
|
False,
|
||||||
|
f"Cannot add {player.name} from Free Agency. Players must be traded from teams within the organizations involved in the trade.",
|
||||||
|
)
|
||||||
|
|
||||||
# Validate player has a valid team assignment
|
# Validate player has a valid team assignment
|
||||||
if not player.team_id:
|
if not player.team_id:
|
||||||
@ -259,7 +289,10 @@ class TradeBuilder:
|
|||||||
|
|
||||||
# Check if player's team is in the same organization as from_team
|
# Check if player's team is in the same organization as from_team
|
||||||
if not player_team.is_same_organization(from_team):
|
if not player_team.is_same_organization(from_team):
|
||||||
return False, f"{player.name} is on {player_team.abbrev}, they are not eligible to be added to the trade."
|
return (
|
||||||
|
False,
|
||||||
|
f"{player.name} is on {player_team.abbrev}, they are not eligible to be added to the trade.",
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure both teams are participating (check by organization for ML authority)
|
# Ensure both teams are participating (check by organization for ML authority)
|
||||||
from_participant = self.trade.get_participant_by_organization(from_team)
|
from_participant = self.trade.get_participant_by_organization(from_team)
|
||||||
@ -274,7 +307,10 @@ class TradeBuilder:
|
|||||||
for participant in self.trade.participants:
|
for participant in self.trade.participants:
|
||||||
for existing_move in participant.all_moves:
|
for existing_move in participant.all_moves:
|
||||||
if existing_move.player.id == player.id:
|
if existing_move.player.id == player.id:
|
||||||
return False, f"{player.name} is already involved in a move in this trade"
|
return (
|
||||||
|
False,
|
||||||
|
f"{player.name} is already involved in a move in this trade",
|
||||||
|
)
|
||||||
|
|
||||||
# Create trade move
|
# Create trade move
|
||||||
trade_move = TradeMove(
|
trade_move = TradeMove(
|
||||||
@ -284,7 +320,7 @@ class TradeBuilder:
|
|||||||
from_team=from_team,
|
from_team=from_team,
|
||||||
to_team=to_team,
|
to_team=to_team,
|
||||||
source_team=from_team,
|
source_team=from_team,
|
||||||
destination_team=to_team
|
destination_team=to_team,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add to giving team's moves
|
# Add to giving team's moves
|
||||||
@ -303,7 +339,7 @@ class TradeBuilder:
|
|||||||
from_roster=from_roster,
|
from_roster=from_roster,
|
||||||
to_roster=RosterType.FREE_AGENCY, # Conceptually leaving the org
|
to_roster=RosterType.FREE_AGENCY, # Conceptually leaving the org
|
||||||
from_team=from_team,
|
from_team=from_team,
|
||||||
to_team=None
|
to_team=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move for receiving team (player joining)
|
# Move for receiving team (player joining)
|
||||||
@ -312,19 +348,23 @@ class TradeBuilder:
|
|||||||
from_roster=RosterType.FREE_AGENCY, # Conceptually joining from outside
|
from_roster=RosterType.FREE_AGENCY, # Conceptually joining from outside
|
||||||
to_roster=to_roster,
|
to_roster=to_roster,
|
||||||
from_team=None,
|
from_team=None,
|
||||||
to_team=to_team
|
to_team=to_team,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add moves to respective builders
|
# Add moves to respective builders
|
||||||
# Skip pending transaction check for trades - they have their own validation workflow
|
# Skip pending transaction check for trades - they have their own validation workflow
|
||||||
from_success, from_error = await from_builder.add_move(from_move, check_pending_transactions=False)
|
from_success, from_error = await from_builder.add_move(
|
||||||
|
from_move, check_pending_transactions=False
|
||||||
|
)
|
||||||
if not from_success:
|
if not from_success:
|
||||||
# Remove from trade if builder failed
|
# Remove from trade if builder failed
|
||||||
from_participant.moves_giving.remove(trade_move)
|
from_participant.moves_giving.remove(trade_move)
|
||||||
to_participant.moves_receiving.remove(trade_move)
|
to_participant.moves_receiving.remove(trade_move)
|
||||||
return False, f"Error adding move to {from_team.abbrev}: {from_error}"
|
return False, f"Error adding move to {from_team.abbrev}: {from_error}"
|
||||||
|
|
||||||
to_success, to_error = await to_builder.add_move(to_move, check_pending_transactions=False)
|
to_success, to_error = await to_builder.add_move(
|
||||||
|
to_move, check_pending_transactions=False
|
||||||
|
)
|
||||||
if not to_success:
|
if not to_success:
|
||||||
# Rollback both if second failed
|
# Rollback both if second failed
|
||||||
from_builder.remove_move(player.id)
|
from_builder.remove_move(player.id)
|
||||||
@ -332,15 +372,13 @@ class TradeBuilder:
|
|||||||
to_participant.moves_receiving.remove(trade_move)
|
to_participant.moves_receiving.remove(trade_move)
|
||||||
return False, f"Error adding move to {to_team.abbrev}: {to_error}"
|
return False, f"Error adding move to {to_team.abbrev}: {to_error}"
|
||||||
|
|
||||||
logger.info(f"Added player move to trade {self.trade_id}: {trade_move.description}")
|
logger.info(
|
||||||
|
f"Added player move to trade {self.trade_id}: {trade_move.description}"
|
||||||
|
)
|
||||||
return True, ""
|
return True, ""
|
||||||
|
|
||||||
async def add_supplementary_move(
|
async def add_supplementary_move(
|
||||||
self,
|
self, team: Team, player: Player, from_roster: RosterType, to_roster: RosterType
|
||||||
team: Team,
|
|
||||||
player: Player,
|
|
||||||
from_roster: RosterType,
|
|
||||||
to_roster: RosterType
|
|
||||||
) -> tuple[bool, str]:
|
) -> tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Add a supplementary move (internal organizational move) for roster legality.
|
Add a supplementary move (internal organizational move) for roster legality.
|
||||||
@ -366,7 +404,7 @@ class TradeBuilder:
|
|||||||
from_team=team,
|
from_team=team,
|
||||||
to_team=team,
|
to_team=team,
|
||||||
source_team=team,
|
source_team=team,
|
||||||
destination_team=team
|
destination_team=team,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add to participant's supplementary moves
|
# Add to participant's supplementary moves
|
||||||
@ -379,16 +417,20 @@ class TradeBuilder:
|
|||||||
from_roster=from_roster,
|
from_roster=from_roster,
|
||||||
to_roster=to_roster,
|
to_roster=to_roster,
|
||||||
from_team=team,
|
from_team=team,
|
||||||
to_team=team
|
to_team=team,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Skip pending transaction check for trade supplementary moves
|
# Skip pending transaction check for trade supplementary moves
|
||||||
success, error = await builder.add_move(trans_move, check_pending_transactions=False)
|
success, error = await builder.add_move(
|
||||||
|
trans_move, check_pending_transactions=False
|
||||||
|
)
|
||||||
if not success:
|
if not success:
|
||||||
participant.supplementary_moves.remove(supp_move)
|
participant.supplementary_moves.remove(supp_move)
|
||||||
return False, error
|
return False, error
|
||||||
|
|
||||||
logger.info(f"Added supplementary move for {team.abbrev}: {supp_move.description}")
|
logger.info(
|
||||||
|
f"Added supplementary move for {team.abbrev}: {supp_move.description}"
|
||||||
|
)
|
||||||
return True, ""
|
return True, ""
|
||||||
|
|
||||||
async def remove_move(self, player_id: int) -> tuple[bool, str]:
|
async def remove_move(self, player_id: int) -> tuple[bool, str]:
|
||||||
@ -432,21 +474,41 @@ class TradeBuilder:
|
|||||||
for builder in self._team_builders.values():
|
for builder in self._team_builders.values():
|
||||||
builder.remove_move(player_id)
|
builder.remove_move(player_id)
|
||||||
|
|
||||||
logger.info(f"Removed move from trade {self.trade_id}: {removed_move.description}")
|
logger.info(
|
||||||
|
f"Removed move from trade {self.trade_id}: {removed_move.description}"
|
||||||
|
)
|
||||||
return True, ""
|
return True, ""
|
||||||
|
|
||||||
async def validate_trade(self, next_week: Optional[int] = None) -> TradeValidationResult:
|
async def validate_trade(
|
||||||
|
self, next_week: Optional[int] = None
|
||||||
|
) -> TradeValidationResult:
|
||||||
"""
|
"""
|
||||||
Validate the entire trade including all teams' roster legality.
|
Validate the entire trade including all teams' roster legality.
|
||||||
|
|
||||||
|
Validates against next week's projected roster (current roster + pending
|
||||||
|
transactions), matching the behavior of /dropadd validation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
next_week: Week to validate for (optional)
|
next_week: Week to validate for (auto-fetched if not provided)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TradeValidationResult with comprehensive validation
|
TradeValidationResult with comprehensive validation
|
||||||
"""
|
"""
|
||||||
result = TradeValidationResult()
|
result = TradeValidationResult()
|
||||||
|
|
||||||
|
# Auto-fetch next week so validation includes pending transactions
|
||||||
|
if next_week is None:
|
||||||
|
try:
|
||||||
|
from services.league_service import league_service
|
||||||
|
|
||||||
|
current_state = await league_service.get_current_state()
|
||||||
|
next_week = (current_state.week + 1) if current_state else 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not determine next week for trade validation: {e}"
|
||||||
|
)
|
||||||
|
next_week = None
|
||||||
|
|
||||||
# Validate trade structure
|
# Validate trade structure
|
||||||
is_balanced, balance_errors = self.trade.validate_trade_balance()
|
is_balanced, balance_errors = self.trade.validate_trade_balance()
|
||||||
if not is_balanced:
|
if not is_balanced:
|
||||||
@ -472,13 +534,17 @@ class TradeBuilder:
|
|||||||
if self.team_count < 2:
|
if self.team_count < 2:
|
||||||
result.trade_suggestions.append("Add another team to create a trade")
|
result.trade_suggestions.append("Add another team to create a trade")
|
||||||
|
|
||||||
logger.debug(f"Trade validation for {self.trade_id}: Legal={result.is_legal}, Errors={len(result.all_errors)}")
|
logger.debug(
|
||||||
|
f"Trade validation for {self.trade_id}: Legal={result.is_legal}, Errors={len(result.all_errors)}"
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_or_create_builder(self, team: Team) -> TransactionBuilder:
|
def _get_or_create_builder(self, team: Team) -> TransactionBuilder:
|
||||||
"""Get or create a transaction builder for a team."""
|
"""Get or create a transaction builder for a team."""
|
||||||
if team.id not in self._team_builders:
|
if team.id not in self._team_builders:
|
||||||
self._team_builders[team.id] = TransactionBuilder(team, self.trade.initiated_by, self.trade.season)
|
self._team_builders[team.id] = TransactionBuilder(
|
||||||
|
team, self.trade.initiated_by, self.trade.season
|
||||||
|
)
|
||||||
return self._team_builders[team.id]
|
return self._team_builders[team.id]
|
||||||
|
|
||||||
def clear_trade(self) -> None:
|
def clear_trade(self) -> None:
|
||||||
@ -592,4 +658,4 @@ def clear_trade_builder_by_team(team_id: int) -> bool:
|
|||||||
|
|
||||||
def get_active_trades() -> Dict[str, TradeBuilder]:
|
def get_active_trades() -> Dict[str, TradeBuilder]:
|
||||||
"""Get all active trade builders (for debugging/admin purposes)."""
|
"""Get all active trade builders (for debugging/admin purposes)."""
|
||||||
return _active_trade_builders.copy()
|
return _active_trade_builders.copy()
|
||||||
|
|||||||
@ -248,7 +248,7 @@ class TransactionService(BaseService[Transaction]):
|
|||||||
|
|
||||||
# POST batch to API
|
# POST batch to API
|
||||||
client = await self.get_client()
|
client = await self.get_client()
|
||||||
response = await client.post(self.endpoint, data=batch_data)
|
response = await client.post(f"{self.endpoint}/", data=batch_data)
|
||||||
|
|
||||||
# API returns a string like "2 transactions have been added"
|
# API returns a string like "2 transactions have been added"
|
||||||
# We need to return the original Transaction objects (they won't have IDs assigned by API)
|
# We need to return the original Transaction objects (they won't have IDs assigned by API)
|
||||||
|
|||||||
@ -1,18 +1,24 @@
|
|||||||
"""
|
"""
|
||||||
API client tests using aioresponses for clean HTTP mocking
|
API client tests using aioresponses for clean HTTP mocking
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
from aioresponses import aioresponses
|
from aioresponses import aioresponses
|
||||||
|
|
||||||
from api.client import APIClient, get_api_client, get_global_client, cleanup_global_client
|
from api.client import (
|
||||||
|
APIClient,
|
||||||
|
get_api_client,
|
||||||
|
get_global_client,
|
||||||
|
cleanup_global_client,
|
||||||
|
)
|
||||||
from exceptions import APIException
|
from exceptions import APIException
|
||||||
|
|
||||||
|
|
||||||
class TestAPIClientWithAioresponses:
|
class TestAPIClientWithAioresponses:
|
||||||
"""Test API client with aioresponses for HTTP mocking."""
|
"""Test API client with aioresponses for HTTP mocking."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config(self):
|
def mock_config(self):
|
||||||
"""Mock configuration for testing."""
|
"""Mock configuration for testing."""
|
||||||
@ -20,66 +26,57 @@ class TestAPIClientWithAioresponses:
|
|||||||
config.db_url = "https://api.example.com"
|
config.db_url = "https://api.example.com"
|
||||||
config.api_token = "test-token"
|
config.api_token = "test-token"
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def api_client(self, mock_config):
|
def api_client(self, mock_config):
|
||||||
"""Create API client with mocked config."""
|
"""Create API client with mocked config."""
|
||||||
with patch('api.client.get_config', return_value=mock_config):
|
with patch("api.client.get_config", return_value=mock_config):
|
||||||
return APIClient()
|
return APIClient()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_request_success(self, api_client):
|
async def test_get_request_success(self, api_client):
|
||||||
"""Test successful GET request."""
|
"""Test successful GET request."""
|
||||||
expected_data = {"id": 1, "name": "Test Player"}
|
expected_data = {"id": 1, "name": "Test Player"}
|
||||||
|
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.get(
|
m.get(
|
||||||
"https://api.example.com/v3/players/1",
|
"https://api.example.com/v3/players/1",
|
||||||
payload=expected_data,
|
payload=expected_data,
|
||||||
status=200
|
status=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await api_client.get("players", object_id=1)
|
result = await api_client.get("players", object_id=1)
|
||||||
|
|
||||||
assert result == expected_data
|
assert result == expected_data
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_request_404(self, api_client):
|
async def test_get_request_404(self, api_client):
|
||||||
"""Test GET request returning 404."""
|
"""Test GET request returning 404."""
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.get(
|
m.get("https://api.example.com/v3/players/999", status=404)
|
||||||
"https://api.example.com/v3/players/999",
|
|
||||||
status=404
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await api_client.get("players", object_id=999)
|
result = await api_client.get("players", object_id=999)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_request_401_auth_error(self, api_client):
|
async def test_get_request_401_auth_error(self, api_client):
|
||||||
"""Test GET request with authentication error."""
|
"""Test GET request with authentication error."""
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.get(
|
m.get("https://api.example.com/v3/players", status=401)
|
||||||
"https://api.example.com/v3/players",
|
|
||||||
status=401
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(APIException, match="Authentication failed"):
|
with pytest.raises(APIException, match="Authentication failed"):
|
||||||
await api_client.get("players")
|
await api_client.get("players")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_request_403_forbidden(self, api_client):
|
async def test_get_request_403_forbidden(self, api_client):
|
||||||
"""Test GET request with forbidden error."""
|
"""Test GET request with forbidden error."""
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.get(
|
m.get("https://api.example.com/v3/players", status=403)
|
||||||
"https://api.example.com/v3/players",
|
|
||||||
status=403
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(APIException, match="Access forbidden"):
|
with pytest.raises(APIException, match="Access forbidden"):
|
||||||
await api_client.get("players")
|
await api_client.get("players")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_request_500_server_error(self, api_client):
|
async def test_get_request_500_server_error(self, api_client):
|
||||||
"""Test GET request with server error."""
|
"""Test GET request with server error."""
|
||||||
@ -87,135 +84,127 @@ class TestAPIClientWithAioresponses:
|
|||||||
m.get(
|
m.get(
|
||||||
"https://api.example.com/v3/players",
|
"https://api.example.com/v3/players",
|
||||||
status=500,
|
status=500,
|
||||||
body="Internal Server Error"
|
body="Internal Server Error",
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(APIException, match="API request failed with status 500"):
|
with pytest.raises(
|
||||||
|
APIException, match="API request failed with status 500"
|
||||||
|
):
|
||||||
await api_client.get("players")
|
await api_client.get("players")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_request_with_params(self, api_client):
|
async def test_get_request_with_params(self, api_client):
|
||||||
"""Test GET request with query parameters."""
|
"""Test GET request with query parameters."""
|
||||||
expected_data = {"count": 2, "players": [{"id": 1}, {"id": 2}]}
|
expected_data = {"count": 2, "players": [{"id": 1}, {"id": 2}]}
|
||||||
|
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.get(
|
m.get(
|
||||||
"https://api.example.com/v3/players?team_id=5&season=12",
|
"https://api.example.com/v3/players?team_id=5&season=12",
|
||||||
payload=expected_data,
|
payload=expected_data,
|
||||||
status=200
|
status=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await api_client.get("players", params=[("team_id", "5"), ("season", "12")])
|
result = await api_client.get(
|
||||||
|
"players", params=[("team_id", "5"), ("season", "12")]
|
||||||
|
)
|
||||||
|
|
||||||
assert result == expected_data
|
assert result == expected_data
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_post_request_success(self, api_client):
|
async def test_post_request_success(self, api_client):
|
||||||
"""Test successful POST request."""
|
"""Test successful POST request."""
|
||||||
input_data = {"name": "New Player", "position": "C"}
|
input_data = {"name": "New Player", "position": "C"}
|
||||||
expected_response = {"id": 1, "name": "New Player", "position": "C"}
|
expected_response = {"id": 1, "name": "New Player", "position": "C"}
|
||||||
|
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.post(
|
m.post(
|
||||||
"https://api.example.com/v3/players",
|
"https://api.example.com/v3/players",
|
||||||
payload=expected_response,
|
payload=expected_response,
|
||||||
status=201
|
status=201,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await api_client.post("players", input_data)
|
result = await api_client.post("players", input_data)
|
||||||
|
|
||||||
assert result == expected_response
|
assert result == expected_response
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_post_request_400_error(self, api_client):
|
async def test_post_request_400_error(self, api_client):
|
||||||
"""Test POST request with validation error."""
|
"""Test POST request with validation error."""
|
||||||
input_data = {"invalid": "data"}
|
input_data = {"invalid": "data"}
|
||||||
|
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.post(
|
m.post(
|
||||||
"https://api.example.com/v3/players",
|
"https://api.example.com/v3/players", status=400, body="Invalid data"
|
||||||
status=400,
|
|
||||||
body="Invalid data"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(APIException, match="POST request failed with status 400"):
|
with pytest.raises(
|
||||||
|
APIException, match="POST request failed with status 400"
|
||||||
|
):
|
||||||
await api_client.post("players", input_data)
|
await api_client.post("players", input_data)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_put_request_success(self, api_client):
|
async def test_put_request_success(self, api_client):
|
||||||
"""Test successful PUT request."""
|
"""Test successful PUT request."""
|
||||||
update_data = {"name": "Updated Player"}
|
update_data = {"name": "Updated Player"}
|
||||||
expected_response = {"id": 1, "name": "Updated Player"}
|
expected_response = {"id": 1, "name": "Updated Player"}
|
||||||
|
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.put(
|
m.put(
|
||||||
"https://api.example.com/v3/players/1",
|
"https://api.example.com/v3/players/1",
|
||||||
payload=expected_response,
|
payload=expected_response,
|
||||||
status=200
|
status=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await api_client.put("players", update_data, object_id=1)
|
result = await api_client.put("players", update_data, object_id=1)
|
||||||
|
|
||||||
assert result == expected_response
|
assert result == expected_response
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_put_request_404(self, api_client):
|
async def test_put_request_404(self, api_client):
|
||||||
"""Test PUT request with 404."""
|
"""Test PUT request with 404."""
|
||||||
update_data = {"name": "Updated Player"}
|
update_data = {"name": "Updated Player"}
|
||||||
|
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.put(
|
m.put("https://api.example.com/v3/players/999", status=404)
|
||||||
"https://api.example.com/v3/players/999",
|
|
||||||
status=404
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await api_client.put("players", update_data, object_id=999)
|
result = await api_client.put("players", update_data, object_id=999)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_request_success(self, api_client):
|
async def test_delete_request_success(self, api_client):
|
||||||
"""Test successful DELETE request."""
|
"""Test successful DELETE request."""
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.delete(
|
m.delete("https://api.example.com/v3/players/1", status=204)
|
||||||
"https://api.example.com/v3/players/1",
|
|
||||||
status=204
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await api_client.delete("players", object_id=1)
|
result = await api_client.delete("players", object_id=1)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_request_404(self, api_client):
|
async def test_delete_request_404(self, api_client):
|
||||||
"""Test DELETE request with 404."""
|
"""Test DELETE request with 404."""
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.delete(
|
m.delete("https://api.example.com/v3/players/999", status=404)
|
||||||
"https://api.example.com/v3/players/999",
|
|
||||||
status=404
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await api_client.delete("players", object_id=999)
|
result = await api_client.delete("players", object_id=999)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_request_200_success(self, api_client):
|
async def test_delete_request_200_success(self, api_client):
|
||||||
"""Test DELETE request with 200 success."""
|
"""Test DELETE request with 200 success."""
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.delete(
|
m.delete("https://api.example.com/v3/players/1", status=200)
|
||||||
"https://api.example.com/v3/players/1",
|
|
||||||
status=200
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await api_client.delete("players", object_id=1)
|
result = await api_client.delete("players", object_id=1)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
class TestAPIClientHelpers:
|
class TestAPIClientHelpers:
|
||||||
"""Test API client helper functions."""
|
"""Test API client helper functions."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config(self):
|
def mock_config(self):
|
||||||
"""Mock configuration for testing."""
|
"""Mock configuration for testing."""
|
||||||
@ -223,49 +212,49 @@ class TestAPIClientHelpers:
|
|||||||
config.db_url = "https://api.example.com"
|
config.db_url = "https://api.example.com"
|
||||||
config.api_token = "test-token"
|
config.api_token = "test-token"
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_api_client_context_manager(self, mock_config):
|
async def test_get_api_client_context_manager(self, mock_config):
|
||||||
"""Test get_api_client context manager."""
|
"""Test get_api_client context manager."""
|
||||||
with patch('api.client.get_config', return_value=mock_config):
|
with patch("api.client.get_config", return_value=mock_config):
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.get(
|
m.get(
|
||||||
"https://api.example.com/v3/test",
|
"https://api.example.com/v3/test",
|
||||||
payload={"success": True},
|
payload={"success": True},
|
||||||
status=200
|
status=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with get_api_client() as client:
|
async with get_api_client() as client:
|
||||||
assert isinstance(client, APIClient)
|
assert isinstance(client, APIClient)
|
||||||
result = await client.get("test")
|
result = await client.get("test")
|
||||||
assert result == {"success": True}
|
assert result == {"success": True}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_global_client_management(self, mock_config):
|
async def test_global_client_management(self, mock_config):
|
||||||
"""Test global client getter and cleanup."""
|
"""Test global client getter and cleanup."""
|
||||||
with patch('api.client.get_config', return_value=mock_config):
|
with patch("api.client.get_config", return_value=mock_config):
|
||||||
# Get global client
|
# Get global client
|
||||||
client1 = await get_global_client()
|
client1 = await get_global_client()
|
||||||
client2 = await get_global_client()
|
client2 = await get_global_client()
|
||||||
|
|
||||||
# Should return same instance
|
# Should return same instance
|
||||||
assert client1 is client2
|
assert client1 is client2
|
||||||
assert isinstance(client1, APIClient)
|
assert isinstance(client1, APIClient)
|
||||||
|
|
||||||
# Test cleanup
|
# Test cleanup
|
||||||
await cleanup_global_client()
|
await cleanup_global_client()
|
||||||
|
|
||||||
# New client should be different instance
|
# New client should be different instance
|
||||||
client3 = await get_global_client()
|
client3 = await get_global_client()
|
||||||
assert client3 is not client1
|
assert client3 is not client1
|
||||||
|
|
||||||
# Clean up for other tests
|
# Clean up for other tests
|
||||||
await cleanup_global_client()
|
await cleanup_global_client()
|
||||||
|
|
||||||
|
|
||||||
class TestIntegrationScenarios:
|
class TestIntegrationScenarios:
|
||||||
"""Test realistic integration scenarios."""
|
"""Test realistic integration scenarios."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config(self):
|
def mock_config(self):
|
||||||
"""Mock configuration for testing."""
|
"""Mock configuration for testing."""
|
||||||
@ -273,11 +262,11 @@ class TestIntegrationScenarios:
|
|||||||
config.db_url = "https://api.example.com"
|
config.db_url = "https://api.example.com"
|
||||||
config.api_token = "test-token"
|
config.api_token = "test-token"
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_player_retrieval_with_team_lookup(self, mock_config):
|
async def test_player_retrieval_with_team_lookup(self, mock_config):
|
||||||
"""Test realistic scenario: get player with team data."""
|
"""Test realistic scenario: get player with team data."""
|
||||||
with patch('api.client.get_config', return_value=mock_config):
|
with patch("api.client.get_config", return_value=mock_config):
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
# Mock player data response
|
# Mock player data response
|
||||||
player_data = {
|
player_data = {
|
||||||
@ -287,43 +276,41 @@ class TestIntegrationScenarios:
|
|||||||
"season": 12,
|
"season": 12,
|
||||||
"team_id": 5,
|
"team_id": 5,
|
||||||
"image": "https://example.com/player1.jpg",
|
"image": "https://example.com/player1.jpg",
|
||||||
"pos_1": "C"
|
"pos_1": "C",
|
||||||
}
|
}
|
||||||
m.get(
|
m.get(
|
||||||
"https://api.example.com/v3/players/1",
|
"https://api.example.com/v3/players/1",
|
||||||
payload=player_data,
|
payload=player_data,
|
||||||
status=200
|
status=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock team data response
|
# Mock team data response
|
||||||
team_data = {
|
team_data = {
|
||||||
"id": 5,
|
"id": 5,
|
||||||
"abbrev": "TST",
|
"abbrev": "TST",
|
||||||
"sname": "Test Team",
|
"sname": "Test Team",
|
||||||
"lname": "Test Team Full Name",
|
"lname": "Test Team Full Name",
|
||||||
"season": 12
|
"season": 12,
|
||||||
}
|
}
|
||||||
m.get(
|
m.get(
|
||||||
"https://api.example.com/v3/teams/5",
|
"https://api.example.com/v3/teams/5", payload=team_data, status=200
|
||||||
payload=team_data,
|
|
||||||
status=200
|
|
||||||
)
|
)
|
||||||
|
|
||||||
client = APIClient()
|
client = APIClient()
|
||||||
|
|
||||||
# Get player
|
# Get player
|
||||||
player = await client.get("players", object_id=1)
|
player = await client.get("players", object_id=1)
|
||||||
assert player["name"] == "Test Player"
|
assert player["name"] == "Test Player"
|
||||||
assert player["team_id"] == 5
|
assert player["team_id"] == 5
|
||||||
|
|
||||||
# Get team for player
|
# Get team for player
|
||||||
team = await client.get("teams", object_id=player["team_id"])
|
team = await client.get("teams", object_id=player["team_id"])
|
||||||
assert team["sname"] == "Test Team"
|
assert team["sname"] == "Test Team"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_api_response_format_handling(self, mock_config):
|
async def test_api_response_format_handling(self, mock_config):
|
||||||
"""Test handling of the API's count + list format."""
|
"""Test handling of the API's count + list format."""
|
||||||
with patch('api.client.get_config', return_value=mock_config):
|
with patch("api.client.get_config", return_value=mock_config):
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
# Mock API response with count format
|
# Mock API response with count format
|
||||||
api_response = {
|
api_response = {
|
||||||
@ -336,7 +323,7 @@ class TestIntegrationScenarios:
|
|||||||
"season": 12,
|
"season": 12,
|
||||||
"team_id": 5,
|
"team_id": 5,
|
||||||
"image": "https://example.com/player1.jpg",
|
"image": "https://example.com/player1.jpg",
|
||||||
"pos_1": "C"
|
"pos_1": "C",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2,
|
"id": 2,
|
||||||
@ -345,93 +332,93 @@ class TestIntegrationScenarios:
|
|||||||
"season": 12,
|
"season": 12,
|
||||||
"team_id": 6,
|
"team_id": 6,
|
||||||
"image": "https://example.com/player2.jpg",
|
"image": "https://example.com/player2.jpg",
|
||||||
"pos_1": "1B"
|
"pos_1": "1B",
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
m.get(
|
m.get(
|
||||||
"https://api.example.com/v3/players?team_id=5",
|
"https://api.example.com/v3/players?team_id=5",
|
||||||
payload=api_response,
|
payload=api_response,
|
||||||
status=200
|
status=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = APIClient()
|
client = APIClient()
|
||||||
result = await client.get("players", params=[("team_id", "5")])
|
result = await client.get("players", params=[("team_id", "5")])
|
||||||
|
|
||||||
assert result["count"] == 25
|
assert result["count"] == 25
|
||||||
assert len(result["players"]) == 2
|
assert len(result["players"]) == 2
|
||||||
assert result["players"][0]["name"] == "Player 1"
|
assert result["players"][0]["name"] == "Player 1"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_error_recovery_scenarios(self, mock_config):
|
async def test_error_recovery_scenarios(self, mock_config):
|
||||||
"""Test error handling and recovery."""
|
"""Test error handling and recovery."""
|
||||||
with patch('api.client.get_config', return_value=mock_config):
|
with patch("api.client.get_config", return_value=mock_config):
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
# First request fails with 500
|
# First request fails with 500
|
||||||
m.get(
|
m.get(
|
||||||
"https://api.example.com/v3/players/1",
|
"https://api.example.com/v3/players/1",
|
||||||
status=500,
|
status=500,
|
||||||
body="Internal Server Error"
|
body="Internal Server Error",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Second request succeeds
|
# Second request succeeds
|
||||||
m.get(
|
m.get(
|
||||||
"https://api.example.com/v3/players/2",
|
"https://api.example.com/v3/players/2",
|
||||||
payload={"id": 2, "name": "Working Player"},
|
payload={"id": 2, "name": "Working Player"},
|
||||||
status=200
|
status=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = APIClient()
|
client = APIClient()
|
||||||
|
|
||||||
# First request should raise exception
|
# First request should raise exception
|
||||||
with pytest.raises(APIException, match="API request failed"):
|
with pytest.raises(APIException, match="API request failed"):
|
||||||
await client.get("players", object_id=1)
|
await client.get("players", object_id=1)
|
||||||
|
|
||||||
# Second request should work fine
|
# Second request should work fine
|
||||||
result = await client.get("players", object_id=2)
|
result = await client.get("players", object_id=2)
|
||||||
assert result["name"] == "Working Player"
|
assert result["name"] == "Working Player"
|
||||||
|
|
||||||
# Client should still be functional
|
# Client should still be functional
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_concurrent_requests(self, mock_config):
|
async def test_concurrent_requests(self, mock_config):
|
||||||
"""Test multiple concurrent requests."""
|
"""Test multiple concurrent requests."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
with patch('api.client.get_config', return_value=mock_config):
|
with patch("api.client.get_config", return_value=mock_config):
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
# Mock multiple endpoints
|
# Mock multiple endpoints
|
||||||
for i in range(1, 4):
|
for i in range(1, 4):
|
||||||
m.get(
|
m.get(
|
||||||
f"https://api.example.com/v3/players/{i}",
|
f"https://api.example.com/v3/players/{i}",
|
||||||
payload={"id": i, "name": f"Player {i}"},
|
payload={"id": i, "name": f"Player {i}"},
|
||||||
status=200
|
status=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = APIClient()
|
client = APIClient()
|
||||||
|
|
||||||
# Make concurrent requests
|
# Make concurrent requests
|
||||||
tasks = [
|
tasks = [
|
||||||
client.get("players", object_id=1),
|
client.get("players", object_id=1),
|
||||||
client.get("players", object_id=2),
|
client.get("players", object_id=2),
|
||||||
client.get("players", object_id=3)
|
client.get("players", object_id=3),
|
||||||
]
|
]
|
||||||
|
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
assert results[0]["name"] == "Player 1"
|
assert results[0]["name"] == "Player 1"
|
||||||
assert results[1]["name"] == "Player 2"
|
assert results[1]["name"] == "Player 2"
|
||||||
assert results[2]["name"] == "Player 3"
|
assert results[2]["name"] == "Player 3"
|
||||||
|
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
class TestAPIClientCoverageExtras:
|
class TestAPIClientCoverageExtras:
|
||||||
"""Additional coverage tests for API client edge cases."""
|
"""Additional coverage tests for API client edge cases."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config(self):
|
def mock_config(self):
|
||||||
"""Mock configuration for testing."""
|
"""Mock configuration for testing."""
|
||||||
@ -439,98 +426,104 @@ class TestAPIClientCoverageExtras:
|
|||||||
config.db_url = "https://api.example.com"
|
config.db_url = "https://api.example.com"
|
||||||
config.api_token = "test-token"
|
config.api_token = "test-token"
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_global_client_cleanup_when_none(self):
|
async def test_global_client_cleanup_when_none(self):
|
||||||
"""Test cleanup when no global client exists."""
|
"""Test cleanup when no global client exists."""
|
||||||
# Ensure no global client exists
|
# Ensure no global client exists
|
||||||
await cleanup_global_client()
|
await cleanup_global_client()
|
||||||
|
|
||||||
# Should not raise error
|
# Should not raise error
|
||||||
await cleanup_global_client()
|
await cleanup_global_client()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_url_building_edge_cases(self, mock_config):
|
async def test_url_building_edge_cases(self, mock_config):
|
||||||
"""Test URL building with various edge cases."""
|
"""Test URL building with various edge cases."""
|
||||||
with patch('api.client.get_config', return_value=mock_config):
|
with patch("api.client.get_config", return_value=mock_config):
|
||||||
client = APIClient()
|
client = APIClient()
|
||||||
|
|
||||||
# Test trailing slash handling
|
# Test trailing slash handling
|
||||||
client.base_url = "https://api.example.com/"
|
client.base_url = "https://api.example.com/"
|
||||||
url = client._build_url("players")
|
url = client._build_url("players")
|
||||||
assert url == "https://api.example.com/v3/players"
|
assert url == "https://api.example.com/v3/players"
|
||||||
assert "//" not in url.replace("https://", "")
|
assert "//" not in url.replace("https://", "")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_parameter_handling_edge_cases(self, mock_config):
|
async def test_parameter_handling_edge_cases(self, mock_config):
|
||||||
"""Test parameter handling with various scenarios."""
|
"""Test parameter handling with various scenarios."""
|
||||||
with patch('api.client.get_config', return_value=mock_config):
|
with patch("api.client.get_config", return_value=mock_config):
|
||||||
client = APIClient()
|
client = APIClient()
|
||||||
|
|
||||||
# Test with existing query string
|
# Test with existing query string
|
||||||
url = client._add_params("https://example.com/api?existing=true", [("new", "param")])
|
url = client._add_params(
|
||||||
|
"https://example.com/api?existing=true", [("new", "param")]
|
||||||
|
)
|
||||||
assert url == "https://example.com/api?existing=true&new=param"
|
assert url == "https://example.com/api?existing=true&new=param"
|
||||||
|
|
||||||
# Test with no parameters
|
# Test with no parameters
|
||||||
url = client._add_params("https://example.com/api")
|
url = client._add_params("https://example.com/api")
|
||||||
assert url == "https://example.com/api"
|
assert url == "https://example.com/api"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_timeout_error_handling(self, mock_config):
|
async def test_timeout_error_handling(self, mock_config):
|
||||||
"""Test timeout error handling using aioresponses."""
|
"""Test timeout error handling using aioresponses."""
|
||||||
with patch('api.client.get_config', return_value=mock_config):
|
with patch("api.client.get_config", return_value=mock_config):
|
||||||
client = APIClient()
|
client = APIClient()
|
||||||
|
|
||||||
# Test timeout using aioresponses exception parameter
|
# Test timeout using aioresponses exception parameter
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.get(
|
m.get(
|
||||||
"https://api.example.com/v3/players",
|
"https://api.example.com/v3/players",
|
||||||
exception=asyncio.TimeoutError("Request timed out")
|
exception=asyncio.TimeoutError("Request timed out"),
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(APIException, match="API call failed.*Request timed out"):
|
with pytest.raises(
|
||||||
|
APIException, match="API call failed.*Request timed out"
|
||||||
|
):
|
||||||
await client.get("players")
|
await client.get("players")
|
||||||
|
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generic_exception_handling(self, mock_config):
|
async def test_generic_exception_handling(self, mock_config):
|
||||||
"""Test generic exception handling."""
|
"""Test generic exception handling."""
|
||||||
with patch('api.client.get_config', return_value=mock_config):
|
with patch("api.client.get_config", return_value=mock_config):
|
||||||
client = APIClient()
|
client = APIClient()
|
||||||
|
|
||||||
# Test generic exception
|
# Test generic exception
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.get(
|
m.get(
|
||||||
"https://api.example.com/v3/players",
|
"https://api.example.com/v3/players",
|
||||||
exception=Exception("Generic error")
|
exception=Exception("Generic error"),
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(APIException, match="API call failed.*Generic error"):
|
with pytest.raises(
|
||||||
|
APIException, match="API call failed.*Generic error"
|
||||||
|
):
|
||||||
await client.get("players")
|
await client.get("players")
|
||||||
|
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_session_closed_handling(self, mock_config):
|
async def test_session_closed_handling(self, mock_config):
|
||||||
"""Test handling of closed session."""
|
"""Test handling of closed session."""
|
||||||
with patch('api.client.get_config', return_value=mock_config):
|
with patch("api.client.get_config", return_value=mock_config):
|
||||||
# Test that the client recreates session when needed
|
# Test that the client recreates session when needed
|
||||||
with aioresponses() as m:
|
with aioresponses() as m:
|
||||||
m.get(
|
m.get(
|
||||||
"https://api.example.com/v3/players",
|
"https://api.example.com/v3/players",
|
||||||
payload={"success": True},
|
payload={"success": True},
|
||||||
status=200
|
status=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = APIClient()
|
client = APIClient()
|
||||||
|
|
||||||
# Close the session manually
|
# Close the session manually
|
||||||
await client._ensure_session()
|
await client._ensure_session()
|
||||||
await client._session.close()
|
await client._session.close()
|
||||||
|
|
||||||
# Client should recreate session and work fine
|
# Client should recreate session and work fine
|
||||||
result = await client.get("players")
|
result = await client.get("players")
|
||||||
assert result == {"success": True}
|
assert result == {"success": True}
|
||||||
|
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Tests for BaseService functionality
|
Tests for BaseService functionality
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
@ -10,6 +11,7 @@ from models.base import SBABaseModel
|
|||||||
|
|
||||||
class MockModel(SBABaseModel):
|
class MockModel(SBABaseModel):
|
||||||
"""Mock model for testing BaseService."""
|
"""Mock model for testing BaseService."""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
value: int = 100
|
value: int = 100
|
||||||
@ -17,240 +19,229 @@ class MockModel(SBABaseModel):
|
|||||||
|
|
||||||
class TestBaseService:
|
class TestBaseService:
|
||||||
"""Test BaseService functionality."""
|
"""Test BaseService functionality."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_client(self):
|
def mock_client(self):
|
||||||
"""Mock API client."""
|
"""Mock API client."""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def base_service(self, mock_client):
|
def base_service(self, mock_client):
|
||||||
"""Create BaseService instance for testing."""
|
"""Create BaseService instance for testing."""
|
||||||
service = BaseService(MockModel, 'mocks', client=mock_client)
|
service = BaseService(MockModel, "mocks", client=mock_client)
|
||||||
return service
|
return service
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_init(self):
|
async def test_init(self):
|
||||||
"""Test service initialization."""
|
"""Test service initialization."""
|
||||||
service = BaseService(MockModel, 'test_endpoint')
|
service = BaseService(MockModel, "test_endpoint")
|
||||||
assert service.model_class == MockModel
|
assert service.model_class == MockModel
|
||||||
assert service.endpoint == 'test_endpoint'
|
assert service.endpoint == "test_endpoint"
|
||||||
assert service._client is None
|
assert service._client is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_by_id_success(self, base_service, mock_client):
|
async def test_get_by_id_success(self, base_service, mock_client):
|
||||||
"""Test successful get_by_id."""
|
"""Test successful get_by_id."""
|
||||||
mock_data = {'id': 1, 'name': 'Test', 'value': 200}
|
mock_data = {"id": 1, "name": "Test", "value": 200}
|
||||||
mock_client.get.return_value = mock_data
|
mock_client.get.return_value = mock_data
|
||||||
|
|
||||||
result = await base_service.get_by_id(1)
|
result = await base_service.get_by_id(1)
|
||||||
|
|
||||||
assert isinstance(result, MockModel)
|
assert isinstance(result, MockModel)
|
||||||
assert result.id == 1
|
assert result.id == 1
|
||||||
assert result.name == 'Test'
|
assert result.name == "Test"
|
||||||
assert result.value == 200
|
assert result.value == 200
|
||||||
mock_client.get.assert_called_once_with('mocks', object_id=1)
|
mock_client.get.assert_called_once_with("mocks", object_id=1)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_by_id_not_found(self, base_service, mock_client):
|
async def test_get_by_id_not_found(self, base_service, mock_client):
|
||||||
"""Test get_by_id when object not found."""
|
"""Test get_by_id when object not found."""
|
||||||
mock_client.get.return_value = None
|
mock_client.get.return_value = None
|
||||||
|
|
||||||
result = await base_service.get_by_id(999)
|
result = await base_service.get_by_id(999)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
mock_client.get.assert_called_once_with('mocks', object_id=999)
|
mock_client.get.assert_called_once_with("mocks", object_id=999)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_all_with_count(self, base_service, mock_client):
|
async def test_get_all_with_count(self, base_service, mock_client):
|
||||||
"""Test get_all with count response format."""
|
"""Test get_all with count response format."""
|
||||||
mock_data = {
|
mock_data = {
|
||||||
'count': 2,
|
"count": 2,
|
||||||
'mocks': [
|
"mocks": [
|
||||||
{'id': 1, 'name': 'Test1', 'value': 100},
|
{"id": 1, "name": "Test1", "value": 100},
|
||||||
{'id': 2, 'name': 'Test2', 'value': 200}
|
{"id": 2, "name": "Test2", "value": 200},
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
mock_client.get.return_value = mock_data
|
mock_client.get.return_value = mock_data
|
||||||
|
|
||||||
result, count = await base_service.get_all()
|
result, count = await base_service.get_all()
|
||||||
|
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert count == 2
|
assert count == 2
|
||||||
assert all(isinstance(item, MockModel) for item in result)
|
assert all(isinstance(item, MockModel) for item in result)
|
||||||
mock_client.get.assert_called_once_with('mocks', params=None)
|
mock_client.get.assert_called_once_with("mocks", params=None)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_all_items_convenience(self, base_service, mock_client):
|
async def test_get_all_items_convenience(self, base_service, mock_client):
|
||||||
"""Test get_all_items convenience method."""
|
"""Test get_all_items convenience method."""
|
||||||
mock_data = {
|
mock_data = {"count": 1, "mocks": [{"id": 1, "name": "Test", "value": 100}]}
|
||||||
'count': 1,
|
|
||||||
'mocks': [{'id': 1, 'name': 'Test', 'value': 100}]
|
|
||||||
}
|
|
||||||
mock_client.get.return_value = mock_data
|
mock_client.get.return_value = mock_data
|
||||||
|
|
||||||
result = await base_service.get_all_items()
|
result = await base_service.get_all_items()
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert isinstance(result[0], MockModel)
|
assert isinstance(result[0], MockModel)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_success(self, base_service, mock_client):
|
async def test_create_success(self, base_service, mock_client):
|
||||||
"""Test successful object creation."""
|
"""Test successful object creation."""
|
||||||
input_data = {'name': 'New Item', 'value': 300}
|
input_data = {"name": "New Item", "value": 300}
|
||||||
response_data = {'id': 3, 'name': 'New Item', 'value': 300}
|
response_data = {"id": 3, "name": "New Item", "value": 300}
|
||||||
mock_client.post.return_value = response_data
|
mock_client.post.return_value = response_data
|
||||||
|
|
||||||
result = await base_service.create(input_data)
|
result = await base_service.create(input_data)
|
||||||
|
|
||||||
assert isinstance(result, MockModel)
|
assert isinstance(result, MockModel)
|
||||||
assert result.id == 3
|
assert result.id == 3
|
||||||
assert result.name == 'New Item'
|
assert result.name == "New Item"
|
||||||
mock_client.post.assert_called_once_with('mocks', input_data)
|
mock_client.post.assert_called_once_with("mocks/", input_data)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_success(self, base_service, mock_client):
|
async def test_update_success(self, base_service, mock_client):
|
||||||
"""Test successful object update."""
|
"""Test successful object update."""
|
||||||
update_data = {'name': 'Updated'}
|
update_data = {"name": "Updated"}
|
||||||
response_data = {'id': 1, 'name': 'Updated', 'value': 100}
|
response_data = {"id": 1, "name": "Updated", "value": 100}
|
||||||
mock_client.put.return_value = response_data
|
mock_client.put.return_value = response_data
|
||||||
|
|
||||||
result = await base_service.update(1, update_data)
|
result = await base_service.update(1, update_data)
|
||||||
|
|
||||||
assert isinstance(result, MockModel)
|
assert isinstance(result, MockModel)
|
||||||
assert result.name == 'Updated'
|
assert result.name == "Updated"
|
||||||
mock_client.put.assert_called_once_with('mocks', update_data, object_id=1)
|
mock_client.put.assert_called_once_with("mocks", update_data, object_id=1)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_success(self, base_service, mock_client):
|
async def test_delete_success(self, base_service, mock_client):
|
||||||
"""Test successful object deletion."""
|
"""Test successful object deletion."""
|
||||||
mock_client.delete.return_value = True
|
mock_client.delete.return_value = True
|
||||||
|
|
||||||
result = await base_service.delete(1)
|
result = await base_service.delete(1)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
mock_client.delete.assert_called_once_with('mocks', object_id=1)
|
mock_client.delete.assert_called_once_with("mocks", object_id=1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_by_field(self, base_service, mock_client):
|
async def test_get_by_field(self, base_service, mock_client):
|
||||||
"""Test get_by_field functionality."""
|
"""Test get_by_field functionality."""
|
||||||
mock_data = {
|
mock_data = {"count": 1, "mocks": [{"id": 1, "name": "Test", "value": 100}]}
|
||||||
'count': 1,
|
|
||||||
'mocks': [{'id': 1, 'name': 'Test', 'value': 100}]
|
|
||||||
}
|
|
||||||
mock_client.get.return_value = mock_data
|
mock_client.get.return_value = mock_data
|
||||||
|
|
||||||
result = await base_service.get_by_field('name', 'Test')
|
result = await base_service.get_by_field("name", "Test")
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
mock_client.get.assert_called_once_with('mocks', params=[('name', 'Test')])
|
mock_client.get.assert_called_once_with("mocks", params=[("name", "Test")])
|
||||||
|
|
||||||
def test_extract_items_and_count_standard_format(self, base_service):
|
def test_extract_items_and_count_standard_format(self, base_service):
|
||||||
"""Test response parsing for standard format."""
|
"""Test response parsing for standard format."""
|
||||||
data = {
|
data = {
|
||||||
'count': 3,
|
"count": 3,
|
||||||
'mocks': [
|
"mocks": [
|
||||||
{'id': 1, 'name': 'Test1'},
|
{"id": 1, "name": "Test1"},
|
||||||
{'id': 2, 'name': 'Test2'},
|
{"id": 2, "name": "Test2"},
|
||||||
{'id': 3, 'name': 'Test3'}
|
{"id": 3, "name": "Test3"},
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
items, count = base_service._extract_items_and_count_from_response(data)
|
items, count = base_service._extract_items_and_count_from_response(data)
|
||||||
|
|
||||||
assert len(items) == 3
|
assert len(items) == 3
|
||||||
assert count == 3
|
assert count == 3
|
||||||
assert items[0]['name'] == 'Test1'
|
assert items[0]["name"] == "Test1"
|
||||||
|
|
||||||
def test_extract_items_and_count_single_object(self, base_service):
|
def test_extract_items_and_count_single_object(self, base_service):
|
||||||
"""Test response parsing for single object."""
|
"""Test response parsing for single object."""
|
||||||
data = {'id': 1, 'name': 'Single'}
|
data = {"id": 1, "name": "Single"}
|
||||||
|
|
||||||
items, count = base_service._extract_items_and_count_from_response(data)
|
items, count = base_service._extract_items_and_count_from_response(data)
|
||||||
|
|
||||||
assert len(items) == 1
|
assert len(items) == 1
|
||||||
assert count == 1
|
assert count == 1
|
||||||
assert items[0] == data
|
assert items[0] == data
|
||||||
|
|
||||||
def test_extract_items_and_count_direct_list(self, base_service):
|
def test_extract_items_and_count_direct_list(self, base_service):
|
||||||
"""Test response parsing for direct list."""
|
"""Test response parsing for direct list."""
|
||||||
data = [
|
data = [{"id": 1, "name": "Test1"}, {"id": 2, "name": "Test2"}]
|
||||||
{'id': 1, 'name': 'Test1'},
|
|
||||||
{'id': 2, 'name': 'Test2'}
|
|
||||||
]
|
|
||||||
|
|
||||||
items, count = base_service._extract_items_and_count_from_response(data)
|
items, count = base_service._extract_items_and_count_from_response(data)
|
||||||
|
|
||||||
assert len(items) == 2
|
assert len(items) == 2
|
||||||
assert count == 2
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
class TestBaseServiceExtras:
|
class TestBaseServiceExtras:
|
||||||
"""Additional coverage tests for BaseService edge cases."""
|
"""Additional coverage tests for BaseService edge cases."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_base_service_additional_methods(self):
|
async def test_base_service_additional_methods(self):
|
||||||
"""Test additional BaseService methods for coverage."""
|
"""Test additional BaseService methods for coverage."""
|
||||||
from services.base_service import BaseService
|
from services.base_service import BaseService
|
||||||
from models.base import SBABaseModel
|
from models.base import SBABaseModel
|
||||||
|
|
||||||
class TestModel(SBABaseModel):
|
class TestModel(SBABaseModel):
|
||||||
name: str
|
name: str
|
||||||
value: int = 100
|
value: int = 100
|
||||||
|
|
||||||
mock_client = AsyncMock()
|
mock_client = AsyncMock()
|
||||||
service = BaseService(TestModel, 'test', client=mock_client)
|
service = BaseService(TestModel, "test", client=mock_client)
|
||||||
|
|
||||||
|
|
||||||
# Test count method
|
# Test count method
|
||||||
mock_client.reset_mock()
|
mock_client.reset_mock()
|
||||||
mock_client.get.return_value = {'count': 42, 'test': []}
|
mock_client.get.return_value = {"count": 42, "test": []}
|
||||||
count = await service.count(params=[('active', 'true')])
|
count = await service.count(params=[("active", "true")])
|
||||||
assert count == 42
|
assert count == 42
|
||||||
|
|
||||||
# Test update_from_model with ID
|
# Test update_from_model with ID
|
||||||
mock_client.reset_mock()
|
mock_client.reset_mock()
|
||||||
model = TestModel(id=1, name="Updated", value=300)
|
model = TestModel(id=1, name="Updated", value=300)
|
||||||
mock_client.put.return_value = {"id": 1, "name": "Updated", "value": 300}
|
mock_client.put.return_value = {"id": 1, "name": "Updated", "value": 300}
|
||||||
result = await service.update_from_model(model)
|
result = await service.update_from_model(model)
|
||||||
assert result.name == "Updated"
|
assert result.name == "Updated"
|
||||||
|
|
||||||
# Test update_from_model without ID
|
# Test update_from_model without ID
|
||||||
model_no_id = TestModel(name="Test")
|
model_no_id = TestModel(name="Test")
|
||||||
with pytest.raises(ValueError, match="Cannot update TestModel without ID"):
|
with pytest.raises(ValueError, match="Cannot update TestModel without ID"):
|
||||||
await service.update_from_model(model_no_id)
|
await service.update_from_model(model_no_id)
|
||||||
|
|
||||||
def test_base_service_response_parsing_edge_cases(self):
|
def test_base_service_response_parsing_edge_cases(self):
|
||||||
"""Test edge cases in response parsing."""
|
"""Test edge cases in response parsing."""
|
||||||
from services.base_service import BaseService
|
from services.base_service import BaseService
|
||||||
from models.base import SBABaseModel
|
from models.base import SBABaseModel
|
||||||
|
|
||||||
class TestModel(SBABaseModel):
|
class TestModel(SBABaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
service = BaseService(TestModel, 'test')
|
service = BaseService(TestModel, "test")
|
||||||
|
|
||||||
# Test with 'items' field
|
# Test with 'items' field
|
||||||
data = {'count': 2, 'items': [{'name': 'Item1'}, {'name': 'Item2'}]}
|
data = {"count": 2, "items": [{"name": "Item1"}, {"name": "Item2"}]}
|
||||||
items, count = service._extract_items_and_count_from_response(data)
|
items, count = service._extract_items_and_count_from_response(data)
|
||||||
assert len(items) == 2
|
assert len(items) == 2
|
||||||
assert count == 2
|
assert count == 2
|
||||||
|
|
||||||
# Test with 'data' field
|
# Test with 'data' field
|
||||||
data = {'count': 1, 'data': [{'name': 'DataItem'}]}
|
data = {"count": 1, "data": [{"name": "DataItem"}]}
|
||||||
items, count = service._extract_items_and_count_from_response(data)
|
items, count = service._extract_items_and_count_from_response(data)
|
||||||
assert len(items) == 1
|
assert len(items) == 1
|
||||||
assert count == 1
|
assert count == 1
|
||||||
|
|
||||||
# Test with count but no recognizable list field
|
# Test with count but no recognizable list field
|
||||||
data = {'count': 5, 'unknown_field': [{'name': 'Item'}]}
|
data = {"count": 5, "unknown_field": [{"name": "Item"}]}
|
||||||
items, count = service._extract_items_and_count_from_response(data)
|
items, count = service._extract_items_and_count_from_response(data)
|
||||||
assert len(items) == 0
|
assert len(items) == 0
|
||||||
assert count == 5
|
assert count == 5
|
||||||
|
|
||||||
# Test with unexpected data type
|
# Test with unexpected data type
|
||||||
items, count = service._extract_items_and_count_from_response("unexpected")
|
items, count = service._extract_items_and_count_from_response("unexpected")
|
||||||
assert len(items) == 0
|
assert len(items) == 0
|
||||||
assert count == 0
|
assert count == 0
|
||||||
|
|||||||
@ -3,6 +3,7 @@ Interactive Trade Embed Views
|
|||||||
|
|
||||||
Handles the Discord embed and button interfaces for the multi-team trade builder.
|
Handles the Discord embed and button interfaces for the multi-team trade builder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
@ -31,60 +32,56 @@ class TradeEmbedView(discord.ui.View):
|
|||||||
"""Check if user has permission to interact with this view."""
|
"""Check if user has permission to interact with this view."""
|
||||||
if interaction.user.id != self.user_id:
|
if interaction.user.id != self.user_id:
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
"❌ You don't have permission to use this trade builder.",
|
"You don't have permission to use this trade builder.",
|
||||||
ephemeral=True
|
ephemeral=True,
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def on_timeout(self) -> None:
|
async def on_timeout(self) -> None:
|
||||||
"""Handle view timeout."""
|
"""Handle view timeout."""
|
||||||
# Disable all buttons when timeout occurs
|
|
||||||
for item in self.children:
|
for item in self.children:
|
||||||
if isinstance(item, discord.ui.Button):
|
if isinstance(item, discord.ui.Button):
|
||||||
item.disabled = True
|
item.disabled = True
|
||||||
|
|
||||||
@discord.ui.button(label="Remove Move", style=discord.ButtonStyle.red, emoji="➖")
|
@discord.ui.button(label="Remove Move", style=discord.ButtonStyle.red)
|
||||||
async def remove_move_button(self, interaction: discord.Interaction, button: discord.ui.Button):
|
async def remove_move_button(
|
||||||
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||||
|
):
|
||||||
"""Handle remove move button click."""
|
"""Handle remove move button click."""
|
||||||
if self.builder.is_empty:
|
if self.builder.is_empty:
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
"❌ No moves to remove. Add some moves first!",
|
"No moves to remove. Add some moves first!", ephemeral=True
|
||||||
ephemeral=True
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create select menu for move removal
|
|
||||||
select_view = RemoveTradeMovesView(self.builder, self.user_id)
|
select_view = RemoveTradeMovesView(self.builder, self.user_id)
|
||||||
embed = await create_trade_embed(self.builder)
|
embed = await create_trade_embed(self.builder)
|
||||||
|
|
||||||
await interaction.response.edit_message(embed=embed, view=select_view)
|
await interaction.response.edit_message(embed=embed, view=select_view)
|
||||||
|
|
||||||
@discord.ui.button(label="Validate Trade", style=discord.ButtonStyle.secondary, emoji="🔍")
|
@discord.ui.button(label="Validate Trade", style=discord.ButtonStyle.secondary)
|
||||||
async def validate_button(self, interaction: discord.Interaction, button: discord.ui.Button):
|
async def validate_button(
|
||||||
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||||
|
):
|
||||||
"""Handle validate trade button click."""
|
"""Handle validate trade button click."""
|
||||||
await interaction.response.defer(ephemeral=True)
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
|
||||||
# Perform detailed validation
|
|
||||||
validation = await self.builder.validate_trade()
|
validation = await self.builder.validate_trade()
|
||||||
|
|
||||||
# Create validation report
|
|
||||||
if validation.is_legal:
|
if validation.is_legal:
|
||||||
status_emoji = "✅"
|
|
||||||
status_text = "**Trade is LEGAL**"
|
status_text = "**Trade is LEGAL**"
|
||||||
color = EmbedColors.SUCCESS
|
color = EmbedColors.SUCCESS
|
||||||
else:
|
else:
|
||||||
status_emoji = "❌"
|
|
||||||
status_text = "**Trade has ERRORS**"
|
status_text = "**Trade has ERRORS**"
|
||||||
color = EmbedColors.ERROR
|
color = EmbedColors.ERROR
|
||||||
|
|
||||||
embed = EmbedTemplate.create_base_embed(
|
embed = EmbedTemplate.create_base_embed(
|
||||||
title=f"{status_emoji} Trade Validation Report",
|
title="Trade Validation Report",
|
||||||
description=status_text,
|
description=status_text,
|
||||||
color=color
|
color=color,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add team-by-team validation
|
|
||||||
for participant in self.builder.trade.participants:
|
for participant in self.builder.trade.participants:
|
||||||
team_validation = validation.get_participant_validation(participant.team.id)
|
team_validation = validation.get_participant_validation(participant.team.id)
|
||||||
if team_validation:
|
if team_validation:
|
||||||
@ -98,72 +95,65 @@ class TradeEmbedView(discord.ui.View):
|
|||||||
team_status.append(team_validation.pre_existing_transactions_note)
|
team_status.append(team_validation.pre_existing_transactions_note)
|
||||||
|
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=f"🏟️ {participant.team.abbrev} - {participant.team.sname}",
|
name=f"{participant.team.abbrev} - {participant.team.sname}",
|
||||||
value="\n".join(team_status),
|
value="\n".join(team_status),
|
||||||
inline=False
|
inline=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add overall errors and suggestions
|
|
||||||
if validation.all_errors:
|
if validation.all_errors:
|
||||||
error_text = "\n".join([f"• {error}" for error in validation.all_errors])
|
error_text = "\n".join([f"- {error}" for error in validation.all_errors])
|
||||||
embed.add_field(
|
embed.add_field(name="Errors", value=error_text, inline=False)
|
||||||
name="❌ Errors",
|
|
||||||
value=error_text,
|
|
||||||
inline=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if validation.all_suggestions:
|
if validation.all_suggestions:
|
||||||
suggestion_text = "\n".join([f"💡 {suggestion}" for suggestion in validation.all_suggestions])
|
suggestion_text = "\n".join(
|
||||||
embed.add_field(
|
[f"- {suggestion}" for suggestion in validation.all_suggestions]
|
||||||
name="💡 Suggestions",
|
|
||||||
value=suggestion_text,
|
|
||||||
inline=False
|
|
||||||
)
|
)
|
||||||
|
embed.add_field(name="Suggestions", value=suggestion_text, inline=False)
|
||||||
|
|
||||||
await interaction.followup.send(embed=embed, ephemeral=True)
|
await interaction.followup.send(embed=embed, ephemeral=True)
|
||||||
|
|
||||||
@discord.ui.button(label="Submit Trade", style=discord.ButtonStyle.primary, emoji="📤")
|
@discord.ui.button(label="Submit Trade", style=discord.ButtonStyle.primary)
|
||||||
async def submit_button(self, interaction: discord.Interaction, button: discord.ui.Button):
|
async def submit_button(
|
||||||
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||||
|
):
|
||||||
"""Handle submit trade button click."""
|
"""Handle submit trade button click."""
|
||||||
if self.builder.is_empty:
|
if self.builder.is_empty:
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
"❌ Cannot submit empty trade. Add some moves first!",
|
"Cannot submit empty trade. Add some moves first!", ephemeral=True
|
||||||
ephemeral=True
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Validate before submission
|
|
||||||
validation = await self.builder.validate_trade()
|
validation = await self.builder.validate_trade()
|
||||||
if not validation.is_legal:
|
if not validation.is_legal:
|
||||||
error_msg = "❌ **Cannot submit illegal trade:**\n"
|
error_msg = "**Cannot submit illegal trade:**\n"
|
||||||
error_msg += "\n".join([f"• {error}" for error in validation.all_errors])
|
error_msg += "\n".join([f"- {error}" for error in validation.all_errors])
|
||||||
|
|
||||||
if validation.all_suggestions:
|
if validation.all_suggestions:
|
||||||
error_msg += "\n\n**Suggestions:**\n"
|
error_msg += "\n\n**Suggestions:**\n"
|
||||||
error_msg += "\n".join([f"💡 {suggestion}" for suggestion in validation.all_suggestions])
|
error_msg += "\n".join(
|
||||||
|
[f"- {suggestion}" for suggestion in validation.all_suggestions]
|
||||||
|
)
|
||||||
|
|
||||||
await interaction.response.send_message(error_msg, ephemeral=True)
|
await interaction.response.send_message(error_msg, ephemeral=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Show confirmation modal
|
|
||||||
modal = SubmitTradeConfirmationModal(self.builder)
|
modal = SubmitTradeConfirmationModal(self.builder)
|
||||||
await interaction.response.send_modal(modal)
|
await interaction.response.send_modal(modal)
|
||||||
|
|
||||||
@discord.ui.button(label="Cancel Trade", style=discord.ButtonStyle.secondary, emoji="❌")
|
@discord.ui.button(label="Cancel Trade", style=discord.ButtonStyle.secondary)
|
||||||
async def cancel_button(self, interaction: discord.Interaction, button: discord.ui.Button):
|
async def cancel_button(
|
||||||
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||||
|
):
|
||||||
"""Handle cancel trade button click."""
|
"""Handle cancel trade button click."""
|
||||||
self.builder.clear_trade()
|
self.builder.clear_trade()
|
||||||
embed = await create_trade_embed(self.builder)
|
embed = await create_trade_embed(self.builder)
|
||||||
|
|
||||||
# Disable all buttons after cancellation
|
|
||||||
for item in self.children:
|
for item in self.children:
|
||||||
if isinstance(item, discord.ui.Button):
|
if isinstance(item, discord.ui.Button):
|
||||||
item.disabled = True
|
item.disabled = True
|
||||||
|
|
||||||
await interaction.response.edit_message(
|
await interaction.response.edit_message(
|
||||||
content="❌ **Trade cancelled and cleared.**",
|
content="**Trade cancelled and cleared.**", embed=embed, view=self
|
||||||
embed=embed,
|
|
||||||
view=self
|
|
||||||
)
|
)
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
@ -176,12 +166,12 @@ class RemoveTradeMovesView(discord.ui.View):
|
|||||||
self.builder = builder
|
self.builder = builder
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
|
|
||||||
# Create select menu with current moves
|
|
||||||
if not builder.is_empty:
|
if not builder.is_empty:
|
||||||
self.add_item(RemoveTradeMovesSelect(builder))
|
self.add_item(RemoveTradeMovesSelect(builder))
|
||||||
|
|
||||||
# Add back button
|
back_button = discord.ui.Button(
|
||||||
back_button = discord.ui.Button(label="Back", style=discord.ButtonStyle.secondary, emoji="⬅️")
|
label="Back", style=discord.ButtonStyle.secondary
|
||||||
|
)
|
||||||
back_button.callback = self.back_callback
|
back_button.callback = self.back_callback
|
||||||
self.add_item(back_button)
|
self.add_item(back_button)
|
||||||
|
|
||||||
@ -202,35 +192,36 @@ class RemoveTradeMovesSelect(discord.ui.Select):
|
|||||||
def __init__(self, builder: TradeBuilder):
|
def __init__(self, builder: TradeBuilder):
|
||||||
self.builder = builder
|
self.builder = builder
|
||||||
|
|
||||||
# Create options from all moves (cross-team and supplementary)
|
|
||||||
options = []
|
options = []
|
||||||
move_count = 0
|
move_count = 0
|
||||||
|
|
||||||
# Add cross-team moves
|
for move in builder.trade.cross_team_moves[
|
||||||
for move in builder.trade.cross_team_moves[:20]: # Limit to avoid Discord's 25 option limit
|
:20
|
||||||
options.append(discord.SelectOption(
|
]: # Limit to avoid Discord's 25 option limit
|
||||||
label=f"{move.player.name}",
|
options.append(
|
||||||
description=move.description[:100], # Discord description limit
|
discord.SelectOption(
|
||||||
value=str(move.player.id),
|
label=f"{move.player.name}",
|
||||||
emoji="🔄"
|
description=move.description[:100],
|
||||||
))
|
value=str(move.player.id),
|
||||||
|
)
|
||||||
|
)
|
||||||
move_count += 1
|
move_count += 1
|
||||||
|
|
||||||
# Add supplementary moves if there's room
|
|
||||||
remaining_slots = 25 - move_count
|
remaining_slots = 25 - move_count
|
||||||
for move in builder.trade.supplementary_moves[:remaining_slots]:
|
for move in builder.trade.supplementary_moves[:remaining_slots]:
|
||||||
options.append(discord.SelectOption(
|
options.append(
|
||||||
label=f"{move.player.name}",
|
discord.SelectOption(
|
||||||
description=move.description[:100],
|
label=f"{move.player.name}",
|
||||||
value=str(move.player.id),
|
description=move.description[:100],
|
||||||
emoji="⚙️"
|
value=str(move.player.id),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
placeholder="Select a move to remove...",
|
placeholder="Select a move to remove...",
|
||||||
min_values=1,
|
min_values=1,
|
||||||
max_values=1,
|
max_values=1,
|
||||||
options=options
|
options=options,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def callback(self, interaction: discord.Interaction):
|
async def callback(self, interaction: discord.Interaction):
|
||||||
@ -241,27 +232,25 @@ class RemoveTradeMovesSelect(discord.ui.Select):
|
|||||||
|
|
||||||
if success:
|
if success:
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
f"✅ Removed move for player ID {player_id}",
|
f"Removed move for player ID {player_id}", ephemeral=True
|
||||||
ephemeral=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the embed
|
|
||||||
main_view = TradeEmbedView(self.builder, interaction.user.id)
|
main_view = TradeEmbedView(self.builder, interaction.user.id)
|
||||||
embed = await create_trade_embed(self.builder)
|
embed = await create_trade_embed(self.builder)
|
||||||
|
|
||||||
# Edit the original message
|
|
||||||
await interaction.edit_original_response(embed=embed, view=main_view)
|
await interaction.edit_original_response(embed=embed, view=main_view)
|
||||||
else:
|
else:
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
f"❌ Could not remove move: {error_msg}",
|
f"Could not remove move: {error_msg}", ephemeral=True
|
||||||
ephemeral=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SubmitTradeConfirmationModal(discord.ui.Modal):
|
class SubmitTradeConfirmationModal(discord.ui.Modal):
|
||||||
"""Modal for confirming trade submission - posts acceptance request to trade channel."""
|
"""Modal for confirming trade submission - posts acceptance request to trade channel."""
|
||||||
|
|
||||||
def __init__(self, builder: TradeBuilder, trade_channel: Optional[discord.TextChannel] = None):
|
def __init__(
|
||||||
|
self, builder: TradeBuilder, trade_channel: Optional[discord.TextChannel] = None
|
||||||
|
):
|
||||||
super().__init__(title="Confirm Trade Submission")
|
super().__init__(title="Confirm Trade Submission")
|
||||||
self.builder = builder
|
self.builder = builder
|
||||||
self.trade_channel = trade_channel
|
self.trade_channel = trade_channel
|
||||||
@ -270,7 +259,7 @@ class SubmitTradeConfirmationModal(discord.ui.Modal):
|
|||||||
label="Type 'CONFIRM' to submit for approval",
|
label="Type 'CONFIRM' to submit for approval",
|
||||||
placeholder="CONFIRM",
|
placeholder="CONFIRM",
|
||||||
required=True,
|
required=True,
|
||||||
max_length=7
|
max_length=7,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.add_item(self.confirmation)
|
self.add_item(self.confirmation)
|
||||||
@ -279,56 +268,52 @@ class SubmitTradeConfirmationModal(discord.ui.Modal):
|
|||||||
"""Handle confirmation submission - posts acceptance view to trade channel."""
|
"""Handle confirmation submission - posts acceptance view to trade channel."""
|
||||||
if self.confirmation.value.upper() != "CONFIRM":
|
if self.confirmation.value.upper() != "CONFIRM":
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
"❌ Trade not submitted. You must type 'CONFIRM' exactly.",
|
"Trade not submitted. You must type 'CONFIRM' exactly.",
|
||||||
ephemeral=True
|
ephemeral=True,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
await interaction.response.defer(ephemeral=True)
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Update trade status to PROPOSED
|
|
||||||
from models.trade import TradeStatus
|
from models.trade import TradeStatus
|
||||||
|
|
||||||
self.builder.trade.status = TradeStatus.PROPOSED
|
self.builder.trade.status = TradeStatus.PROPOSED
|
||||||
|
|
||||||
# Create acceptance embed and view
|
|
||||||
acceptance_embed = await create_trade_acceptance_embed(self.builder)
|
acceptance_embed = await create_trade_acceptance_embed(self.builder)
|
||||||
acceptance_view = TradeAcceptanceView(self.builder)
|
acceptance_view = TradeAcceptanceView(self.builder)
|
||||||
|
|
||||||
# Find the trade channel to post to
|
|
||||||
channel = self.trade_channel
|
channel = self.trade_channel
|
||||||
if not channel:
|
if not channel:
|
||||||
# Try to find trade channel by name pattern
|
|
||||||
trade_channel_name = f"trade-{'-'.join(t.abbrev.lower() for t in self.builder.participating_teams)}"
|
|
||||||
for ch in interaction.guild.text_channels: # type: ignore
|
for ch in interaction.guild.text_channels: # type: ignore
|
||||||
if ch.name.startswith("trade-") and self.builder.trade_id[:4] in ch.name:
|
if (
|
||||||
|
ch.name.startswith("trade-")
|
||||||
|
and self.builder.trade_id[:4] in ch.name
|
||||||
|
):
|
||||||
channel = ch
|
channel = ch
|
||||||
break
|
break
|
||||||
|
|
||||||
if channel:
|
if channel:
|
||||||
# Post acceptance request to trade channel
|
|
||||||
await channel.send(
|
await channel.send(
|
||||||
content="📋 **Trade submitted for approval!** All teams must accept to complete the trade.",
|
content="**Trade submitted for approval.** All teams must accept to complete the trade.",
|
||||||
embed=acceptance_embed,
|
embed=acceptance_embed,
|
||||||
view=acceptance_view
|
view=acceptance_view,
|
||||||
)
|
)
|
||||||
await interaction.followup.send(
|
await interaction.followup.send(
|
||||||
f"✅ Trade submitted for approval!\n\nThe acceptance request has been posted to {channel.mention}.\n"
|
f"Trade submitted for approval.\n\nThe acceptance request has been posted to {channel.mention}.\n"
|
||||||
f"All participating teams must click **Accept Trade** to finalize.",
|
f"All participating teams must click **Accept Trade** to finalize.",
|
||||||
ephemeral=True
|
ephemeral=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# No trade channel found, post in current channel
|
|
||||||
await interaction.followup.send(
|
await interaction.followup.send(
|
||||||
content="📋 **Trade submitted for approval!** All teams must accept to complete the trade.",
|
content="**Trade submitted for approval.** All teams must accept to complete the trade.",
|
||||||
embed=acceptance_embed,
|
embed=acceptance_embed,
|
||||||
view=acceptance_view
|
view=acceptance_view,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await interaction.followup.send(
|
await interaction.followup.send(
|
||||||
f"❌ Error submitting trade: {str(e)}",
|
f"Error submitting trade: {str(e)}", ephemeral=True
|
||||||
ephemeral=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -343,8 +328,11 @@ class TradeAcceptanceView(discord.ui.View):
|
|||||||
"""Get the team owned by the interacting user."""
|
"""Get the team owned by the interacting user."""
|
||||||
from services.team_service import team_service
|
from services.team_service import team_service
|
||||||
from config import get_config
|
from config import get_config
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
return await team_service.get_team_by_owner(interaction.user.id, config.sba_season)
|
return await team_service.get_team_by_owner(
|
||||||
|
interaction.user.id, config.sba_season
|
||||||
|
)
|
||||||
|
|
||||||
async def interaction_check(self, interaction: discord.Interaction) -> bool:
|
async def interaction_check(self, interaction: discord.Interaction) -> bool:
|
||||||
"""Check if user is a GM of a participating team."""
|
"""Check if user is a GM of a participating team."""
|
||||||
@ -352,17 +340,14 @@ class TradeAcceptanceView(discord.ui.View):
|
|||||||
|
|
||||||
if not user_team:
|
if not user_team:
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
"❌ You don't own a team in the league.",
|
"You don't own a team in the league.", ephemeral=True
|
||||||
ephemeral=True
|
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if their team (or organization) is participating
|
|
||||||
participant = self.builder.trade.get_participant_by_organization(user_team)
|
participant = self.builder.trade.get_participant_by_organization(user_team)
|
||||||
if not participant:
|
if not participant:
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
"❌ Your team is not part of this trade.",
|
"Your team is not part of this trade.", ephemeral=True
|
||||||
ephemeral=True
|
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -374,47 +359,45 @@ class TradeAcceptanceView(discord.ui.View):
|
|||||||
if isinstance(item, discord.ui.Button):
|
if isinstance(item, discord.ui.Button):
|
||||||
item.disabled = True
|
item.disabled = True
|
||||||
|
|
||||||
@discord.ui.button(label="Accept Trade", style=discord.ButtonStyle.success, emoji="✅")
|
@discord.ui.button(label="Accept Trade", style=discord.ButtonStyle.success)
|
||||||
async def accept_button(self, interaction: discord.Interaction, button: discord.ui.Button):
|
async def accept_button(
|
||||||
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||||
|
):
|
||||||
"""Handle accept button click."""
|
"""Handle accept button click."""
|
||||||
user_team = await self._get_user_team(interaction)
|
user_team = await self._get_user_team(interaction)
|
||||||
if not user_team:
|
if not user_team:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Find the participating team (could be org affiliate)
|
|
||||||
participant = self.builder.trade.get_participant_by_organization(user_team)
|
participant = self.builder.trade.get_participant_by_organization(user_team)
|
||||||
if not participant:
|
if not participant:
|
||||||
return
|
return
|
||||||
|
|
||||||
team_id = participant.team.id
|
team_id = participant.team.id
|
||||||
|
|
||||||
# Check if already accepted
|
|
||||||
if self.builder.has_team_accepted(team_id):
|
if self.builder.has_team_accepted(team_id):
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
f"✅ {participant.team.abbrev} has already accepted this trade.",
|
f"{participant.team.abbrev} has already accepted this trade.",
|
||||||
ephemeral=True
|
ephemeral=True,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Record acceptance
|
|
||||||
all_accepted = self.builder.accept_trade(team_id)
|
all_accepted = self.builder.accept_trade(team_id)
|
||||||
|
|
||||||
if all_accepted:
|
if all_accepted:
|
||||||
# All teams accepted - finalize the trade
|
|
||||||
await self._finalize_trade(interaction)
|
await self._finalize_trade(interaction)
|
||||||
else:
|
else:
|
||||||
# Update embed to show new acceptance status
|
|
||||||
embed = await create_trade_acceptance_embed(self.builder)
|
embed = await create_trade_acceptance_embed(self.builder)
|
||||||
await interaction.response.edit_message(embed=embed, view=self)
|
await interaction.response.edit_message(embed=embed, view=self)
|
||||||
|
|
||||||
# Send confirmation to channel
|
|
||||||
await interaction.followup.send(
|
await interaction.followup.send(
|
||||||
f"✅ **{participant.team.abbrev}** has accepted the trade! "
|
f"**{participant.team.abbrev}** has accepted the trade. "
|
||||||
f"({len(self.builder.accepted_teams)}/{self.builder.team_count} teams)"
|
f"({len(self.builder.accepted_teams)}/{self.builder.team_count} teams)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@discord.ui.button(label="Reject Trade", style=discord.ButtonStyle.danger, emoji="❌")
|
@discord.ui.button(label="Reject Trade", style=discord.ButtonStyle.danger)
|
||||||
async def reject_button(self, interaction: discord.Interaction, button: discord.ui.Button):
|
async def reject_button(
|
||||||
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||||
|
):
|
||||||
"""Handle reject button click - moves trade back to DRAFT."""
|
"""Handle reject button click - moves trade back to DRAFT."""
|
||||||
user_team = await self._get_user_team(interaction)
|
user_team = await self._get_user_team(interaction)
|
||||||
if not user_team:
|
if not user_team:
|
||||||
@ -424,20 +407,16 @@ class TradeAcceptanceView(discord.ui.View):
|
|||||||
if not participant:
|
if not participant:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Reject the trade
|
|
||||||
self.builder.reject_trade()
|
self.builder.reject_trade()
|
||||||
|
|
||||||
# Disable buttons
|
|
||||||
self.accept_button.disabled = True
|
self.accept_button.disabled = True
|
||||||
self.reject_button.disabled = True
|
self.reject_button.disabled = True
|
||||||
|
|
||||||
# Update embed to show rejection
|
|
||||||
embed = await create_trade_rejection_embed(self.builder, participant.team)
|
embed = await create_trade_rejection_embed(self.builder, participant.team)
|
||||||
await interaction.response.edit_message(embed=embed, view=self)
|
await interaction.response.edit_message(embed=embed, view=self)
|
||||||
|
|
||||||
# Notify the channel
|
|
||||||
await interaction.followup.send(
|
await interaction.followup.send(
|
||||||
f"❌ **{participant.team.abbrev}** has rejected the trade.\n\n"
|
f"**{participant.team.abbrev}** has rejected the trade.\n\n"
|
||||||
f"The trade has been moved back to **DRAFT** status. "
|
f"The trade has been moved back to **DRAFT** status. "
|
||||||
f"Teams can continue negotiating using `/trade` commands."
|
f"Teams can continue negotiating using `/trade` commands."
|
||||||
)
|
)
|
||||||
@ -459,41 +438,52 @@ class TradeAcceptanceView(discord.ui.View):
|
|||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
|
||||||
# Get next week for transactions
|
|
||||||
current = await league_service.get_current_state()
|
current = await league_service.get_current_state()
|
||||||
next_week = current.week + 1 if current else 1
|
next_week = current.week + 1 if current else 1
|
||||||
|
|
||||||
# Create FA team for reference
|
|
||||||
fa_team = Team(
|
fa_team = Team(
|
||||||
id=config.free_agent_team_id,
|
id=config.free_agent_team_id,
|
||||||
abbrev="FA",
|
abbrev="FA",
|
||||||
sname="Free Agents",
|
sname="Free Agents",
|
||||||
lname="Free Agency",
|
lname="Free Agency",
|
||||||
season=self.builder.trade.season
|
season=self.builder.trade.season,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
# Create transactions from all moves
|
|
||||||
transactions: List[Transaction] = []
|
transactions: List[Transaction] = []
|
||||||
move_id = f"Trade-{self.builder.trade_id}-{int(datetime.now(timezone.utc).timestamp())}"
|
move_id = f"Trade-{self.builder.trade_id}-{int(datetime.now(timezone.utc).timestamp())}"
|
||||||
|
|
||||||
# Process cross-team moves
|
|
||||||
for move in self.builder.trade.cross_team_moves:
|
for move in self.builder.trade.cross_team_moves:
|
||||||
# Get actual team affiliates for from/to based on roster type
|
|
||||||
if move.from_roster == RosterType.MAJOR_LEAGUE:
|
if move.from_roster == RosterType.MAJOR_LEAGUE:
|
||||||
old_team = move.source_team
|
old_team = move.source_team
|
||||||
elif move.from_roster == RosterType.MINOR_LEAGUE:
|
elif move.from_roster == RosterType.MINOR_LEAGUE:
|
||||||
old_team = await move.source_team.minor_league_affiliate() if move.source_team else None
|
old_team = (
|
||||||
|
await move.source_team.minor_league_affiliate()
|
||||||
|
if move.source_team
|
||||||
|
else None
|
||||||
|
)
|
||||||
elif move.from_roster == RosterType.INJURED_LIST:
|
elif move.from_roster == RosterType.INJURED_LIST:
|
||||||
old_team = await move.source_team.injured_list_affiliate() if move.source_team else None
|
old_team = (
|
||||||
|
await move.source_team.injured_list_affiliate()
|
||||||
|
if move.source_team
|
||||||
|
else None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
old_team = move.source_team
|
old_team = move.source_team
|
||||||
|
|
||||||
if move.to_roster == RosterType.MAJOR_LEAGUE:
|
if move.to_roster == RosterType.MAJOR_LEAGUE:
|
||||||
new_team = move.destination_team
|
new_team = move.destination_team
|
||||||
elif move.to_roster == RosterType.MINOR_LEAGUE:
|
elif move.to_roster == RosterType.MINOR_LEAGUE:
|
||||||
new_team = await move.destination_team.minor_league_affiliate() if move.destination_team else None
|
new_team = (
|
||||||
|
await move.destination_team.minor_league_affiliate()
|
||||||
|
if move.destination_team
|
||||||
|
else None
|
||||||
|
)
|
||||||
elif move.to_roster == RosterType.INJURED_LIST:
|
elif move.to_roster == RosterType.INJURED_LIST:
|
||||||
new_team = await move.destination_team.injured_list_affiliate() if move.destination_team else None
|
new_team = (
|
||||||
|
await move.destination_team.injured_list_affiliate()
|
||||||
|
if move.destination_team
|
||||||
|
else None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
new_team = move.destination_team
|
new_team = move.destination_team
|
||||||
|
|
||||||
@ -507,18 +497,25 @@ class TradeAcceptanceView(discord.ui.View):
|
|||||||
oldteam=old_team,
|
oldteam=old_team,
|
||||||
newteam=new_team,
|
newteam=new_team,
|
||||||
cancelled=False,
|
cancelled=False,
|
||||||
frozen=False # Trades are NOT frozen - immediately effective
|
frozen=False,
|
||||||
)
|
)
|
||||||
transactions.append(transaction)
|
transactions.append(transaction)
|
||||||
|
|
||||||
# Process supplementary moves
|
|
||||||
for move in self.builder.trade.supplementary_moves:
|
for move in self.builder.trade.supplementary_moves:
|
||||||
if move.from_roster == RosterType.MAJOR_LEAGUE:
|
if move.from_roster == RosterType.MAJOR_LEAGUE:
|
||||||
old_team = move.source_team
|
old_team = move.source_team
|
||||||
elif move.from_roster == RosterType.MINOR_LEAGUE:
|
elif move.from_roster == RosterType.MINOR_LEAGUE:
|
||||||
old_team = await move.source_team.minor_league_affiliate() if move.source_team else None
|
old_team = (
|
||||||
|
await move.source_team.minor_league_affiliate()
|
||||||
|
if move.source_team
|
||||||
|
else None
|
||||||
|
)
|
||||||
elif move.from_roster == RosterType.INJURED_LIST:
|
elif move.from_roster == RosterType.INJURED_LIST:
|
||||||
old_team = await move.source_team.injured_list_affiliate() if move.source_team else None
|
old_team = (
|
||||||
|
await move.source_team.injured_list_affiliate()
|
||||||
|
if move.source_team
|
||||||
|
else None
|
||||||
|
)
|
||||||
elif move.from_roster == RosterType.FREE_AGENCY:
|
elif move.from_roster == RosterType.FREE_AGENCY:
|
||||||
old_team = fa_team
|
old_team = fa_team
|
||||||
else:
|
else:
|
||||||
@ -527,9 +524,17 @@ class TradeAcceptanceView(discord.ui.View):
|
|||||||
if move.to_roster == RosterType.MAJOR_LEAGUE:
|
if move.to_roster == RosterType.MAJOR_LEAGUE:
|
||||||
new_team = move.destination_team
|
new_team = move.destination_team
|
||||||
elif move.to_roster == RosterType.MINOR_LEAGUE:
|
elif move.to_roster == RosterType.MINOR_LEAGUE:
|
||||||
new_team = await move.destination_team.minor_league_affiliate() if move.destination_team else None
|
new_team = (
|
||||||
|
await move.destination_team.minor_league_affiliate()
|
||||||
|
if move.destination_team
|
||||||
|
else None
|
||||||
|
)
|
||||||
elif move.to_roster == RosterType.INJURED_LIST:
|
elif move.to_roster == RosterType.INJURED_LIST:
|
||||||
new_team = await move.destination_team.injured_list_affiliate() if move.destination_team else None
|
new_team = (
|
||||||
|
await move.destination_team.injured_list_affiliate()
|
||||||
|
if move.destination_team
|
||||||
|
else None
|
||||||
|
)
|
||||||
elif move.to_roster == RosterType.FREE_AGENCY:
|
elif move.to_roster == RosterType.FREE_AGENCY:
|
||||||
new_team = fa_team
|
new_team = fa_team
|
||||||
else:
|
else:
|
||||||
@ -545,45 +550,42 @@ class TradeAcceptanceView(discord.ui.View):
|
|||||||
oldteam=old_team,
|
oldteam=old_team,
|
||||||
newteam=new_team,
|
newteam=new_team,
|
||||||
cancelled=False,
|
cancelled=False,
|
||||||
frozen=False # Trades are NOT frozen - immediately effective
|
frozen=False,
|
||||||
)
|
)
|
||||||
transactions.append(transaction)
|
transactions.append(transaction)
|
||||||
|
|
||||||
# POST transactions to database
|
|
||||||
if transactions:
|
if transactions:
|
||||||
created_transactions = await transaction_service.create_transaction_batch(transactions)
|
created_transactions = (
|
||||||
|
await transaction_service.create_transaction_batch(transactions)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
created_transactions = []
|
created_transactions = []
|
||||||
|
|
||||||
# Post to #transaction-log channel
|
|
||||||
if created_transactions and interaction.client:
|
if created_transactions and interaction.client:
|
||||||
await post_trade_to_log(
|
await post_trade_to_log(
|
||||||
bot=interaction.client,
|
bot=interaction.client,
|
||||||
builder=self.builder,
|
builder=self.builder,
|
||||||
transactions=created_transactions,
|
transactions=created_transactions,
|
||||||
effective_week=next_week
|
effective_week=next_week,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update trade status
|
|
||||||
self.builder.trade.status = TradeStatus.ACCEPTED
|
self.builder.trade.status = TradeStatus.ACCEPTED
|
||||||
|
|
||||||
# Disable buttons
|
|
||||||
self.accept_button.disabled = True
|
self.accept_button.disabled = True
|
||||||
self.reject_button.disabled = True
|
self.reject_button.disabled = True
|
||||||
|
|
||||||
# Update embed to show completion
|
embed = await create_trade_complete_embed(
|
||||||
embed = await create_trade_complete_embed(self.builder, len(created_transactions), next_week)
|
self.builder, len(created_transactions), next_week
|
||||||
|
)
|
||||||
await interaction.edit_original_response(embed=embed, view=self)
|
await interaction.edit_original_response(embed=embed, view=self)
|
||||||
|
|
||||||
# Send completion message
|
|
||||||
await interaction.followup.send(
|
await interaction.followup.send(
|
||||||
f"🎉 **Trade Complete!**\n\n"
|
f"**Trade Complete!**\n\n"
|
||||||
f"All {self.builder.team_count} teams have accepted the trade.\n"
|
f"All {self.builder.team_count} teams have accepted the trade.\n"
|
||||||
f"**{len(created_transactions)} transactions** have been created for **Week {next_week}**.\n\n"
|
f"**{len(created_transactions)} transactions** have been created for **Week {next_week}**.\n\n"
|
||||||
f"Trade ID: `{self.builder.trade_id}`"
|
f"Trade ID: `{self.builder.trade_id}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clear the trade builder
|
|
||||||
for team in self.builder.participating_teams:
|
for team in self.builder.participating_teams:
|
||||||
clear_trade_builder_by_team(team.id)
|
clear_trade_builder_by_team(team.id)
|
||||||
|
|
||||||
@ -591,81 +593,79 @@ class TradeAcceptanceView(discord.ui.View):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await interaction.followup.send(
|
await interaction.followup.send(
|
||||||
f"❌ Error finalizing trade: {str(e)}",
|
f"Error finalizing trade: {str(e)}", ephemeral=True
|
||||||
ephemeral=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def create_trade_acceptance_embed(builder: TradeBuilder) -> discord.Embed:
|
async def create_trade_acceptance_embed(builder: TradeBuilder) -> discord.Embed:
|
||||||
"""Create embed showing trade details and acceptance status."""
|
"""Create embed showing trade details and acceptance status."""
|
||||||
embed = EmbedTemplate.create_base_embed(
|
embed = EmbedTemplate.create_base_embed(
|
||||||
title=f"📋 Trade Pending Acceptance - {builder.trade.get_trade_summary()}",
|
title=f"Trade Pending Acceptance - {builder.trade.get_trade_summary()}",
|
||||||
description="All participating teams must accept to complete the trade.",
|
description="All participating teams must accept to complete the trade.",
|
||||||
color=EmbedColors.WARNING
|
color=EmbedColors.WARNING,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Show participating teams
|
team_list = [
|
||||||
team_list = [f"• {team.abbrev} - {team.sname}" for team in builder.participating_teams]
|
f"- {team.abbrev} - {team.sname}" for team in builder.participating_teams
|
||||||
|
]
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=f"🏟️ Participating Teams ({builder.team_count})",
|
name=f"Participating Teams ({builder.team_count})",
|
||||||
value="\n".join(team_list),
|
value="\n".join(team_list),
|
||||||
inline=False
|
inline=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Show cross-team moves
|
|
||||||
if builder.trade.cross_team_moves:
|
if builder.trade.cross_team_moves:
|
||||||
moves_text = ""
|
moves_text = ""
|
||||||
for move in builder.trade.cross_team_moves[:10]:
|
for move in builder.trade.cross_team_moves[:10]:
|
||||||
moves_text += f"• {move.description}\n"
|
moves_text += f"- {move.description}\n"
|
||||||
if len(builder.trade.cross_team_moves) > 10:
|
if len(builder.trade.cross_team_moves) > 10:
|
||||||
moves_text += f"... and {len(builder.trade.cross_team_moves) - 10} more"
|
moves_text += f"... and {len(builder.trade.cross_team_moves) - 10} more"
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=f"🔄 Player Exchanges ({len(builder.trade.cross_team_moves)})",
|
name=f"Player Exchanges ({len(builder.trade.cross_team_moves)})",
|
||||||
value=moves_text,
|
value=moves_text,
|
||||||
inline=False
|
inline=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Show supplementary moves if any
|
|
||||||
if builder.trade.supplementary_moves:
|
if builder.trade.supplementary_moves:
|
||||||
supp_text = ""
|
supp_text = ""
|
||||||
for move in builder.trade.supplementary_moves[:5]:
|
for move in builder.trade.supplementary_moves[:5]:
|
||||||
supp_text += f"• {move.description}\n"
|
supp_text += f"- {move.description}\n"
|
||||||
if len(builder.trade.supplementary_moves) > 5:
|
if len(builder.trade.supplementary_moves) > 5:
|
||||||
supp_text += f"... and {len(builder.trade.supplementary_moves) - 5} more"
|
supp_text += f"... and {len(builder.trade.supplementary_moves) - 5} more"
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=f"⚙️ Supplementary Moves ({len(builder.trade.supplementary_moves)})",
|
name=f"Supplementary Moves ({len(builder.trade.supplementary_moves)})",
|
||||||
value=supp_text,
|
value=supp_text,
|
||||||
inline=False
|
inline=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Show acceptance status
|
|
||||||
status_lines = []
|
status_lines = []
|
||||||
for team in builder.participating_teams:
|
for team in builder.participating_teams:
|
||||||
if team.id in builder.accepted_teams:
|
if team.id in builder.accepted_teams:
|
||||||
status_lines.append(f"✅ **{team.abbrev}** - Accepted")
|
status_lines.append(f"**{team.abbrev}** - Accepted")
|
||||||
else:
|
else:
|
||||||
status_lines.append(f"⏳ **{team.abbrev}** - Pending")
|
status_lines.append(f"**{team.abbrev}** - Pending")
|
||||||
|
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name="📊 Acceptance Status",
|
name="Acceptance Status", value="\n".join(status_lines), inline=False
|
||||||
value="\n".join(status_lines),
|
|
||||||
inline=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add footer
|
embed.set_footer(
|
||||||
embed.set_footer(text=f"Trade ID: {builder.trade_id} • {len(builder.accepted_teams)}/{builder.team_count} teams accepted")
|
text=f"Trade ID: {builder.trade_id} | {len(builder.accepted_teams)}/{builder.team_count} teams accepted"
|
||||||
|
)
|
||||||
|
|
||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
|
||||||
async def create_trade_rejection_embed(builder: TradeBuilder, rejecting_team: Team) -> discord.Embed:
|
async def create_trade_rejection_embed(
|
||||||
|
builder: TradeBuilder, rejecting_team: Team
|
||||||
|
) -> discord.Embed:
|
||||||
"""Create embed showing trade was rejected."""
|
"""Create embed showing trade was rejected."""
|
||||||
embed = EmbedTemplate.create_base_embed(
|
embed = EmbedTemplate.create_base_embed(
|
||||||
title=f"❌ Trade Rejected - {builder.trade.get_trade_summary()}",
|
title=f"Trade Rejected - {builder.trade.get_trade_summary()}",
|
||||||
description=f"**{rejecting_team.abbrev}** has rejected the trade.\n\n"
|
description=f"**{rejecting_team.abbrev}** has rejected the trade.\n\n"
|
||||||
f"The trade has been moved back to **DRAFT** status.\n"
|
f"The trade has been moved back to **DRAFT** status.\n"
|
||||||
f"Teams can continue negotiating using `/trade` commands.",
|
f"Teams can continue negotiating using `/trade` commands.",
|
||||||
color=EmbedColors.ERROR
|
color=EmbedColors.ERROR,
|
||||||
)
|
)
|
||||||
|
|
||||||
embed.set_footer(text=f"Trade ID: {builder.trade_id}")
|
embed.set_footer(text=f"Trade ID: {builder.trade_id}")
|
||||||
@ -673,37 +673,33 @@ async def create_trade_rejection_embed(builder: TradeBuilder, rejecting_team: Te
|
|||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
|
||||||
async def create_trade_complete_embed(builder: TradeBuilder, transaction_count: int, effective_week: int) -> discord.Embed:
|
async def create_trade_complete_embed(
|
||||||
|
builder: TradeBuilder, transaction_count: int, effective_week: int
|
||||||
|
) -> discord.Embed:
|
||||||
"""Create embed showing trade was completed."""
|
"""Create embed showing trade was completed."""
|
||||||
embed = EmbedTemplate.create_base_embed(
|
embed = EmbedTemplate.create_base_embed(
|
||||||
title=f"🎉 Trade Complete! - {builder.trade.get_trade_summary()}",
|
title=f"Trade Complete - {builder.trade.get_trade_summary()}",
|
||||||
description=f"All {builder.team_count} teams have accepted the trade!\n\n"
|
description=f"All {builder.team_count} teams have accepted the trade.\n\n"
|
||||||
f"**{transaction_count} transactions** created for **Week {effective_week}**.",
|
f"**{transaction_count} transactions** created for **Week {effective_week}**.",
|
||||||
color=EmbedColors.SUCCESS
|
color=EmbedColors.SUCCESS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Show final acceptance status (all green)
|
status_lines = [
|
||||||
status_lines = [f"✅ **{team.abbrev}** - Accepted" for team in builder.participating_teams]
|
f"**{team.abbrev}** - Accepted" for team in builder.participating_teams
|
||||||
embed.add_field(
|
]
|
||||||
name="📊 Final Status",
|
embed.add_field(name="Final Status", value="\n".join(status_lines), inline=False)
|
||||||
value="\n".join(status_lines),
|
|
||||||
inline=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show cross-team moves
|
|
||||||
if builder.trade.cross_team_moves:
|
if builder.trade.cross_team_moves:
|
||||||
moves_text = ""
|
moves_text = ""
|
||||||
for move in builder.trade.cross_team_moves[:8]:
|
for move in builder.trade.cross_team_moves[:8]:
|
||||||
moves_text += f"• {move.description}\n"
|
moves_text += f"- {move.description}\n"
|
||||||
if len(builder.trade.cross_team_moves) > 8:
|
if len(builder.trade.cross_team_moves) > 8:
|
||||||
moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more"
|
moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more"
|
||||||
embed.add_field(
|
embed.add_field(name="Player Exchanges", value=moves_text, inline=False)
|
||||||
name=f"🔄 Player Exchanges",
|
|
||||||
value=moves_text,
|
|
||||||
inline=False
|
|
||||||
)
|
|
||||||
|
|
||||||
embed.set_footer(text=f"Trade ID: {builder.trade_id} • Effective: Week {effective_week}")
|
embed.set_footer(
|
||||||
|
text=f"Trade ID: {builder.trade_id} | Effective: Week {effective_week}"
|
||||||
|
)
|
||||||
|
|
||||||
return embed
|
return embed
|
||||||
|
|
||||||
@ -718,7 +714,6 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed:
|
|||||||
Returns:
|
Returns:
|
||||||
Discord embed with current trade state
|
Discord embed with current trade state
|
||||||
"""
|
"""
|
||||||
# Determine embed color based on trade status
|
|
||||||
if builder.is_empty:
|
if builder.is_empty:
|
||||||
color = EmbedColors.SECONDARY
|
color = EmbedColors.SECONDARY
|
||||||
else:
|
else:
|
||||||
@ -726,79 +721,79 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed:
|
|||||||
color = EmbedColors.SUCCESS if validation.is_legal else EmbedColors.WARNING
|
color = EmbedColors.SUCCESS if validation.is_legal else EmbedColors.WARNING
|
||||||
|
|
||||||
embed = EmbedTemplate.create_base_embed(
|
embed = EmbedTemplate.create_base_embed(
|
||||||
title=f"📋 Trade Builder - {builder.trade.get_trade_summary()}",
|
title=f"Trade Builder - {builder.trade.get_trade_summary()}",
|
||||||
description=f"Build your multi-team trade",
|
description="Build your multi-team trade",
|
||||||
color=color
|
color=color,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add participating teams section
|
team_list = [
|
||||||
team_list = [f"• {team.abbrev} - {team.sname}" for team in builder.participating_teams]
|
f"- {team.abbrev} - {team.sname}" for team in builder.participating_teams
|
||||||
|
]
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=f"🏟️ Participating Teams ({builder.team_count})",
|
name=f"Participating Teams ({builder.team_count})",
|
||||||
value="\n".join(team_list) if team_list else "*No teams yet*",
|
value="\n".join(team_list) if team_list else "*No teams yet*",
|
||||||
inline=False
|
inline=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add current moves section
|
|
||||||
if builder.is_empty:
|
if builder.is_empty:
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name="Current Moves",
|
name="Current Moves",
|
||||||
value="*No moves yet. Use the `/trade` commands to build your trade.*",
|
value="*No moves yet. Use the `/trade` commands to build your trade.*",
|
||||||
inline=False
|
inline=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Show cross-team moves
|
|
||||||
if builder.trade.cross_team_moves:
|
if builder.trade.cross_team_moves:
|
||||||
moves_text = ""
|
moves_text = ""
|
||||||
for i, move in enumerate(builder.trade.cross_team_moves[:8], 1): # Limit display
|
for i, move in enumerate(builder.trade.cross_team_moves[:8], 1):
|
||||||
moves_text += f"{i}. {move.description}\n"
|
moves_text += f"{i}. {move.description}\n"
|
||||||
|
|
||||||
if len(builder.trade.cross_team_moves) > 8:
|
if len(builder.trade.cross_team_moves) > 8:
|
||||||
moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more"
|
moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more"
|
||||||
|
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=f"🔄 Player Exchanges ({len(builder.trade.cross_team_moves)})",
|
name=f"Player Exchanges ({len(builder.trade.cross_team_moves)})",
|
||||||
value=moves_text,
|
value=moves_text,
|
||||||
inline=False
|
inline=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Show supplementary moves
|
|
||||||
if builder.trade.supplementary_moves:
|
if builder.trade.supplementary_moves:
|
||||||
supp_text = ""
|
supp_text = ""
|
||||||
for i, move in enumerate(builder.trade.supplementary_moves[:5], 1): # Limit display
|
for i, move in enumerate(builder.trade.supplementary_moves[:5], 1):
|
||||||
supp_text += f"{i}. {move.description}\n"
|
supp_text += f"{i}. {move.description}\n"
|
||||||
|
|
||||||
if len(builder.trade.supplementary_moves) > 5:
|
if len(builder.trade.supplementary_moves) > 5:
|
||||||
supp_text += f"... and {len(builder.trade.supplementary_moves) - 5} more"
|
supp_text += (
|
||||||
|
f"... and {len(builder.trade.supplementary_moves) - 5} more"
|
||||||
|
)
|
||||||
|
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name=f"⚙️ Supplementary Moves ({len(builder.trade.supplementary_moves)})",
|
name=f"Supplementary Moves ({len(builder.trade.supplementary_moves)})",
|
||||||
value=supp_text,
|
value=supp_text,
|
||||||
inline=False
|
inline=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add quick validation summary
|
|
||||||
validation = await builder.validate_trade()
|
validation = await builder.validate_trade()
|
||||||
if validation.is_legal:
|
if validation.is_legal:
|
||||||
status_text = "✅ Trade appears legal"
|
status_text = "Trade appears legal"
|
||||||
else:
|
else:
|
||||||
error_count = len(validation.all_errors)
|
error_count = len(validation.all_errors)
|
||||||
status_text = f"❌ {error_count} error{'s' if error_count != 1 else ''} found"
|
status_text = f"{error_count} error{'s' if error_count != 1 else ''} found\n"
|
||||||
|
status_text += "\n".join(f"- {error}" for error in validation.all_errors)
|
||||||
|
if validation.all_suggestions:
|
||||||
|
status_text += "\n" + "\n".join(
|
||||||
|
f"- {s}" for s in validation.all_suggestions
|
||||||
|
)
|
||||||
|
|
||||||
|
embed.add_field(name="Quick Status", value=status_text, inline=False)
|
||||||
|
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name="🔍 Quick Status",
|
name="Build Your Trade",
|
||||||
value=status_text,
|
value="- `/trade add-player` - Add player exchanges\n- `/trade supplementary` - Add internal moves\n- `/trade add-team` - Add more teams",
|
||||||
inline=False
|
inline=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add instructions for adding more moves
|
embed.set_footer(
|
||||||
embed.add_field(
|
text=f"Trade ID: {builder.trade_id} | Created: {datetime.fromisoformat(builder.trade.created_at).strftime('%H:%M:%S')}"
|
||||||
name="➕ Build Your Trade",
|
|
||||||
value="• `/trade add-player` - Add player exchanges\n• `/trade supplementary` - Add internal moves\n• `/trade add-team` - Add more teams",
|
|
||||||
inline=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add footer with trade ID and timestamp
|
return embed
|
||||||
embed.set_footer(text=f"Trade ID: {builder.trade_id} • Created: {datetime.fromisoformat(builder.trade.created_at).strftime('%H:%M:%S')}")
|
|
||||||
|
|
||||||
return embed
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user