Merge main into next-release
All checks were successful
Build Docker Image / build (push) Successful in 3m26s
Build Docker Image / build (pull_request) Successful in 58s

Resolve conflict in views/trade_embed.py: keep main's hotfix
(emoji-stripped UI, inline validation errors) and apply next-release's
import refactor (lazy imports hoisted to top-level).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2026-03-17 11:28:05 -05:00
commit 910a27e356
13 changed files with 729 additions and 786 deletions

View File

@ -86,6 +86,23 @@ class MyCog(commands.Cog):
- API errors → verify `DB_URL` points to correct database API and `API_TOKEN` matches
- Redis errors are non-fatal (graceful fallback when `REDIS_URL` is empty)
## Dependencies
### Pinning Policy
All dependencies are pinned to exact versions (`==`). This ensures every Docker build
produces an identical image — a `git revert` actually rolls back to the previous working state.
- **`requirements.txt`** — production runtime deps only (used by Dockerfile)
- **`requirements-dev.txt`** — includes `-r requirements.txt` plus dev/test tools
When installing for local development or running tests:
```bash
pip install -r requirements-dev.txt
```
When upgrading a dependency, update BOTH the `==` pin and (if applicable) the comment in
the file. Test before committing. Never use `>=` or `~=` constraints.
## API Reference
- OpenAPI spec: https://sba.manticorum.com/api/openapi.json (use WebFetch for current endpoints)

9
requirements-dev.txt Normal file
View File

@ -0,0 +1,9 @@
-r requirements.txt
# Development & Testing
pytest==8.4.1
pytest-asyncio==1.0.0
pytest-mock==3.15.1
aioresponses==0.7.8
black==26.1.0
ruff==0.15.0

View File

@ -6,15 +6,7 @@ aiohttp==3.12.13
# Utilities
python-dotenv==1.1.1
redis>=5.0.0 # For optional API response caching (not currently installed)
# Development & Testing
pytest==8.4.1
pytest-asyncio==1.0.0
pytest-mock>=3.10.0 # Not currently installed
aioresponses==0.7.8
black>=23.0.0 # Not currently installed
ruff>=0.1.0 # Not currently installed
redis==7.3.0
# Optional Dependencies
pygsheets==2.0.6 # For Google Sheets integration (scorecard submission)
pygsheets==2.0.6 # For Google Sheets integration (scorecard submission)

View File

@ -3,6 +3,7 @@ Base service class for Discord Bot v2.0
Provides common CRUD operations and error handling for all data services.
"""
import logging
import hashlib
from typing import Optional, Type, TypeVar, Generic, Dict, Any, List, Tuple
@ -12,15 +13,15 @@ from models.base import SBABaseModel
from exceptions import APIException
from utils.cache import CacheManager
logger = logging.getLogger(f'{__name__}.BaseService')
logger = logging.getLogger(f"{__name__}.BaseService")
T = TypeVar('T', bound=SBABaseModel)
T = TypeVar("T", bound=SBABaseModel)
class BaseService(Generic[T]):
"""
Base service class providing common CRUD operations for SBA models.
Features:
- Generic type support for any SBABaseModel subclass
- Automatic model validation and conversion
@ -28,15 +29,17 @@ class BaseService(Generic[T]):
- API response format handling (count + list format)
- Connection management via global client
"""
def __init__(self,
model_class: Type[T],
endpoint: str,
client: Optional[APIClient] = None,
cache_manager: Optional[CacheManager] = None):
def __init__(
self,
model_class: Type[T],
endpoint: str,
client: Optional[APIClient] = None,
cache_manager: Optional[CacheManager] = None,
):
"""
Initialize base service.
Args:
model_class: Pydantic model class for this service
endpoint: API endpoint path (e.g., 'players', 'teams')
@ -48,40 +51,44 @@ class BaseService(Generic[T]):
self._client = client
self._cached_client: Optional[APIClient] = None
self.cache = cache_manager or CacheManager()
logger.debug(f"Initialized {self.__class__.__name__} for {model_class.__name__} at endpoint '{endpoint}'")
def _generate_cache_key(self, method: str, params: Optional[List[Tuple[str, Any]]] = None) -> str:
logger.debug(
f"Initialized {self.__class__.__name__} for {model_class.__name__} at endpoint '{endpoint}'"
)
def _generate_cache_key(
self, method: str, params: Optional[List[Tuple[str, Any]]] = None
) -> str:
"""
Generate consistent cache key for API calls.
Args:
method: API method name
params: Query parameters as list of tuples
Returns:
SHA256-hashed cache key
"""
key_parts = [self.endpoint, method]
if params:
# Sort parameters for consistent key generation
sorted_params = sorted(params, key=lambda x: str(x[0]))
param_str = "&".join([f"{k}={v}" for k, v in sorted_params])
key_parts.append(param_str)
key_data = ":".join(key_parts)
key_hash = hashlib.sha256(key_data.encode()).hexdigest()[:16] # First 16 chars
return self.cache.cache_key("sba", f"{self.endpoint}_{key_hash}")
async def _get_cached_items(self, cache_key: str) -> Optional[List[T]]:
"""
Get cached list of model items.
Args:
cache_key: Cache key to lookup
Returns:
List of model instances or None if not cached
"""
@ -91,13 +98,15 @@ class BaseService(Generic[T]):
return [self.model_class.from_api_data(item) for item in cached_data]
except Exception as e:
logger.warning(f"Error deserializing cached data for {cache_key}: {e}")
return None
async def _cache_items(self, cache_key: str, items: List[T], ttl: Optional[int] = None) -> None:
async def _cache_items(
self, cache_key: str, items: List[T], ttl: Optional[int] = None
) -> None:
"""
Cache list of model items.
Args:
cache_key: Cache key to store under
items: List of model instances to cache
@ -105,40 +114,40 @@ class BaseService(Generic[T]):
"""
if not items:
return
try:
# Convert to JSON-serializable format
cache_data = [item.model_dump() for item in items]
await self.cache.set(cache_key, cache_data, ttl)
except Exception as e:
logger.warning(f"Error caching items for {cache_key}: {e}")
async def get_client(self) -> APIClient:
"""
Get API client instance with caching to reduce async overhead.
Returns:
APIClient instance (cached after first access)
"""
if self._client:
return self._client
# Cache the global client to avoid repeated async calls
if self._cached_client is None:
self._cached_client = await get_global_client()
return self._cached_client
async def get_by_id(self, object_id: int) -> Optional[T]:
"""
Get single object by ID.
Args:
object_id: Unique identifier for the object
Returns:
Model instance or None if not found
Raises:
APIException: For API errors
ValueError: For invalid data
@ -146,167 +155,181 @@ class BaseService(Generic[T]):
try:
client = await self.get_client()
data = await client.get(self.endpoint, object_id=object_id)
if not data:
logger.debug(f"{self.model_class.__name__} {object_id} not found")
return None
model = self.model_class.from_api_data(data)
logger.debug(f"Retrieved {self.model_class.__name__} {object_id}: {model}")
return model
except APIException:
logger.error(f"API error retrieving {self.model_class.__name__} {object_id}")
logger.error(
f"API error retrieving {self.model_class.__name__} {object_id}"
)
raise
except Exception as e:
logger.error(f"Error retrieving {self.model_class.__name__} {object_id}: {e}")
logger.error(
f"Error retrieving {self.model_class.__name__} {object_id}: {e}"
)
raise APIException(f"Failed to retrieve {self.model_class.__name__}: {e}")
async def get_all(self, params: Optional[List[tuple]] = None) -> Tuple[List[T], int]:
async def get_all(
self, params: Optional[List[tuple]] = None
) -> Tuple[List[T], int]:
"""
Get all objects with optional query parameters.
Args:
params: Query parameters as list of (key, value) tuples
Returns:
Tuple of (list of model instances, total count)
Raises:
APIException: For API errors
"""
try:
client = await self.get_client()
data = await client.get(self.endpoint, params=params)
if not data:
logger.debug(f"No {self.model_class.__name__} objects found")
return [], 0
# Handle API response format: {'count': int, '<endpoint>': [...]}
items, count = self._extract_items_and_count_from_response(data)
models = [self.model_class.from_api_data(item) for item in items]
logger.debug(f"Retrieved {len(models)} of {count} {self.model_class.__name__} objects")
logger.debug(
f"Retrieved {len(models)} of {count} {self.model_class.__name__} objects"
)
return models, count
except APIException:
logger.error(f"API error retrieving {self.model_class.__name__} list")
raise
except Exception as e:
logger.error(f"Error retrieving {self.model_class.__name__} list: {e}")
raise APIException(f"Failed to retrieve {self.model_class.__name__} list: {e}")
raise APIException(
f"Failed to retrieve {self.model_class.__name__} list: {e}"
)
async def get_all_items(self, params: Optional[List[tuple]] = None) -> List[T]:
"""
Get all objects (convenience method that only returns the list).
Args:
params: Query parameters as list of (key, value) tuples
Returns:
List of model instances
"""
items, _ = await self.get_all(params=params)
return items
async def create(self, model_data: Dict[str, Any]) -> Optional[T]:
"""
Create new object from data dictionary.
Args:
model_data: Dictionary of model fields
Returns:
Created model instance or None
Raises:
APIException: For API errors
ValueError: For invalid data
"""
try:
client = await self.get_client()
response = await client.post(self.endpoint, model_data)
response = await client.post(f"{self.endpoint}/", model_data)
if not response:
logger.warning(f"No response from {self.model_class.__name__} creation")
return None
model = self.model_class.from_api_data(response)
logger.debug(f"Created {self.model_class.__name__}: {model}")
return model
except APIException:
logger.error(f"API error creating {self.model_class.__name__}")
raise
except Exception as e:
logger.error(f"Error creating {self.model_class.__name__}: {e}")
raise APIException(f"Failed to create {self.model_class.__name__}: {e}")
async def create_from_model(self, model: T) -> Optional[T]:
"""
Create new object from model instance.
Args:
model: Model instance to create
Returns:
Created model instance or None
"""
return await self.create(model.to_dict(exclude_none=True))
async def update(self, object_id: int, model_data: Dict[str, Any]) -> Optional[T]:
"""
Update existing object.
Args:
object_id: ID of object to update
model_data: Dictionary of fields to update
Returns:
Updated model instance or None if not found
Raises:
APIException: For API errors
"""
try:
client = await self.get_client()
response = await client.put(self.endpoint, model_data, object_id=object_id)
if not response:
logger.debug(f"{self.model_class.__name__} {object_id} not found for update")
logger.debug(
f"{self.model_class.__name__} {object_id} not found for update"
)
return None
model = self.model_class.from_api_data(response)
logger.debug(f"Updated {self.model_class.__name__} {object_id}: {model}")
return model
except APIException:
logger.error(f"API error updating {self.model_class.__name__} {object_id}")
raise
except Exception as e:
logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}")
raise APIException(f"Failed to update {self.model_class.__name__}: {e}")
async def update_from_model(self, model: T) -> Optional[T]:
"""
Update object from model instance.
Args:
model: Model instance to update (must have ID)
Returns:
Updated model instance or None
Raises:
ValueError: If model has no ID
"""
if not model.id:
raise ValueError(f"Cannot update {self.model_class.__name__} without ID")
return await self.update(model.id, model.to_dict(exclude_none=True))
async def patch(self, object_id: int, model_data: Dict[str, Any], use_query_params: bool = False) -> Optional[T]:
async def patch(
self, object_id: int, model_data: Dict[str, Any], use_query_params: bool = False
) -> Optional[T]:
"""
Update existing object with HTTP PATCH.
@ -323,10 +346,14 @@ class BaseService(Generic[T]):
"""
try:
client = await self.get_client()
response = await client.patch(self.endpoint, model_data, object_id, use_query_params=use_query_params)
response = await client.patch(
self.endpoint, model_data, object_id, use_query_params=use_query_params
)
if not response:
logger.debug(f"{self.model_class.__name__} {object_id} not found for update")
logger.debug(
f"{self.model_class.__name__} {object_id} not found for update"
)
return None
model = self.model_class.from_api_data(response)
@ -340,134 +367,142 @@ class BaseService(Generic[T]):
logger.error(f"Error updating {self.model_class.__name__} {object_id}: {e}")
raise APIException(f"Failed to update {self.model_class.__name__}: {e}")
async def delete(self, object_id: int) -> bool:
"""
Delete object by ID.
Args:
object_id: ID of object to delete
Returns:
True if deleted, False if not found
Raises:
APIException: For API errors
"""
try:
client = await self.get_client()
success = await client.delete(self.endpoint, object_id=object_id)
if success:
logger.debug(f"Deleted {self.model_class.__name__} {object_id}")
else:
logger.debug(f"{self.model_class.__name__} {object_id} not found for deletion")
logger.debug(
f"{self.model_class.__name__} {object_id} not found for deletion"
)
return success
except APIException:
logger.error(f"API error deleting {self.model_class.__name__} {object_id}")
raise
except Exception as e:
logger.error(f"Error deleting {self.model_class.__name__} {object_id}: {e}")
raise APIException(f"Failed to delete {self.model_class.__name__}: {e}")
async def get_by_field(self, field: str, value: Any) -> List[T]:
"""
Get objects by specific field value.
Args:
field: Field name to search
value: Field value to match
Returns:
List of matching model instances
"""
params = [(field, str(value))]
return await self.get_all_items(params=params)
async def count(self, params: Optional[List[tuple]] = None) -> int:
"""
Get count of objects matching parameters.
Args:
params: Query parameters
Returns:
Number of matching objects (from API count field)
"""
_, count = await self.get_all(params=params)
return count
def _extract_items_and_count_from_response(self, data: Any) -> Tuple[List[Dict[str, Any]], int]:
def _extract_items_and_count_from_response(
self, data: Any
) -> Tuple[List[Dict[str, Any]], int]:
"""
Extract items list and count from API response with optimized parsing.
Expected format: {'count': int, '<endpoint>': [...]}
Single object format: {'id': 1, 'name': '...'}
Args:
data: API response data
Returns:
Tuple of (items list, total count)
"""
if isinstance(data, list):
return data, len(data)
if not isinstance(data, dict):
logger.warning(f"Unexpected response format for {self.model_class.__name__}: {type(data)}")
logger.warning(
f"Unexpected response format for {self.model_class.__name__}: {type(data)}"
)
return [], 0
# Single pass through the response dict - get count first
count = data.get('count', 0)
count = data.get("count", 0)
# Priority order for finding items list (most common first)
field_candidates = [self.endpoint, 'items', 'data', 'results']
field_candidates = [self.endpoint, "items", "data", "results"]
for field_name in field_candidates:
if field_name in data and isinstance(data[field_name], list):
return data[field_name], count or len(data[field_name])
# Single object response (check for common identifying fields)
if any(key in data for key in ['id', 'name', 'abbrev']):
if any(key in data for key in ["id", "name", "abbrev"]):
return [data], 1
return [], count
async def get_items_with_params(self, params: Optional[List[tuple]] = None) -> List[T]:
async def get_items_with_params(
self, params: Optional[List[tuple]] = None
) -> List[T]:
"""
Get all items with parameters (alias for get_all_items for compatibility).
Args:
params: Query parameters as list of (key, value) tuples
Returns:
List of model instances
"""
return await self.get_all_items(params=params)
async def create_item(self, model_data: Dict[str, Any]) -> Optional[T]:
"""
Create item (alias for create for compatibility).
Args:
model_data: Dictionary of model fields
Returns:
Created model instance or None
"""
return await self.create(model_data)
async def update_item_by_field(self, field: str, value: Any, update_data: Dict[str, Any]) -> Optional[T]:
async def update_item_by_field(
self, field: str, value: Any, update_data: Dict[str, Any]
) -> Optional[T]:
"""
Update item by field value.
Args:
field: Field name to search by
value: Field value to match
update_data: Data to update
Returns:
Updated model instance or None if not found
"""
@ -475,22 +510,22 @@ class BaseService(Generic[T]):
items = await self.get_by_field(field, value)
if not items:
return None
# Update the first matching item
item = items[0]
if not item.id:
return None
return await self.update(item.id, update_data)
async def delete_item_by_field(self, field: str, value: Any) -> bool:
"""
Delete item by field value.
Args:
field: Field name to search by
value: Field value to match
Returns:
True if deleted, False if not found
"""
@ -498,62 +533,41 @@ class BaseService(Generic[T]):
items = await self.get_by_field(field, value)
if not items:
return False
# Delete the first matching item
item = items[0]
if not item.id:
return False
return await self.delete(item.id)
async def create_item_in_table(self, table_name: str, item_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Create item in a specific table (simplified for custom commands service).
This is a placeholder - real implementation would need table-specific endpoints.
Args:
table_name: Name of the table
item_data: Data to create
Returns:
Created item data or None
"""
# For now, use the main endpoint - this would need proper implementation
# for different tables like 'custom_command_creators'
try:
client = await self.get_client()
# Use table name as endpoint for now
response = await client.post(table_name, item_data)
return response
except Exception as e:
logger.error(f"Error creating item in table {table_name}: {e}")
return None
async def get_items_from_table_with_params(self, table_name: str, params: List[tuple]) -> List[Dict[str, Any]]:
async def get_items_from_table_with_params(
self, table_name: str, params: List[tuple]
) -> List[Dict[str, Any]]:
"""
Get items from a specific table with parameters.
Args:
table_name: Name of the table
params: Query parameters
Returns:
List of item dictionaries
"""
try:
client = await self.get_client()
data = await client.get(table_name, params=params)
if not data:
return []
# Handle response format
items, _ = self._extract_items_and_count_from_response(data)
return items
except Exception as e:
logger.error(f"Error getting items from table {table_name}: {e}")
return []
def __repr__(self) -> str:
return f"{self.__class__.__name__}(model={self.model_class.__name__}, endpoint='{self.endpoint}')"
return f"{self.__class__.__name__}(model={self.model_class.__name__}, endpoint='{self.endpoint}')"

View File

@ -552,9 +552,8 @@ class CustomCommandsService(BaseService[CustomCommand]):
"active_commands": 0,
}
result = await self.create_item_in_table(
"custom_commands/creators", creator_data
)
client = await self.get_client()
result = await client.post("custom_commands/creators", creator_data)
if not result:
raise BotException("Failed to create command creator")

View File

@ -3,6 +3,7 @@ Decision Service
Manages pitching decision operations for game submission.
"""
from typing import List, Dict, Any, Optional, Tuple
from utils.logging import get_contextual_logger
@ -16,17 +17,14 @@ class DecisionService:
def __init__(self):
"""Initialize decision service."""
self.logger = get_contextual_logger(f'{__name__}.DecisionService')
self.logger = get_contextual_logger(f"{__name__}.DecisionService")
self._get_client = get_global_client
async def get_client(self):
"""Get the API client."""
return await self._get_client()
async def create_decisions_batch(
self,
decisions: List[Dict[str, Any]]
) -> bool:
async def create_decisions_batch(self, decisions: List[Dict[str, Any]]) -> bool:
"""
POST batch of decisions to /decisions endpoint.
@ -42,8 +40,10 @@ class DecisionService:
try:
client = await self.get_client()
payload = {'decisions': decisions}
await client.post('decisions', payload)
payload = {"decisions": decisions}
# Trailing slash required: without it, the server returns a 307 redirect
# and aiohttp drops the POST body when following the redirect
await client.post("decisions/", payload)
self.logger.info(f"Created {len(decisions)} decisions")
return True
@ -70,7 +70,7 @@ class DecisionService:
"""
try:
client = await self.get_client()
await client.delete(f'decisions/game/{game_id}')
await client.delete(f"decisions/game/{game_id}")
self.logger.info(f"Deleted decisions for game {game_id}")
return True
@ -80,9 +80,10 @@ class DecisionService:
raise APIException(f"Failed to delete decisions: {e}")
async def find_winning_losing_pitchers(
self,
decisions_data: List[Dict[str, Any]]
) -> Tuple[Optional[Player], Optional[Player], Optional[Player], List[Player], List[Player]]:
self, decisions_data: List[Dict[str, Any]]
) -> Tuple[
Optional[Player], Optional[Player], Optional[Player], List[Player], List[Player]
]:
"""
Extract WP, LP, SV, Holds, Blown Saves from decisions list and fetch Player objects.
@ -110,17 +111,17 @@ class DecisionService:
# First pass: Extract IDs
for decision in decisions_data:
pitcher_id = int(decision.get('pitcher_id', 0))
pitcher_id = int(decision.get("pitcher_id", 0))
if int(decision.get('win', 0)) == 1:
if int(decision.get("win", 0)) == 1:
wp_id = pitcher_id
if int(decision.get('loss', 0)) == 1:
if int(decision.get("loss", 0)) == 1:
lp_id = pitcher_id
if int(decision.get('is_save', 0)) == 1:
if int(decision.get("is_save", 0)) == 1:
sv_id = pitcher_id
if int(decision.get('hold', 0)) == 1:
if int(decision.get("hold", 0)) == 1:
hold_ids.append(pitcher_id)
if int(decision.get('b_save', 0)) == 1:
if int(decision.get("b_save", 0)) == 1:
bsv_ids.append(pitcher_id)
# Second pass: Fetch Player objects
@ -154,9 +155,9 @@ class DecisionService:
"""
error_str = str(error)
if 'Player ID' in error_str and 'not found' in error_str:
if "Player ID" in error_str and "not found" in error_str:
return "Invalid pitcher ID in decision data."
elif 'Game ID' in error_str and 'not found' in error_str:
elif "Game ID" in error_str and "not found" in error_str:
return "Game not found for decisions."
else:
return f"Error submitting decisions: {error_str}"

View File

@ -3,13 +3,14 @@ Draft list service for Discord Bot v2.0
Handles team draft list (auto-draft queue) operations. NO CACHING - lists change frequently.
"""
import logging
from typing import Optional, List
from services.base_service import BaseService
from models.draft_list import DraftList
logger = logging.getLogger(f'{__name__}.DraftListService')
logger = logging.getLogger(f"{__name__}.DraftListService")
class DraftListService(BaseService[DraftList]):
@ -32,7 +33,7 @@ class DraftListService(BaseService[DraftList]):
def __init__(self):
"""Initialize draft list service."""
super().__init__(DraftList, 'draftlist')
super().__init__(DraftList, "draftlist")
logger.debug("DraftListService initialized")
def _extract_items_and_count_from_response(self, data):
@ -54,20 +55,16 @@ class DraftListService(BaseService[DraftList]):
return [], 0
# Get count
count = data.get('count', 0)
count = data.get("count", 0)
# API returns items under 'picks' key (not 'draftlist')
if 'picks' in data and isinstance(data['picks'], list):
return data['picks'], count or len(data['picks'])
if "picks" in data and isinstance(data["picks"], list):
return data["picks"], count or len(data["picks"])
# Fallback to standard extraction
return super()._extract_items_and_count_from_response(data)
async def get_team_list(
self,
season: int,
team_id: int
) -> List[DraftList]:
async def get_team_list(self, season: int, team_id: int) -> List[DraftList]:
"""
Get team's draft list ordered by rank.
@ -82,8 +79,8 @@ class DraftListService(BaseService[DraftList]):
"""
try:
params = [
('season', str(season)),
('team_id', str(team_id))
("season", str(season)),
("team_id", str(team_id)),
# NOTE: API does not support 'sort' param - results must be sorted client-side
]
@ -100,11 +97,7 @@ class DraftListService(BaseService[DraftList]):
return []
async def add_to_list(
self,
season: int,
team_id: int,
player_id: int,
rank: Optional[int] = None
self, season: int, team_id: int, player_id: int, rank: Optional[int] = None
) -> Optional[List[DraftList]]:
"""
Add player to team's draft list.
@ -133,10 +126,10 @@ class DraftListService(BaseService[DraftList]):
# Create new entry data
new_entry_data = {
'season': season,
'team_id': team_id,
'player_id': player_id,
'rank': rank
"season": season,
"team_id": team_id,
"player_id": player_id,
"rank": rank,
}
# Build complete list for bulk replacement
@ -146,36 +139,42 @@ class DraftListService(BaseService[DraftList]):
for entry in current_list:
if entry.rank >= rank:
# Shift down entries at or after insertion point
draft_list_entries.append({
'season': entry.season,
'team_id': entry.team_id,
'player_id': entry.player_id,
'rank': entry.rank + 1
})
draft_list_entries.append(
{
"season": entry.season,
"team_id": entry.team_id,
"player_id": entry.player_id,
"rank": entry.rank + 1,
}
)
else:
# Keep existing rank for entries before insertion point
draft_list_entries.append({
'season': entry.season,
'team_id': entry.team_id,
'player_id': entry.player_id,
'rank': entry.rank
})
draft_list_entries.append(
{
"season": entry.season,
"team_id": entry.team_id,
"player_id": entry.player_id,
"rank": entry.rank,
}
)
# Add new entry
draft_list_entries.append(new_entry_data)
# Sort by rank for consistency
draft_list_entries.sort(key=lambda x: x['rank'])
draft_list_entries.sort(key=lambda x: x["rank"])
# POST entire list (bulk replacement)
client = await self.get_client()
payload = {
'count': len(draft_list_entries),
'draft_list': draft_list_entries
"count": len(draft_list_entries),
"draft_list": draft_list_entries,
}
logger.debug(f"Posting draft list for team {team_id}: {len(draft_list_entries)} entries")
response = await client.post(self.endpoint, payload)
logger.debug(
f"Posting draft list for team {team_id}: {len(draft_list_entries)} entries"
)
response = await client.post(f"{self.endpoint}/", payload)
logger.debug(f"POST response: {response}")
# Verify by fetching the list back (API returns full objects)
@ -184,20 +183,21 @@ class DraftListService(BaseService[DraftList]):
# Verify the player was added
if not any(entry.player_id == player_id for entry in verification):
logger.error(f"Player {player_id} not found in list after POST - operation may have failed")
logger.error(
f"Player {player_id} not found in list after POST - operation may have failed"
)
return None
logger.info(f"Added player {player_id} to team {team_id} draft list at rank {rank}")
logger.info(
f"Added player {player_id} to team {team_id} draft list at rank {rank}"
)
return verification # Return full updated list
except Exception as e:
logger.error(f"Error adding player {player_id} to draft list: {e}")
return None
async def remove_from_list(
self,
entry_id: int
) -> bool:
async def remove_from_list(self, entry_id: int) -> bool:
"""
Remove entry from draft list by ID.
@ -209,14 +209,13 @@ class DraftListService(BaseService[DraftList]):
Returns:
True if deletion succeeded
"""
logger.warning("remove_from_list() called with entry_id - use remove_player_from_list() instead")
logger.warning(
"remove_from_list() called with entry_id - use remove_player_from_list() instead"
)
return False
async def remove_player_from_list(
self,
season: int,
team_id: int,
player_id: int
self, season: int, team_id: int, player_id: int
) -> bool:
"""
Remove specific player from team's draft list.
@ -238,7 +237,9 @@ class DraftListService(BaseService[DraftList]):
# Check if player is in list
player_found = any(entry.player_id == player_id for entry in current_list)
if not player_found:
logger.warning(f"Player {player_id} not found in team {team_id} draft list")
logger.warning(
f"Player {player_id} not found in team {team_id} draft list"
)
return False
# Build new list without the player, adjusting ranks
@ -246,22 +247,24 @@ class DraftListService(BaseService[DraftList]):
new_rank = 1
for entry in current_list:
if entry.player_id != player_id:
draft_list_entries.append({
'season': entry.season,
'team_id': entry.team_id,
'player_id': entry.player_id,
'rank': new_rank
})
draft_list_entries.append(
{
"season": entry.season,
"team_id": entry.team_id,
"player_id": entry.player_id,
"rank": new_rank,
}
)
new_rank += 1
# POST updated list (bulk replacement)
client = await self.get_client()
payload = {
'count': len(draft_list_entries),
'draft_list': draft_list_entries
"count": len(draft_list_entries),
"draft_list": draft_list_entries,
}
await client.post(self.endpoint, payload)
await client.post(f"{self.endpoint}/", payload)
logger.info(f"Removed player {player_id} from team {team_id} draft list")
return True
@ -270,11 +273,7 @@ class DraftListService(BaseService[DraftList]):
logger.error(f"Error removing player {player_id} from draft list: {e}")
return False
async def clear_list(
self,
season: int,
team_id: int
) -> bool:
async def clear_list(self, season: int, team_id: int) -> bool:
"""
Clear entire draft list for team.
@ -309,10 +308,7 @@ class DraftListService(BaseService[DraftList]):
return False
async def reorder_list(
self,
season: int,
team_id: int,
new_order: List[int]
self, season: int, team_id: int, new_order: List[int]
) -> bool:
"""
Reorder team's draft list.
@ -342,21 +338,23 @@ class DraftListService(BaseService[DraftList]):
continue
entry = entry_map[player_id]
draft_list_entries.append({
'season': entry.season,
'team_id': entry.team_id,
'player_id': entry.player_id,
'rank': new_rank
})
draft_list_entries.append(
{
"season": entry.season,
"team_id": entry.team_id,
"player_id": entry.player_id,
"rank": new_rank,
}
)
# POST reordered list (bulk replacement)
client = await self.get_client()
payload = {
'count': len(draft_list_entries),
'draft_list': draft_list_entries
"count": len(draft_list_entries),
"draft_list": draft_list_entries,
}
await client.post(self.endpoint, payload)
await client.post(f"{self.endpoint}/", payload)
logger.info(f"Reordered draft list for team {team_id}")
return True
@ -365,12 +363,7 @@ class DraftListService(BaseService[DraftList]):
logger.error(f"Error reordering draft list for team {team_id}: {e}")
return False
async def move_entry_up(
self,
season: int,
team_id: int,
player_id: int
) -> bool:
async def move_entry_up(self, season: int, team_id: int, player_id: int) -> bool:
"""
Move player up one position in draft list (higher priority).
@ -403,7 +396,9 @@ class DraftListService(BaseService[DraftList]):
return False
# Find entry above (rank - 1)
above_entry = next((e for e in entries if e.rank == current_entry.rank - 1), None)
above_entry = next(
(e for e in entries if e.rank == current_entry.rank - 1), None
)
if not above_entry:
logger.error(f"Could not find entry above rank {current_entry.rank}")
return False
@ -421,24 +416,26 @@ class DraftListService(BaseService[DraftList]):
# Keep existing rank
new_rank = entry.rank
draft_list_entries.append({
'season': entry.season,
'team_id': entry.team_id,
'player_id': entry.player_id,
'rank': new_rank
})
draft_list_entries.append(
{
"season": entry.season,
"team_id": entry.team_id,
"player_id": entry.player_id,
"rank": new_rank,
}
)
# Sort by rank
draft_list_entries.sort(key=lambda x: x['rank'])
draft_list_entries.sort(key=lambda x: x["rank"])
# POST updated list (bulk replacement)
client = await self.get_client()
payload = {
'count': len(draft_list_entries),
'draft_list': draft_list_entries
"count": len(draft_list_entries),
"draft_list": draft_list_entries,
}
await client.post(self.endpoint, payload)
await client.post(f"{self.endpoint}/", payload)
logger.info(f"Moved player {player_id} up to rank {current_entry.rank - 1}")
return True
@ -447,12 +444,7 @@ class DraftListService(BaseService[DraftList]):
logger.error(f"Error moving player {player_id} up in draft list: {e}")
return False
async def move_entry_down(
self,
season: int,
team_id: int,
player_id: int
) -> bool:
async def move_entry_down(self, season: int, team_id: int, player_id: int) -> bool:
"""
Move player down one position in draft list (lower priority).
@ -485,7 +477,9 @@ class DraftListService(BaseService[DraftList]):
return False
# Find entry below (rank + 1)
below_entry = next((e for e in entries if e.rank == current_entry.rank + 1), None)
below_entry = next(
(e for e in entries if e.rank == current_entry.rank + 1), None
)
if not below_entry:
logger.error(f"Could not find entry below rank {current_entry.rank}")
return False
@ -503,25 +497,29 @@ class DraftListService(BaseService[DraftList]):
# Keep existing rank
new_rank = entry.rank
draft_list_entries.append({
'season': entry.season,
'team_id': entry.team_id,
'player_id': entry.player_id,
'rank': new_rank
})
draft_list_entries.append(
{
"season": entry.season,
"team_id": entry.team_id,
"player_id": entry.player_id,
"rank": new_rank,
}
)
# Sort by rank
draft_list_entries.sort(key=lambda x: x['rank'])
draft_list_entries.sort(key=lambda x: x["rank"])
# POST updated list (bulk replacement)
client = await self.get_client()
payload = {
'count': len(draft_list_entries),
'draft_list': draft_list_entries
"count": len(draft_list_entries),
"draft_list": draft_list_entries,
}
await client.post(self.endpoint, payload)
logger.info(f"Moved player {player_id} down to rank {current_entry.rank + 1}")
await client.post(f"{self.endpoint}/", payload)
logger.info(
f"Moved player {player_id} down to rank {current_entry.rank + 1}"
)
return True

View File

@ -3,13 +3,14 @@ Injury service for Discord Bot v2.0
Handles injury-related operations including checking, creating, and clearing injuries.
"""
import logging
from typing import Optional, List
from services.base_service import BaseService
from models.injury import Injury
logger = logging.getLogger(f'{__name__}.InjuryService')
logger = logging.getLogger(f"{__name__}.InjuryService")
class InjuryService(BaseService[Injury]):
@ -25,7 +26,7 @@ class InjuryService(BaseService[Injury]):
def __init__(self):
"""Initialize injury service."""
super().__init__(Injury, 'injuries')
super().__init__(Injury, "injuries")
logger.debug("InjuryService initialized")
async def get_active_injury(self, player_id: int, season: int) -> Optional[Injury]:
@ -41,25 +42,31 @@ class InjuryService(BaseService[Injury]):
"""
try:
params = [
('player_id', str(player_id)),
('season', str(season)),
('is_active', 'true')
("player_id", str(player_id)),
("season", str(season)),
("is_active", "true"),
]
injuries = await self.get_all_items(params=params)
if injuries:
logger.debug(f"Found active injury for player {player_id} in season {season}")
logger.debug(
f"Found active injury for player {player_id} in season {season}"
)
return injuries[0]
logger.debug(f"No active injury found for player {player_id} in season {season}")
logger.debug(
f"No active injury found for player {player_id} in season {season}"
)
return None
except Exception as e:
logger.error(f"Error getting active injury for player {player_id}: {e}")
return None
async def get_injuries_by_player(self, player_id: int, season: int, active_only: bool = False) -> List[Injury]:
async def get_injuries_by_player(
self, player_id: int, season: int, active_only: bool = False
) -> List[Injury]:
"""
Get all injuries for a player in a specific season.
@ -72,13 +79,10 @@ class InjuryService(BaseService[Injury]):
List of injuries for the player
"""
try:
params = [
('player_id', str(player_id)),
('season', str(season))
]
params = [("player_id", str(player_id)), ("season", str(season))]
if active_only:
params.append(('is_active', 'true'))
params.append(("is_active", "true"))
injuries = await self.get_all_items(params=params)
logger.debug(f"Retrieved {len(injuries)} injuries for player {player_id}")
@ -88,7 +92,9 @@ class InjuryService(BaseService[Injury]):
logger.error(f"Error getting injuries for player {player_id}: {e}")
return []
async def get_injuries_by_team(self, team_id: int, season: int, active_only: bool = True) -> List[Injury]:
async def get_injuries_by_team(
self, team_id: int, season: int, active_only: bool = True
) -> List[Injury]:
"""
Get all injuries for a team in a specific season.
@ -101,13 +107,10 @@ class InjuryService(BaseService[Injury]):
List of injuries for the team
"""
try:
params = [
('team_id', str(team_id)),
('season', str(season))
]
params = [("team_id", str(team_id)), ("season", str(season))]
if active_only:
params.append(('is_active', 'true'))
params.append(("is_active", "true"))
injuries = await self.get_all_items(params=params)
logger.debug(f"Retrieved {len(injuries)} injuries for team {team_id}")
@ -125,7 +128,7 @@ class InjuryService(BaseService[Injury]):
start_week: int,
start_game: int,
end_week: int,
end_game: int
end_game: int,
) -> Optional[Injury]:
"""
Create a new injury record.
@ -144,22 +147,24 @@ class InjuryService(BaseService[Injury]):
"""
try:
injury_data = {
'season': season,
'player_id': player_id,
'total_games': total_games,
'start_week': start_week,
'start_game': start_game,
'end_week': end_week,
'end_game': end_game,
'is_active': True
"season": season,
"player_id": player_id,
"total_games": total_games,
"start_week": start_week,
"start_game": start_game,
"end_week": end_week,
"end_game": end_game,
"is_active": True,
}
# Call the API to create the injury
client = await self.get_client()
response = await client.post(self.endpoint, injury_data)
response = await client.post(f"{self.endpoint}/", injury_data)
if not response:
logger.error(f"Failed to create injury for player {player_id}: No response from API")
logger.error(
f"Failed to create injury for player {player_id}: No response from API"
)
return None
# Merge the request data with the response to ensure all required fields are present
@ -187,7 +192,9 @@ class InjuryService(BaseService[Injury]):
"""
try:
# Note: API expects is_active as query parameter, not JSON body
updated_injury = await self.patch(injury_id, {'is_active': False}, use_query_params=True)
updated_injury = await self.patch(
injury_id, {"is_active": False}, use_query_params=True
)
if updated_injury:
logger.info(f"Cleared injury {injury_id}")
@ -216,16 +223,18 @@ class InjuryService(BaseService[Injury]):
try:
client = await self.get_client()
params = [
('season', str(season)),
('is_active', 'true'),
('sort', 'return-asc')
("season", str(season)),
("is_active", "true"),
("sort", "return-asc"),
]
response = await client.get(self.endpoint, params=params)
if response and 'injuries' in response:
logger.debug(f"Retrieved {len(response['injuries'])} active injuries for season {season}")
return response['injuries']
if response and "injuries" in response:
logger.debug(
f"Retrieved {len(response['injuries'])} active injuries for season {season}"
)
return response["injuries"]
logger.debug(f"No active injuries found for season {season}")
return []

View File

@ -3,6 +3,7 @@ Play Service
Manages play-by-play data operations for game submission.
"""
from typing import List, Dict, Any
from utils.logging import get_contextual_logger
@ -16,7 +17,7 @@ class PlayService:
def __init__(self):
"""Initialize play service."""
self.logger = get_contextual_logger(f'{__name__}.PlayService')
self.logger = get_contextual_logger(f"{__name__}.PlayService")
self._get_client = get_global_client
async def get_client(self):
@ -39,8 +40,10 @@ class PlayService:
try:
client = await self.get_client()
payload = {'plays': plays}
response = await client.post('plays', payload)
payload = {"plays": plays}
# Trailing slash required: without it, the server returns a 307 redirect
# and aiohttp drops the POST body when following the redirect
response = await client.post("plays/", payload)
self.logger.info(f"Created {len(plays)} plays")
return True
@ -68,7 +71,7 @@ class PlayService:
"""
try:
client = await self.get_client()
response = await client.delete(f'plays/game/{game_id}')
response = await client.delete(f"plays/game/{game_id}")
self.logger.info(f"Deleted plays for game {game_id}")
return True
@ -77,11 +80,7 @@ class PlayService:
self.logger.error(f"Failed to delete plays for game {game_id}: {e}")
raise APIException(f"Failed to delete plays: {e}")
async def get_top_plays_by_wpa(
self,
game_id: int,
limit: int = 3
) -> List[Play]:
async def get_top_plays_by_wpa(self, game_id: int, limit: int = 3) -> List[Play]:
"""
Get top plays by WPA (absolute value) for key plays display.
@ -95,19 +94,15 @@ class PlayService:
try:
client = await self.get_client()
params = [
('game_id', game_id),
('sort', 'wpa-desc'),
('limit', limit)
]
params = [("game_id", game_id), ("sort", "wpa-desc"), ("limit", limit)]
response = await client.get('plays', params=params)
response = await client.get("plays", params=params)
if not response or 'plays' not in response:
self.logger.info(f'No plays found for game ID {game_id}')
if not response or "plays" not in response:
self.logger.info(f"No plays found for game ID {game_id}")
return []
plays = [Play.from_api_data(p) for p in response['plays']]
plays = [Play.from_api_data(p) for p in response["plays"]]
self.logger.debug(f"Retrieved {len(plays)} top plays for game {game_id}")
return plays
@ -129,11 +124,11 @@ class PlayService:
error_str = str(error)
# Common error patterns
if 'Player ID' in error_str and 'not found' in error_str:
if "Player ID" in error_str and "not found" in error_str:
return "Invalid player ID in scorecard data. Please check player IDs."
elif 'Game ID' in error_str and 'not found' in error_str:
elif "Game ID" in error_str and "not found" in error_str:
return "Game not found in database. Please contact an admin."
elif 'validation' in error_str.lower():
elif "validation" in error_str.lower():
return f"Data validation error: {error_str}"
else:
return f"Error submitting plays: {error_str}"

View File

@ -248,7 +248,7 @@ class TransactionService(BaseService[Transaction]):
# POST batch to API
client = await self.get_client()
response = await client.post(self.endpoint, data=batch_data)
response = await client.post(f"{self.endpoint}/", data=batch_data)
# API returns a string like "2 transactions have been added"
# We need to return the original Transaction objects (they won't have IDs assigned by API)

View File

@ -1,18 +1,24 @@
"""
API client tests using aioresponses for clean HTTP mocking
"""
import pytest
import asyncio
from unittest.mock import MagicMock, patch
from aioresponses import aioresponses
from api.client import APIClient, get_api_client, get_global_client, cleanup_global_client
from api.client import (
APIClient,
get_api_client,
get_global_client,
cleanup_global_client,
)
from exceptions import APIException
class TestAPIClientWithAioresponses:
"""Test API client with aioresponses for HTTP mocking."""
@pytest.fixture
def mock_config(self):
"""Mock configuration for testing."""
@ -20,66 +26,57 @@ class TestAPIClientWithAioresponses:
config.db_url = "https://api.example.com"
config.api_token = "test-token"
return config
@pytest.fixture
def api_client(self, mock_config):
"""Create API client with mocked config."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
return APIClient()
@pytest.mark.asyncio
async def test_get_request_success(self, api_client):
"""Test successful GET request."""
expected_data = {"id": 1, "name": "Test Player"}
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players/1",
payload=expected_data,
status=200
status=200,
)
result = await api_client.get("players", object_id=1)
assert result == expected_data
@pytest.mark.asyncio
async def test_get_request_404(self, api_client):
"""Test GET request returning 404."""
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players/999",
status=404
)
m.get("https://api.example.com/v3/players/999", status=404)
result = await api_client.get("players", object_id=999)
assert result is None
@pytest.mark.asyncio
async def test_get_request_401_auth_error(self, api_client):
"""Test GET request with authentication error."""
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players",
status=401
)
m.get("https://api.example.com/v3/players", status=401)
with pytest.raises(APIException, match="Authentication failed"):
await api_client.get("players")
@pytest.mark.asyncio
async def test_get_request_403_forbidden(self, api_client):
"""Test GET request with forbidden error."""
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players",
status=403
)
m.get("https://api.example.com/v3/players", status=403)
with pytest.raises(APIException, match="Access forbidden"):
await api_client.get("players")
@pytest.mark.asyncio
async def test_get_request_500_server_error(self, api_client):
"""Test GET request with server error."""
@ -87,135 +84,127 @@ class TestAPIClientWithAioresponses:
m.get(
"https://api.example.com/v3/players",
status=500,
body="Internal Server Error"
body="Internal Server Error",
)
with pytest.raises(APIException, match="API request failed with status 500"):
with pytest.raises(
APIException, match="API request failed with status 500"
):
await api_client.get("players")
@pytest.mark.asyncio
async def test_get_request_with_params(self, api_client):
"""Test GET request with query parameters."""
expected_data = {"count": 2, "players": [{"id": 1}, {"id": 2}]}
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players?team_id=5&season=12",
payload=expected_data,
status=200
status=200,
)
result = await api_client.get("players", params=[("team_id", "5"), ("season", "12")])
result = await api_client.get(
"players", params=[("team_id", "5"), ("season", "12")]
)
assert result == expected_data
@pytest.mark.asyncio
async def test_post_request_success(self, api_client):
"""Test successful POST request."""
input_data = {"name": "New Player", "position": "C"}
expected_response = {"id": 1, "name": "New Player", "position": "C"}
with aioresponses() as m:
m.post(
"https://api.example.com/v3/players",
payload=expected_response,
status=201
status=201,
)
result = await api_client.post("players", input_data)
assert result == expected_response
@pytest.mark.asyncio
async def test_post_request_400_error(self, api_client):
"""Test POST request with validation error."""
input_data = {"invalid": "data"}
with aioresponses() as m:
m.post(
"https://api.example.com/v3/players",
status=400,
body="Invalid data"
"https://api.example.com/v3/players", status=400, body="Invalid data"
)
with pytest.raises(APIException, match="POST request failed with status 400"):
with pytest.raises(
APIException, match="POST request failed with status 400"
):
await api_client.post("players", input_data)
@pytest.mark.asyncio
async def test_put_request_success(self, api_client):
"""Test successful PUT request."""
update_data = {"name": "Updated Player"}
expected_response = {"id": 1, "name": "Updated Player"}
with aioresponses() as m:
m.put(
"https://api.example.com/v3/players/1",
payload=expected_response,
status=200
status=200,
)
result = await api_client.put("players", update_data, object_id=1)
assert result == expected_response
@pytest.mark.asyncio
async def test_put_request_404(self, api_client):
"""Test PUT request with 404."""
update_data = {"name": "Updated Player"}
with aioresponses() as m:
m.put(
"https://api.example.com/v3/players/999",
status=404
)
m.put("https://api.example.com/v3/players/999", status=404)
result = await api_client.put("players", update_data, object_id=999)
assert result is None
@pytest.mark.asyncio
async def test_delete_request_success(self, api_client):
"""Test successful DELETE request."""
with aioresponses() as m:
m.delete(
"https://api.example.com/v3/players/1",
status=204
)
m.delete("https://api.example.com/v3/players/1", status=204)
result = await api_client.delete("players", object_id=1)
assert result is True
@pytest.mark.asyncio
async def test_delete_request_404(self, api_client):
"""Test DELETE request with 404."""
with aioresponses() as m:
m.delete(
"https://api.example.com/v3/players/999",
status=404
)
m.delete("https://api.example.com/v3/players/999", status=404)
result = await api_client.delete("players", object_id=999)
assert result is False
@pytest.mark.asyncio
async def test_delete_request_200_success(self, api_client):
"""Test DELETE request with 200 success."""
with aioresponses() as m:
m.delete(
"https://api.example.com/v3/players/1",
status=200
)
m.delete("https://api.example.com/v3/players/1", status=200)
result = await api_client.delete("players", object_id=1)
assert result is True
class TestAPIClientHelpers:
"""Test API client helper functions."""
@pytest.fixture
def mock_config(self):
"""Mock configuration for testing."""
@ -223,49 +212,49 @@ class TestAPIClientHelpers:
config.db_url = "https://api.example.com"
config.api_token = "test-token"
return config
@pytest.mark.asyncio
async def test_get_api_client_context_manager(self, mock_config):
"""Test get_api_client context manager."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m:
m.get(
"https://api.example.com/v3/test",
payload={"success": True},
status=200
status=200,
)
async with get_api_client() as client:
assert isinstance(client, APIClient)
result = await client.get("test")
assert result == {"success": True}
@pytest.mark.asyncio
async def test_global_client_management(self, mock_config):
"""Test global client getter and cleanup."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
# Get global client
client1 = await get_global_client()
client2 = await get_global_client()
# Should return same instance
assert client1 is client2
assert isinstance(client1, APIClient)
# Test cleanup
await cleanup_global_client()
# New client should be different instance
client3 = await get_global_client()
assert client3 is not client1
# Clean up for other tests
await cleanup_global_client()
class TestIntegrationScenarios:
"""Test realistic integration scenarios."""
@pytest.fixture
def mock_config(self):
"""Mock configuration for testing."""
@ -273,11 +262,11 @@ class TestIntegrationScenarios:
config.db_url = "https://api.example.com"
config.api_token = "test-token"
return config
@pytest.mark.asyncio
async def test_player_retrieval_with_team_lookup(self, mock_config):
"""Test realistic scenario: get player with team data."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m:
# Mock player data response
player_data = {
@ -287,43 +276,41 @@ class TestIntegrationScenarios:
"season": 12,
"team_id": 5,
"image": "https://example.com/player1.jpg",
"pos_1": "C"
"pos_1": "C",
}
m.get(
"https://api.example.com/v3/players/1",
payload=player_data,
status=200
status=200,
)
# Mock team data response
team_data = {
"id": 5,
"abbrev": "TST",
"sname": "Test Team",
"lname": "Test Team Full Name",
"season": 12
"season": 12,
}
m.get(
"https://api.example.com/v3/teams/5",
payload=team_data,
status=200
"https://api.example.com/v3/teams/5", payload=team_data, status=200
)
client = APIClient()
# Get player
player = await client.get("players", object_id=1)
assert player["name"] == "Test Player"
assert player["team_id"] == 5
# Get team for player
team = await client.get("teams", object_id=player["team_id"])
assert team["sname"] == "Test Team"
@pytest.mark.asyncio
async def test_api_response_format_handling(self, mock_config):
"""Test handling of the API's count + list format."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m:
# Mock API response with count format
api_response = {
@ -336,7 +323,7 @@ class TestIntegrationScenarios:
"season": 12,
"team_id": 5,
"image": "https://example.com/player1.jpg",
"pos_1": "C"
"pos_1": "C",
},
{
"id": 2,
@ -345,93 +332,93 @@ class TestIntegrationScenarios:
"season": 12,
"team_id": 6,
"image": "https://example.com/player2.jpg",
"pos_1": "1B"
}
]
"pos_1": "1B",
},
],
}
m.get(
"https://api.example.com/v3/players?team_id=5",
payload=api_response,
status=200
status=200,
)
client = APIClient()
result = await client.get("players", params=[("team_id", "5")])
assert result["count"] == 25
assert len(result["players"]) == 2
assert result["players"][0]["name"] == "Player 1"
@pytest.mark.asyncio
async def test_error_recovery_scenarios(self, mock_config):
"""Test error handling and recovery."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m:
# First request fails with 500
m.get(
"https://api.example.com/v3/players/1",
status=500,
body="Internal Server Error"
body="Internal Server Error",
)
# Second request succeeds
m.get(
"https://api.example.com/v3/players/2",
payload={"id": 2, "name": "Working Player"},
status=200
status=200,
)
client = APIClient()
# First request should raise exception
with pytest.raises(APIException, match="API request failed"):
await client.get("players", object_id=1)
# Second request should work fine
result = await client.get("players", object_id=2)
assert result["name"] == "Working Player"
# Client should still be functional
await client.close()
@pytest.mark.asyncio
async def test_concurrent_requests(self, mock_config):
"""Test multiple concurrent requests."""
import asyncio
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
with aioresponses() as m:
# Mock multiple endpoints
for i in range(1, 4):
m.get(
f"https://api.example.com/v3/players/{i}",
payload={"id": i, "name": f"Player {i}"},
status=200
status=200,
)
client = APIClient()
# Make concurrent requests
tasks = [
client.get("players", object_id=1),
client.get("players", object_id=2),
client.get("players", object_id=3)
client.get("players", object_id=3),
]
results = await asyncio.gather(*tasks)
assert len(results) == 3
assert results[0]["name"] == "Player 1"
assert results[1]["name"] == "Player 2"
assert results[2]["name"] == "Player 3"
await client.close()
class TestAPIClientCoverageExtras:
"""Additional coverage tests for API client edge cases."""
@pytest.fixture
def mock_config(self):
"""Mock configuration for testing."""
@ -439,98 +426,104 @@ class TestAPIClientCoverageExtras:
config.db_url = "https://api.example.com"
config.api_token = "test-token"
return config
@pytest.mark.asyncio
async def test_global_client_cleanup_when_none(self):
"""Test cleanup when no global client exists."""
# Ensure no global client exists
await cleanup_global_client()
# Should not raise error
await cleanup_global_client()
@pytest.mark.asyncio
async def test_url_building_edge_cases(self, mock_config):
"""Test URL building with various edge cases."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
client = APIClient()
# Test trailing slash handling
client.base_url = "https://api.example.com/"
url = client._build_url("players")
assert url == "https://api.example.com/v3/players"
assert "//" not in url.replace("https://", "")
@pytest.mark.asyncio
async def test_parameter_handling_edge_cases(self, mock_config):
"""Test parameter handling with various scenarios."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
client = APIClient()
# Test with existing query string
url = client._add_params("https://example.com/api?existing=true", [("new", "param")])
url = client._add_params(
"https://example.com/api?existing=true", [("new", "param")]
)
assert url == "https://example.com/api?existing=true&new=param"
# Test with no parameters
url = client._add_params("https://example.com/api")
assert url == "https://example.com/api"
@pytest.mark.asyncio
async def test_timeout_error_handling(self, mock_config):
"""Test timeout error handling using aioresponses."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
client = APIClient()
# Test timeout using aioresponses exception parameter
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players",
exception=asyncio.TimeoutError("Request timed out")
exception=asyncio.TimeoutError("Request timed out"),
)
with pytest.raises(APIException, match="API call failed.*Request timed out"):
with pytest.raises(
APIException, match="API call failed.*Request timed out"
):
await client.get("players")
await client.close()
@pytest.mark.asyncio
async def test_generic_exception_handling(self, mock_config):
"""Test generic exception handling."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
client = APIClient()
# Test generic exception
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players",
exception=Exception("Generic error")
exception=Exception("Generic error"),
)
with pytest.raises(APIException, match="API call failed.*Generic error"):
with pytest.raises(
APIException, match="API call failed.*Generic error"
):
await client.get("players")
await client.close()
@pytest.mark.asyncio
async def test_session_closed_handling(self, mock_config):
"""Test handling of closed session."""
with patch('api.client.get_config', return_value=mock_config):
with patch("api.client.get_config", return_value=mock_config):
# Test that the client recreates session when needed
with aioresponses() as m:
m.get(
"https://api.example.com/v3/players",
payload={"success": True},
status=200
status=200,
)
client = APIClient()
# Close the session manually
await client._ensure_session()
await client._session.close()
# Client should recreate session and work fine
result = await client.get("players")
assert result == {"success": True}
await client.close()
await client.close()

View File

@ -1,6 +1,7 @@
"""
Tests for BaseService functionality
"""
import pytest
from unittest.mock import AsyncMock
@ -10,6 +11,7 @@ from models.base import SBABaseModel
class MockModel(SBABaseModel):
"""Mock model for testing BaseService."""
id: int
name: str
value: int = 100
@ -17,240 +19,229 @@ class MockModel(SBABaseModel):
class TestBaseService:
"""Test BaseService functionality."""
@pytest.fixture
def mock_client(self):
"""Mock API client."""
client = AsyncMock()
return client
@pytest.fixture
def base_service(self, mock_client):
"""Create BaseService instance for testing."""
service = BaseService(MockModel, 'mocks', client=mock_client)
service = BaseService(MockModel, "mocks", client=mock_client)
return service
@pytest.mark.asyncio
async def test_init(self):
"""Test service initialization."""
service = BaseService(MockModel, 'test_endpoint')
service = BaseService(MockModel, "test_endpoint")
assert service.model_class == MockModel
assert service.endpoint == 'test_endpoint'
assert service.endpoint == "test_endpoint"
assert service._client is None
@pytest.mark.asyncio
async def test_get_by_id_success(self, base_service, mock_client):
"""Test successful get_by_id."""
mock_data = {'id': 1, 'name': 'Test', 'value': 200}
mock_data = {"id": 1, "name": "Test", "value": 200}
mock_client.get.return_value = mock_data
result = await base_service.get_by_id(1)
assert isinstance(result, MockModel)
assert result.id == 1
assert result.name == 'Test'
assert result.name == "Test"
assert result.value == 200
mock_client.get.assert_called_once_with('mocks', object_id=1)
mock_client.get.assert_called_once_with("mocks", object_id=1)
@pytest.mark.asyncio
async def test_get_by_id_not_found(self, base_service, mock_client):
"""Test get_by_id when object not found."""
mock_client.get.return_value = None
result = await base_service.get_by_id(999)
assert result is None
mock_client.get.assert_called_once_with('mocks', object_id=999)
mock_client.get.assert_called_once_with("mocks", object_id=999)
@pytest.mark.asyncio
async def test_get_all_with_count(self, base_service, mock_client):
"""Test get_all with count response format."""
mock_data = {
'count': 2,
'mocks': [
{'id': 1, 'name': 'Test1', 'value': 100},
{'id': 2, 'name': 'Test2', 'value': 200}
]
"count": 2,
"mocks": [
{"id": 1, "name": "Test1", "value": 100},
{"id": 2, "name": "Test2", "value": 200},
],
}
mock_client.get.return_value = mock_data
result, count = await base_service.get_all()
assert len(result) == 2
assert count == 2
assert all(isinstance(item, MockModel) for item in result)
mock_client.get.assert_called_once_with('mocks', params=None)
mock_client.get.assert_called_once_with("mocks", params=None)
@pytest.mark.asyncio
async def test_get_all_items_convenience(self, base_service, mock_client):
"""Test get_all_items convenience method."""
mock_data = {
'count': 1,
'mocks': [{'id': 1, 'name': 'Test', 'value': 100}]
}
mock_data = {"count": 1, "mocks": [{"id": 1, "name": "Test", "value": 100}]}
mock_client.get.return_value = mock_data
result = await base_service.get_all_items()
assert len(result) == 1
assert isinstance(result[0], MockModel)
@pytest.mark.asyncio
async def test_create_success(self, base_service, mock_client):
"""Test successful object creation."""
input_data = {'name': 'New Item', 'value': 300}
response_data = {'id': 3, 'name': 'New Item', 'value': 300}
input_data = {"name": "New Item", "value": 300}
response_data = {"id": 3, "name": "New Item", "value": 300}
mock_client.post.return_value = response_data
result = await base_service.create(input_data)
assert isinstance(result, MockModel)
assert result.id == 3
assert result.name == 'New Item'
mock_client.post.assert_called_once_with('mocks', input_data)
assert result.name == "New Item"
mock_client.post.assert_called_once_with("mocks/", input_data)
@pytest.mark.asyncio
async def test_update_success(self, base_service, mock_client):
"""Test successful object update."""
update_data = {'name': 'Updated'}
response_data = {'id': 1, 'name': 'Updated', 'value': 100}
update_data = {"name": "Updated"}
response_data = {"id": 1, "name": "Updated", "value": 100}
mock_client.put.return_value = response_data
result = await base_service.update(1, update_data)
assert isinstance(result, MockModel)
assert result.name == 'Updated'
mock_client.put.assert_called_once_with('mocks', update_data, object_id=1)
assert result.name == "Updated"
mock_client.put.assert_called_once_with("mocks", update_data, object_id=1)
@pytest.mark.asyncio
async def test_delete_success(self, base_service, mock_client):
"""Test successful object deletion."""
mock_client.delete.return_value = True
result = await base_service.delete(1)
assert result is True
mock_client.delete.assert_called_once_with('mocks', object_id=1)
mock_client.delete.assert_called_once_with("mocks", object_id=1)
@pytest.mark.asyncio
async def test_get_by_field(self, base_service, mock_client):
"""Test get_by_field functionality."""
mock_data = {
'count': 1,
'mocks': [{'id': 1, 'name': 'Test', 'value': 100}]
}
mock_data = {"count": 1, "mocks": [{"id": 1, "name": "Test", "value": 100}]}
mock_client.get.return_value = mock_data
result = await base_service.get_by_field('name', 'Test')
result = await base_service.get_by_field("name", "Test")
assert len(result) == 1
mock_client.get.assert_called_once_with('mocks', params=[('name', 'Test')])
mock_client.get.assert_called_once_with("mocks", params=[("name", "Test")])
def test_extract_items_and_count_standard_format(self, base_service):
"""Test response parsing for standard format."""
data = {
'count': 3,
'mocks': [
{'id': 1, 'name': 'Test1'},
{'id': 2, 'name': 'Test2'},
{'id': 3, 'name': 'Test3'}
]
"count": 3,
"mocks": [
{"id": 1, "name": "Test1"},
{"id": 2, "name": "Test2"},
{"id": 3, "name": "Test3"},
],
}
items, count = base_service._extract_items_and_count_from_response(data)
assert len(items) == 3
assert count == 3
assert items[0]['name'] == 'Test1'
assert items[0]["name"] == "Test1"
def test_extract_items_and_count_single_object(self, base_service):
"""Test response parsing for single object."""
data = {'id': 1, 'name': 'Single'}
data = {"id": 1, "name": "Single"}
items, count = base_service._extract_items_and_count_from_response(data)
assert len(items) == 1
assert count == 1
assert items[0] == data
def test_extract_items_and_count_direct_list(self, base_service):
"""Test response parsing for direct list."""
data = [
{'id': 1, 'name': 'Test1'},
{'id': 2, 'name': 'Test2'}
]
data = [{"id": 1, "name": "Test1"}, {"id": 2, "name": "Test2"}]
items, count = base_service._extract_items_and_count_from_response(data)
assert len(items) == 2
assert count == 2
class TestBaseServiceExtras:
"""Additional coverage tests for BaseService edge cases."""
@pytest.mark.asyncio
async def test_base_service_additional_methods(self):
"""Test additional BaseService methods for coverage."""
from services.base_service import BaseService
from models.base import SBABaseModel
class TestModel(SBABaseModel):
name: str
value: int = 100
mock_client = AsyncMock()
service = BaseService(TestModel, 'test', client=mock_client)
service = BaseService(TestModel, "test", client=mock_client)
# Test count method
mock_client.reset_mock()
mock_client.get.return_value = {'count': 42, 'test': []}
count = await service.count(params=[('active', 'true')])
mock_client.get.return_value = {"count": 42, "test": []}
count = await service.count(params=[("active", "true")])
assert count == 42
# Test update_from_model with ID
mock_client.reset_mock()
model = TestModel(id=1, name="Updated", value=300)
mock_client.put.return_value = {"id": 1, "name": "Updated", "value": 300}
result = await service.update_from_model(model)
assert result.name == "Updated"
# Test update_from_model without ID
model_no_id = TestModel(name="Test")
with pytest.raises(ValueError, match="Cannot update TestModel without ID"):
await service.update_from_model(model_no_id)
def test_base_service_response_parsing_edge_cases(self):
"""Test edge cases in response parsing."""
from services.base_service import BaseService
from models.base import SBABaseModel
class TestModel(SBABaseModel):
name: str
service = BaseService(TestModel, 'test')
service = BaseService(TestModel, "test")
# Test with 'items' field
data = {'count': 2, 'items': [{'name': 'Item1'}, {'name': 'Item2'}]}
data = {"count": 2, "items": [{"name": "Item1"}, {"name": "Item2"}]}
items, count = service._extract_items_and_count_from_response(data)
assert len(items) == 2
assert count == 2
# Test with 'data' field
data = {'count': 1, 'data': [{'name': 'DataItem'}]}
data = {"count": 1, "data": [{"name": "DataItem"}]}
items, count = service._extract_items_and_count_from_response(data)
assert len(items) == 1
assert count == 1
# Test with count but no recognizable list field
data = {'count': 5, 'unknown_field': [{'name': 'Item'}]}
data = {"count": 5, "unknown_field": [{"name": "Item"}]}
items, count = service._extract_items_and_count_from_response(data)
assert len(items) == 0
assert count == 5
# Test with unexpected data type
items, count = service._extract_items_and_count_from_response("unexpected")
assert len(items) == 0
assert count == 0
assert count == 0

View File

@ -39,7 +39,7 @@ class TradeEmbedView(discord.ui.View):
"""Check if user has permission to interact with this view."""
if interaction.user.id != self.user_id:
await interaction.response.send_message(
"You don't have permission to use this trade builder.",
"You don't have permission to use this trade builder.",
ephemeral=True,
)
return False
@ -47,57 +47,48 @@ class TradeEmbedView(discord.ui.View):
async def on_timeout(self) -> None:
"""Handle view timeout."""
# Disable all buttons when timeout occurs
for item in self.children:
if isinstance(item, discord.ui.Button):
item.disabled = True
@discord.ui.button(label="Remove Move", style=discord.ButtonStyle.red, emoji="")
@discord.ui.button(label="Remove Move", style=discord.ButtonStyle.red)
async def remove_move_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle remove move button click."""
if self.builder.is_empty:
await interaction.response.send_message(
"No moves to remove. Add some moves first!", ephemeral=True
"No moves to remove. Add some moves first!", ephemeral=True
)
return
# Create select menu for move removal
select_view = RemoveTradeMovesView(self.builder, self.user_id)
embed = await create_trade_embed(self.builder)
await interaction.response.edit_message(embed=embed, view=select_view)
@discord.ui.button(
label="Validate Trade", style=discord.ButtonStyle.secondary, emoji="🔍"
)
@discord.ui.button(label="Validate Trade", style=discord.ButtonStyle.secondary)
async def validate_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle validate trade button click."""
await interaction.response.defer(ephemeral=True)
# Perform detailed validation
validation = await self.builder.validate_trade()
# Create validation report
if validation.is_legal:
status_emoji = ""
status_text = "**Trade is LEGAL**"
color = EmbedColors.SUCCESS
else:
status_emoji = ""
status_text = "**Trade has ERRORS**"
color = EmbedColors.ERROR
embed = EmbedTemplate.create_base_embed(
title=f"{status_emoji} Trade Validation Report",
title="Trade Validation Report",
description=status_text,
color=color,
)
# Add team-by-team validation
for participant in self.builder.trade.participants:
team_validation = validation.get_participant_validation(participant.team.id)
if team_validation:
@ -111,59 +102,52 @@ class TradeEmbedView(discord.ui.View):
team_status.append(team_validation.pre_existing_transactions_note)
embed.add_field(
name=f"🏟️ {participant.team.abbrev} - {participant.team.sname}",
name=f"{participant.team.abbrev} - {participant.team.sname}",
value="\n".join(team_status),
inline=False,
)
# Add overall errors and suggestions
if validation.all_errors:
error_text = "\n".join([f" {error}" for error in validation.all_errors])
embed.add_field(name="Errors", value=error_text, inline=False)
error_text = "\n".join([f"- {error}" for error in validation.all_errors])
embed.add_field(name="Errors", value=error_text, inline=False)
if validation.all_suggestions:
suggestion_text = "\n".join(
[f"💡 {suggestion}" for suggestion in validation.all_suggestions]
[f"- {suggestion}" for suggestion in validation.all_suggestions]
)
embed.add_field(name="💡 Suggestions", value=suggestion_text, inline=False)
embed.add_field(name="Suggestions", value=suggestion_text, inline=False)
await interaction.followup.send(embed=embed, ephemeral=True)
@discord.ui.button(
label="Submit Trade", style=discord.ButtonStyle.primary, emoji="📤"
)
@discord.ui.button(label="Submit Trade", style=discord.ButtonStyle.primary)
async def submit_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
"""Handle submit trade button click."""
if self.builder.is_empty:
await interaction.response.send_message(
"Cannot submit empty trade. Add some moves first!", ephemeral=True
"Cannot submit empty trade. Add some moves first!", ephemeral=True
)
return
# Validate before submission
validation = await self.builder.validate_trade()
if not validation.is_legal:
error_msg = "**Cannot submit illegal trade:**\n"
error_msg += "\n".join([f" {error}" for error in validation.all_errors])
error_msg = "**Cannot submit illegal trade:**\n"
error_msg += "\n".join([f"- {error}" for error in validation.all_errors])
if validation.all_suggestions:
error_msg += "\n\n**Suggestions:**\n"
error_msg += "\n".join(
[f"💡 {suggestion}" for suggestion in validation.all_suggestions]
[f"- {suggestion}" for suggestion in validation.all_suggestions]
)
await interaction.response.send_message(error_msg, ephemeral=True)
return
# Show confirmation modal
modal = SubmitTradeConfirmationModal(self.builder)
await interaction.response.send_modal(modal)
@discord.ui.button(
label="Cancel Trade", style=discord.ButtonStyle.secondary, emoji=""
)
@discord.ui.button(label="Cancel Trade", style=discord.ButtonStyle.secondary)
async def cancel_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
@ -171,13 +155,12 @@ class TradeEmbedView(discord.ui.View):
self.builder.clear_trade()
embed = await create_trade_embed(self.builder)
# Disable all buttons after cancellation
for item in self.children:
if isinstance(item, discord.ui.Button):
item.disabled = True
await interaction.response.edit_message(
content="**Trade cancelled and cleared.**", embed=embed, view=self
content="**Trade cancelled and cleared.**", embed=embed, view=self
)
self.stop()
@ -190,13 +173,11 @@ class RemoveTradeMovesView(discord.ui.View):
self.builder = builder
self.user_id = user_id
# Create select menu with current moves
if not builder.is_empty:
self.add_item(RemoveTradeMovesSelect(builder))
# Add back button
back_button = discord.ui.Button(
label="Back", style=discord.ButtonStyle.secondary, emoji="⬅️"
label="Back", style=discord.ButtonStyle.secondary
)
back_button.callback = self.back_callback
self.add_item(back_button)
@ -218,25 +199,21 @@ class RemoveTradeMovesSelect(discord.ui.Select):
def __init__(self, builder: TradeBuilder):
self.builder = builder
# Create options from all moves (cross-team and supplementary)
options = []
move_count = 0
# Add cross-team moves
for move in builder.trade.cross_team_moves[
:20
]: # Limit to avoid Discord's 25 option limit
options.append(
discord.SelectOption(
label=f"{move.player.name}",
description=move.description[:100], # Discord description limit
description=move.description[:100],
value=str(move.player.id),
emoji="🔄",
)
)
move_count += 1
# Add supplementary moves if there's room
remaining_slots = 25 - move_count
for move in builder.trade.supplementary_moves[:remaining_slots]:
options.append(
@ -244,7 +221,6 @@ class RemoveTradeMovesSelect(discord.ui.Select):
label=f"{move.player.name}",
description=move.description[:100],
value=str(move.player.id),
emoji="⚙️",
)
)
@ -263,18 +239,16 @@ class RemoveTradeMovesSelect(discord.ui.Select):
if success:
await interaction.response.send_message(
f"Removed move for player ID {player_id}", ephemeral=True
f"Removed move for player ID {player_id}", ephemeral=True
)
# Update the embed
main_view = TradeEmbedView(self.builder, interaction.user.id)
embed = await create_trade_embed(self.builder)
# Edit the original message
await interaction.edit_original_response(embed=embed, view=main_view)
else:
await interaction.response.send_message(
f"Could not remove move: {error_msg}", ephemeral=True
f"Could not remove move: {error_msg}", ephemeral=True
)
@ -301,7 +275,7 @@ class SubmitTradeConfirmationModal(discord.ui.Modal):
"""Handle confirmation submission - posts acceptance view to trade channel."""
if self.confirmation.value.upper() != "CONFIRM":
await interaction.response.send_message(
"Trade not submitted. You must type 'CONFIRM' exactly.",
"Trade not submitted. You must type 'CONFIRM' exactly.",
ephemeral=True,
)
return
@ -309,18 +283,13 @@ class SubmitTradeConfirmationModal(discord.ui.Modal):
await interaction.response.defer(ephemeral=True)
try:
# Update trade status to PROPOSED
self.builder.trade.status = TradeStatus.PROPOSED
# Create acceptance embed and view
acceptance_embed = await create_trade_acceptance_embed(self.builder)
acceptance_view = TradeAcceptanceView(self.builder)
# Find the trade channel to post to
channel = self.trade_channel
if not channel:
# Try to find trade channel by name pattern
trade_channel_name = f"trade-{'-'.join(t.abbrev.lower() for t in self.builder.participating_teams)}"
for ch in interaction.guild.text_channels: # type: ignore
if (
ch.name.startswith("trade-")
@ -330,28 +299,26 @@ class SubmitTradeConfirmationModal(discord.ui.Modal):
break
if channel:
# Post acceptance request to trade channel
await channel.send(
content="📋 **Trade submitted for approval!** All teams must accept to complete the trade.",
content="**Trade submitted for approval.** All teams must accept to complete the trade.",
embed=acceptance_embed,
view=acceptance_view,
)
await interaction.followup.send(
f"✅ Trade submitted for approval!\n\nThe acceptance request has been posted to {channel.mention}.\n"
f"Trade submitted for approval.\n\nThe acceptance request has been posted to {channel.mention}.\n"
f"All participating teams must click **Accept Trade** to finalize.",
ephemeral=True,
)
else:
# No trade channel found, post in current channel
await interaction.followup.send(
content="📋 **Trade submitted for approval!** All teams must accept to complete the trade.",
content="**Trade submitted for approval.** All teams must accept to complete the trade.",
embed=acceptance_embed,
view=acceptance_view,
)
except Exception as e:
await interaction.followup.send(
f"Error submitting trade: {str(e)}", ephemeral=True
f"Error submitting trade: {str(e)}", ephemeral=True
)
@ -375,15 +342,14 @@ class TradeAcceptanceView(discord.ui.View):
if not user_team:
await interaction.response.send_message(
"You don't own a team in the league.", ephemeral=True
"You don't own a team in the league.", ephemeral=True
)
return False
# Check if their team (or organization) is participating
participant = self.builder.trade.get_participant_by_organization(user_team)
if not participant:
await interaction.response.send_message(
"Your team is not part of this trade.", ephemeral=True
"Your team is not part of this trade.", ephemeral=True
)
return False
@ -395,9 +361,7 @@ class TradeAcceptanceView(discord.ui.View):
if isinstance(item, discord.ui.Button):
item.disabled = True
@discord.ui.button(
label="Accept Trade", style=discord.ButtonStyle.success, emoji=""
)
@discord.ui.button(label="Accept Trade", style=discord.ButtonStyle.success)
async def accept_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
@ -406,41 +370,33 @@ class TradeAcceptanceView(discord.ui.View):
if not user_team:
return
# Find the participating team (could be org affiliate)
participant = self.builder.trade.get_participant_by_organization(user_team)
if not participant:
return
team_id = participant.team.id
# Check if already accepted
if self.builder.has_team_accepted(team_id):
await interaction.response.send_message(
f"{participant.team.abbrev} has already accepted this trade.",
f"{participant.team.abbrev} has already accepted this trade.",
ephemeral=True,
)
return
# Record acceptance
all_accepted = self.builder.accept_trade(team_id)
if all_accepted:
# All teams accepted - finalize the trade
await self._finalize_trade(interaction)
else:
# Update embed to show new acceptance status
embed = await create_trade_acceptance_embed(self.builder)
await interaction.response.edit_message(embed=embed, view=self)
# Send confirmation to channel
await interaction.followup.send(
f"**{participant.team.abbrev}** has accepted the trade! "
f"**{participant.team.abbrev}** has accepted the trade. "
f"({len(self.builder.accepted_teams)}/{self.builder.team_count} teams)"
)
@discord.ui.button(
label="Reject Trade", style=discord.ButtonStyle.danger, emoji=""
)
@discord.ui.button(label="Reject Trade", style=discord.ButtonStyle.danger)
async def reject_button(
self, interaction: discord.Interaction, button: discord.ui.Button
):
@ -453,20 +409,16 @@ class TradeAcceptanceView(discord.ui.View):
if not participant:
return
# Reject the trade
self.builder.reject_trade()
# Disable buttons
self.accept_button.disabled = True
self.reject_button.disabled = True
# Update embed to show rejection
embed = await create_trade_rejection_embed(self.builder, participant.team)
await interaction.response.edit_message(embed=embed, view=self)
# Notify the channel
await interaction.followup.send(
f"**{participant.team.abbrev}** has rejected the trade.\n\n"
f"**{participant.team.abbrev}** has rejected the trade.\n\n"
f"The trade has been moved back to **DRAFT** status. "
f"Teams can continue negotiating using `/trade` commands."
)
@ -480,11 +432,9 @@ class TradeAcceptanceView(discord.ui.View):
config = get_config()
# Get next week for transactions
current = await league_service.get_current_state()
next_week = current.week + 1 if current else 1
# Create FA team for reference
fa_team = Team(
id=config.free_agent_team_id,
abbrev="FA",
@ -493,13 +443,10 @@ class TradeAcceptanceView(discord.ui.View):
season=self.builder.trade.season,
) # type: ignore
# Create transactions from all moves
transactions: List[Transaction] = []
move_id = f"Trade-{self.builder.trade_id}-{int(datetime.now(timezone.utc).timestamp())}"
# Process cross-team moves
for move in self.builder.trade.cross_team_moves:
# Get actual team affiliates for from/to based on roster type
if move.from_roster == RosterType.MAJOR_LEAGUE:
old_team = move.source_team
elif move.from_roster == RosterType.MINOR_LEAGUE:
@ -544,11 +491,10 @@ class TradeAcceptanceView(discord.ui.View):
oldteam=old_team,
newteam=new_team,
cancelled=False,
frozen=False, # Trades are NOT frozen - immediately effective
frozen=False,
)
transactions.append(transaction)
# Process supplementary moves
for move in self.builder.trade.supplementary_moves:
if move.from_roster == RosterType.MAJOR_LEAGUE:
old_team = move.source_team
@ -598,11 +544,10 @@ class TradeAcceptanceView(discord.ui.View):
oldteam=old_team,
newteam=new_team,
cancelled=False,
frozen=False, # Trades are NOT frozen - immediately effective
frozen=False,
)
transactions.append(transaction)
# POST transactions to database
if transactions:
created_transactions = (
await transaction_service.create_transaction_batch(transactions)
@ -610,7 +555,6 @@ class TradeAcceptanceView(discord.ui.View):
else:
created_transactions = []
# Post to #transaction-log channel
if created_transactions and interaction.client:
await post_trade_to_log(
bot=interaction.client,
@ -619,28 +563,23 @@ class TradeAcceptanceView(discord.ui.View):
effective_week=next_week,
)
# Update trade status
self.builder.trade.status = TradeStatus.ACCEPTED
# Disable buttons
self.accept_button.disabled = True
self.reject_button.disabled = True
# Update embed to show completion
embed = await create_trade_complete_embed(
self.builder, len(created_transactions), next_week
)
await interaction.edit_original_response(embed=embed, view=self)
# Send completion message
await interaction.followup.send(
f"🎉 **Trade Complete!**\n\n"
f"**Trade Complete!**\n\n"
f"All {self.builder.team_count} teams have accepted the trade.\n"
f"**{len(created_transactions)} transactions** have been created for **Week {next_week}**.\n\n"
f"Trade ID: `{self.builder.trade_id}`"
)
# Clear the trade builder
for team in self.builder.participating_teams:
clear_trade_builder_by_team(team.id)
@ -648,69 +587,64 @@ class TradeAcceptanceView(discord.ui.View):
except Exception as e:
await interaction.followup.send(
f"Error finalizing trade: {str(e)}", ephemeral=True
f"Error finalizing trade: {str(e)}", ephemeral=True
)
async def create_trade_acceptance_embed(builder: TradeBuilder) -> discord.Embed:
"""Create embed showing trade details and acceptance status."""
embed = EmbedTemplate.create_base_embed(
title=f"📋 Trade Pending Acceptance - {builder.trade.get_trade_summary()}",
title=f"Trade Pending Acceptance - {builder.trade.get_trade_summary()}",
description="All participating teams must accept to complete the trade.",
color=EmbedColors.WARNING,
)
# Show participating teams
team_list = [
f" {team.abbrev} - {team.sname}" for team in builder.participating_teams
f"- {team.abbrev} - {team.sname}" for team in builder.participating_teams
]
embed.add_field(
name=f"🏟️ Participating Teams ({builder.team_count})",
name=f"Participating Teams ({builder.team_count})",
value="\n".join(team_list),
inline=False,
)
# Show cross-team moves
if builder.trade.cross_team_moves:
moves_text = ""
for move in builder.trade.cross_team_moves[:10]:
moves_text += f" {move.description}\n"
moves_text += f"- {move.description}\n"
if len(builder.trade.cross_team_moves) > 10:
moves_text += f"... and {len(builder.trade.cross_team_moves) - 10} more"
embed.add_field(
name=f"🔄 Player Exchanges ({len(builder.trade.cross_team_moves)})",
name=f"Player Exchanges ({len(builder.trade.cross_team_moves)})",
value=moves_text,
inline=False,
)
# Show supplementary moves if any
if builder.trade.supplementary_moves:
supp_text = ""
for move in builder.trade.supplementary_moves[:5]:
supp_text += f" {move.description}\n"
supp_text += f"- {move.description}\n"
if len(builder.trade.supplementary_moves) > 5:
supp_text += f"... and {len(builder.trade.supplementary_moves) - 5} more"
embed.add_field(
name=f"⚙️ Supplementary Moves ({len(builder.trade.supplementary_moves)})",
name=f"Supplementary Moves ({len(builder.trade.supplementary_moves)})",
value=supp_text,
inline=False,
)
# Show acceptance status
status_lines = []
for team in builder.participating_teams:
if team.id in builder.accepted_teams:
status_lines.append(f"**{team.abbrev}** - Accepted")
status_lines.append(f"**{team.abbrev}** - Accepted")
else:
status_lines.append(f"**{team.abbrev}** - Pending")
status_lines.append(f"**{team.abbrev}** - Pending")
embed.add_field(
name="📊 Acceptance Status", value="\n".join(status_lines), inline=False
name="Acceptance Status", value="\n".join(status_lines), inline=False
)
# Add footer
embed.set_footer(
text=f"Trade ID: {builder.trade_id} {len(builder.accepted_teams)}/{builder.team_count} teams accepted"
text=f"Trade ID: {builder.trade_id} | {len(builder.accepted_teams)}/{builder.team_count} teams accepted"
)
return embed
@ -721,7 +655,7 @@ async def create_trade_rejection_embed(
) -> discord.Embed:
"""Create embed showing trade was rejected."""
embed = EmbedTemplate.create_base_embed(
title=f"Trade Rejected - {builder.trade.get_trade_summary()}",
title=f"Trade Rejected - {builder.trade.get_trade_summary()}",
description=f"**{rejecting_team.abbrev}** has rejected the trade.\n\n"
f"The trade has been moved back to **DRAFT** status.\n"
f"Teams can continue negotiating using `/trade` commands.",
@ -738,29 +672,27 @@ async def create_trade_complete_embed(
) -> discord.Embed:
"""Create embed showing trade was completed."""
embed = EmbedTemplate.create_base_embed(
title=f"🎉 Trade Complete! - {builder.trade.get_trade_summary()}",
description=f"All {builder.team_count} teams have accepted the trade!\n\n"
title=f"Trade Complete - {builder.trade.get_trade_summary()}",
description=f"All {builder.team_count} teams have accepted the trade.\n\n"
f"**{transaction_count} transactions** created for **Week {effective_week}**.",
color=EmbedColors.SUCCESS,
)
# Show final acceptance status (all green)
status_lines = [
f"**{team.abbrev}** - Accepted" for team in builder.participating_teams
f"**{team.abbrev}** - Accepted" for team in builder.participating_teams
]
embed.add_field(name="📊 Final Status", value="\n".join(status_lines), inline=False)
embed.add_field(name="Final Status", value="\n".join(status_lines), inline=False)
# Show cross-team moves
if builder.trade.cross_team_moves:
moves_text = ""
for move in builder.trade.cross_team_moves[:8]:
moves_text += f" {move.description}\n"
moves_text += f"- {move.description}\n"
if len(builder.trade.cross_team_moves) > 8:
moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more"
embed.add_field(name=f"🔄 Player Exchanges", value=moves_text, inline=False)
embed.add_field(name="Player Exchanges", value=moves_text, inline=False)
embed.set_footer(
text=f"Trade ID: {builder.trade_id} Effective: Week {effective_week}"
text=f"Trade ID: {builder.trade_id} | Effective: Week {effective_week}"
)
return embed
@ -776,7 +708,6 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed:
Returns:
Discord embed with current trade state
"""
# Determine embed color based on trade status
if builder.is_empty:
color = EmbedColors.SECONDARY
else:
@ -784,22 +715,20 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed:
color = EmbedColors.SUCCESS if validation.is_legal else EmbedColors.WARNING
embed = EmbedTemplate.create_base_embed(
title=f"📋 Trade Builder - {builder.trade.get_trade_summary()}",
description=f"Build your multi-team trade",
title=f"Trade Builder - {builder.trade.get_trade_summary()}",
description="Build your multi-team trade",
color=color,
)
# Add participating teams section
team_list = [
f" {team.abbrev} - {team.sname}" for team in builder.participating_teams
f"- {team.abbrev} - {team.sname}" for team in builder.participating_teams
]
embed.add_field(
name=f"🏟️ Participating Teams ({builder.team_count})",
name=f"Participating Teams ({builder.team_count})",
value="\n".join(team_list) if team_list else "*No teams yet*",
inline=False,
)
# Add current moves section
if builder.is_empty:
embed.add_field(
name="Current Moves",
@ -807,29 +736,23 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed:
inline=False,
)
else:
# Show cross-team moves
if builder.trade.cross_team_moves:
moves_text = ""
for i, move in enumerate(
builder.trade.cross_team_moves[:8], 1
): # Limit display
for i, move in enumerate(builder.trade.cross_team_moves[:8], 1):
moves_text += f"{i}. {move.description}\n"
if len(builder.trade.cross_team_moves) > 8:
moves_text += f"... and {len(builder.trade.cross_team_moves) - 8} more"
embed.add_field(
name=f"🔄 Player Exchanges ({len(builder.trade.cross_team_moves)})",
name=f"Player Exchanges ({len(builder.trade.cross_team_moves)})",
value=moves_text,
inline=False,
)
# Show supplementary moves
if builder.trade.supplementary_moves:
supp_text = ""
for i, move in enumerate(
builder.trade.supplementary_moves[:5], 1
): # Limit display
for i, move in enumerate(builder.trade.supplementary_moves[:5], 1):
supp_text += f"{i}. {move.description}\n"
if len(builder.trade.supplementary_moves) > 5:
@ -838,31 +761,33 @@ async def create_trade_embed(builder: TradeBuilder) -> discord.Embed:
)
embed.add_field(
name=f"⚙️ Supplementary Moves ({len(builder.trade.supplementary_moves)})",
name=f"Supplementary Moves ({len(builder.trade.supplementary_moves)})",
value=supp_text,
inline=False,
)
# Add quick validation summary
validation = await builder.validate_trade()
if validation.is_legal:
status_text = "Trade appears legal"
status_text = "Trade appears legal"
else:
error_count = len(validation.all_errors)
status_text = f"{error_count} error{'s' if error_count != 1 else ''} found"
status_text = f"{error_count} error{'s' if error_count != 1 else ''} found\n"
status_text += "\n".join(f"- {error}" for error in validation.all_errors)
if validation.all_suggestions:
status_text += "\n" + "\n".join(
f"- {s}" for s in validation.all_suggestions
)
embed.add_field(name="🔍 Quick Status", value=status_text, inline=False)
embed.add_field(name="Quick Status", value=status_text, inline=False)
# Add instructions for adding more moves
embed.add_field(
name=" Build Your Trade",
value="• `/trade add-player` - Add player exchanges\n• `/trade supplementary` - Add internal moves\n `/trade add-team` - Add more teams",
name="Build Your Trade",
value="- `/trade add-player` - Add player exchanges\n- `/trade supplementary` - Add internal moves\n- `/trade add-team` - Add more teams",
inline=False,
)
# Add footer with trade ID and timestamp
embed.set_footer(
text=f"Trade ID: {builder.trade_id} Created: {datetime.fromisoformat(builder.trade.created_at).strftime('%H:%M:%S')}"
text=f"Trade ID: {builder.trade_id} | Created: {datetime.fromisoformat(builder.trade.created_at).strftime('%H:%M:%S')}"
)
return embed