Release: merge next-release into main #64

Merged
cal merged 11 commits from next-release into main 2026-03-17 21:43:38 +00:00
13 changed files with 920 additions and 626 deletions

View File

@ -51,12 +51,6 @@ Per season updates:
""" """
WEEK_NUMS = {
'regular': {
}
}
def model_csv_headers(this_obj, exclude=None) -> List: def model_csv_headers(this_obj, exclude=None) -> List:
data = model_to_dict(this_obj, recurse=False, exclude=exclude) data = model_to_dict(this_obj, recurse=False, exclude=exclude)
@ -458,7 +452,7 @@ class Team(BaseModel):
active_roster['WARa'] -= move.player.wara active_roster['WARa'] -= move.player.wara
try: try:
active_roster['players'].remove(move.player) active_roster['players'].remove(move.player)
except: except Exception:
print(f'I could not drop {move.player.name}') print(f'I could not drop {move.player.name}')
for move in all_adds: for move in all_adds:
@ -519,7 +513,7 @@ class Team(BaseModel):
# print(f'SIL dropping {move.player.name} id ({move.player.get_id()}) for {move.player.wara} WARa') # print(f'SIL dropping {move.player.name} id ({move.player.get_id()}) for {move.player.wara} WARa')
try: try:
short_roster['players'].remove(move.player) short_roster['players'].remove(move.player)
except: except Exception:
print(f'I could not drop {move.player.name}') print(f'I could not drop {move.player.name}')
for move in all_adds: for move in all_adds:
@ -580,7 +574,7 @@ class Team(BaseModel):
# print(f'LIL dropping {move.player.name} id ({move.player.get_id()}) for {move.player.wara} WARa') # print(f'LIL dropping {move.player.name} id ({move.player.get_id()}) for {move.player.wara} WARa')
try: try:
long_roster['players'].remove(move.player) long_roster['players'].remove(move.player)
except: except Exception:
print(f'I could not drop {move.player.name}') print(f'I could not drop {move.player.name}')
for move in all_adds: for move in all_adds:
@ -2351,7 +2345,7 @@ class CustomCommand(BaseModel):
try: try:
import json import json
return json.loads(self.tags) return json.loads(self.tags)
except: except Exception:
return [] return []
def set_tags_list(self, tags_list): def set_tags_list(self, tags_list):

View File

@ -11,8 +11,8 @@ from fastapi import HTTPException, Response
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from redis import Redis from redis import Redis
date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' date = f"{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}"
logger = logging.getLogger('discord_app') logger = logging.getLogger("discord_app")
# date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' # date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
# log_level = logger.info if os.environ.get('LOG_LEVEL') == 'INFO' else 'WARN' # log_level = logger.info if os.environ.get('LOG_LEVEL') == 'INFO' else 'WARN'
@ -23,10 +23,10 @@ logger = logging.getLogger('discord_app')
# ) # )
# Redis configuration # Redis configuration
REDIS_HOST = os.environ.get('REDIS_HOST', 'localhost') REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
REDIS_PORT = int(os.environ.get('REDIS_PORT', '6379')) REDIS_PORT = int(os.environ.get("REDIS_PORT", "6379"))
REDIS_DB = int(os.environ.get('REDIS_DB', '0')) REDIS_DB = int(os.environ.get("REDIS_DB", "0"))
CACHE_ENABLED = os.environ.get('CACHE_ENABLED', 'true').lower() == 'true' CACHE_ENABLED = os.environ.get("CACHE_ENABLED", "true").lower() == "true"
# Initialize Redis client with connection error handling # Initialize Redis client with connection error handling
if not CACHE_ENABLED: if not CACHE_ENABLED:
@ -40,7 +40,7 @@ else:
db=REDIS_DB, db=REDIS_DB,
decode_responses=True, decode_responses=True,
socket_connect_timeout=5, socket_connect_timeout=5,
socket_timeout=5 socket_timeout=5,
) )
# Test connection # Test connection
redis_client.ping() redis_client.ping()
@ -50,12 +50,16 @@ else:
redis_client = None redis_client = None
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
priv_help = False if not os.environ.get('PRIVATE_IN_SCHEMA') else os.environ.get('PRIVATE_IN_SCHEMA').upper() priv_help = (
PRIVATE_IN_SCHEMA = True if priv_help == 'TRUE' else False False
if not os.environ.get("PRIVATE_IN_SCHEMA")
else os.environ.get("PRIVATE_IN_SCHEMA").upper()
)
PRIVATE_IN_SCHEMA = True if priv_help == "TRUE" else False
def valid_token(token): def valid_token(token):
return token == os.environ.get('API_TOKEN') return token == os.environ.get("API_TOKEN")
def update_season_batting_stats(player_ids, season, db_connection): def update_season_batting_stats(player_ids, season, db_connection):
@ -63,17 +67,19 @@ def update_season_batting_stats(player_ids, season, db_connection):
Update season batting stats for specific players in a given season. Update season batting stats for specific players in a given season.
Recalculates stats from stratplay data and upserts into seasonbattingstats table. Recalculates stats from stratplay data and upserts into seasonbattingstats table.
""" """
if not player_ids: if not player_ids:
logger.warning("update_season_batting_stats called with empty player_ids list") logger.warning("update_season_batting_stats called with empty player_ids list")
return return
# Convert single player_id to list for consistency # Convert single player_id to list for consistency
if isinstance(player_ids, int): if isinstance(player_ids, int):
player_ids = [player_ids] player_ids = [player_ids]
logger.info(f"Updating season batting stats for {len(player_ids)} players in season {season}") logger.info(
f"Updating season batting stats for {len(player_ids)} players in season {season}"
)
try: try:
# SQL query to recalculate and upsert batting stats # SQL query to recalculate and upsert batting stats
query = """ query = """
@ -217,12 +223,14 @@ def update_season_batting_stats(player_ids, season, db_connection):
sb = EXCLUDED.sb, sb = EXCLUDED.sb,
cs = EXCLUDED.cs; cs = EXCLUDED.cs;
""" """
# Execute the query with parameters using the passed database connection # Execute the query with parameters using the passed database connection
db_connection.execute_sql(query, [season, player_ids, season, player_ids]) db_connection.execute_sql(query, [season, player_ids, season, player_ids])
logger.info(f"Successfully updated season batting stats for {len(player_ids)} players in season {season}") logger.info(
f"Successfully updated season batting stats for {len(player_ids)} players in season {season}"
)
except Exception as e: except Exception as e:
logger.error(f"Error updating season batting stats: {e}") logger.error(f"Error updating season batting stats: {e}")
raise raise
@ -233,17 +241,19 @@ def update_season_pitching_stats(player_ids, season, db_connection):
Update season pitching stats for specific players in a given season. Update season pitching stats for specific players in a given season.
Recalculates stats from stratplay and decision data and upserts into seasonpitchingstats table. Recalculates stats from stratplay and decision data and upserts into seasonpitchingstats table.
""" """
if not player_ids: if not player_ids:
logger.warning("update_season_pitching_stats called with empty player_ids list") logger.warning("update_season_pitching_stats called with empty player_ids list")
return return
# Convert single player_id to list for consistency # Convert single player_id to list for consistency
if isinstance(player_ids, int): if isinstance(player_ids, int):
player_ids = [player_ids] player_ids = [player_ids]
logger.info(f"Updating season pitching stats for {len(player_ids)} players in season {season}") logger.info(
f"Updating season pitching stats for {len(player_ids)} players in season {season}"
)
try: try:
# SQL query to recalculate and upsert pitching stats # SQL query to recalculate and upsert pitching stats
query = """ query = """
@ -357,8 +367,28 @@ def update_season_pitching_stats(player_ids, season, db_connection):
WHEN SUM(sp.bb) > 0 WHEN SUM(sp.bb) > 0
THEN ROUND(SUM(sp.so)::DECIMAL / SUM(sp.bb), 2) THEN ROUND(SUM(sp.so)::DECIMAL / SUM(sp.bb), 2)
ELSE 0.0 ELSE 0.0
END AS kperbb END AS kperbb,
-- Runners left on base when pitcher recorded the 3rd out
SUM(CASE WHEN sp.on_first_final IS NOT NULL AND sp.on_first_final != 4 AND sp.starting_outs + sp.outs = 3 THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_second_final IS NOT NULL AND sp.on_second_final != 4 AND sp.starting_outs + sp.outs = 3 THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_third_final IS NOT NULL AND sp.on_third_final != 4 AND sp.starting_outs + sp.outs = 3 THEN 1 ELSE 0 END) AS lob_2outs,
-- RBI allowed (excluding HR) per runner opportunity
CASE
WHEN (SUM(CASE WHEN sp.on_first IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_second IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_third IS NOT NULL THEN 1 ELSE 0 END)) > 0
THEN ROUND(
(SUM(sp.rbi) - SUM(sp.homerun))::DECIMAL /
(SUM(CASE WHEN sp.on_first IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_second IS NOT NULL THEN 1 ELSE 0 END) +
SUM(CASE WHEN sp.on_third IS NOT NULL THEN 1 ELSE 0 END)),
3
)
ELSE 0.000
END AS rbipercent
FROM stratplay sp FROM stratplay sp
JOIN stratgame sg ON sg.id = sp.game_id JOIN stratgame sg ON sg.id = sp.game_id
JOIN player p ON p.id = sp.pitcher_id JOIN player p ON p.id = sp.pitcher_id
@ -402,7 +432,7 @@ def update_season_pitching_stats(player_ids, season, db_connection):
ps.bphr, ps.bpfo, ps.bp1b, ps.bplo, ps.wp, ps.balk, ps.bphr, ps.bpfo, ps.bp1b, ps.bplo, ps.wp, ps.balk,
ps.wpa * -1, ps.era, ps.whip, ps.avg, ps.obp, ps.slg, ps.ops, ps.woba, ps.wpa * -1, ps.era, ps.whip, ps.avg, ps.obp, ps.slg, ps.ops, ps.woba,
ps.hper9, ps.kper9, ps.bbper9, ps.kperbb, ps.hper9, ps.kper9, ps.bbper9, ps.kperbb,
0.0, 0.0, COALESCE(ps.re24 * -1, 0.0) ps.lob_2outs, ps.rbipercent, COALESCE(ps.re24 * -1, 0.0)
FROM pitching_stats ps FROM pitching_stats ps
LEFT JOIN decision_stats ds ON ps.player_id = ds.player_id AND ps.season = ds.season LEFT JOIN decision_stats ds ON ps.player_id = ds.player_id AND ps.season = ds.season
ON CONFLICT (player_id, season) ON CONFLICT (player_id, season)
@ -460,12 +490,14 @@ def update_season_pitching_stats(player_ids, season, db_connection):
rbipercent = EXCLUDED.rbipercent, rbipercent = EXCLUDED.rbipercent,
re24 = EXCLUDED.re24; re24 = EXCLUDED.re24;
""" """
# Execute the query with parameters using the passed database connection # Execute the query with parameters using the passed database connection
db_connection.execute_sql(query, [season, player_ids, season, player_ids]) db_connection.execute_sql(query, [season, player_ids, season, player_ids])
logger.info(f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}") logger.info(
f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}"
)
except Exception as e: except Exception as e:
logger.error(f"Error updating season pitching stats: {e}") logger.error(f"Error updating season pitching stats: {e}")
raise raise
@ -474,26 +506,24 @@ def update_season_pitching_stats(player_ids, season, db_connection):
def send_webhook_message(message: str) -> bool: def send_webhook_message(message: str) -> bool:
""" """
Send a message to Discord via webhook. Send a message to Discord via webhook.
Args: Args:
message: The message content to send message: The message content to send
Returns: Returns:
bool: True if successful, False otherwise bool: True if successful, False otherwise
""" """
webhook_url = "https://discord.com/api/webhooks/1408811717424840876/7RXG_D5IqovA3Jwa9YOobUjVcVMuLc6cQyezABcWuXaHo5Fvz1en10M7J43o3OJ3bzGW" webhook_url = "https://discord.com/api/webhooks/1408811717424840876/7RXG_D5IqovA3Jwa9YOobUjVcVMuLc6cQyezABcWuXaHo5Fvz1en10M7J43o3OJ3bzGW"
try: try:
payload = { payload = {"content": message}
"content": message
}
response = requests.post(webhook_url, json=payload, timeout=10) response = requests.post(webhook_url, json=payload, timeout=10)
response.raise_for_status() response.raise_for_status()
logger.info(f"Webhook message sent successfully: {message[:100]}...") logger.info(f"Webhook message sent successfully: {message[:100]}...")
return True return True
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.error(f"Failed to send webhook message: {e}") logger.error(f"Failed to send webhook message: {e}")
return False return False
@ -502,99 +532,106 @@ def send_webhook_message(message: str) -> bool:
return False return False
def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True): def cache_result(
ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True
):
""" """
Decorator to cache function results in Redis with parameter normalization. Decorator to cache function results in Redis with parameter normalization.
Args: Args:
ttl: Time to live in seconds (default: 5 minutes) ttl: Time to live in seconds (default: 5 minutes)
key_prefix: Prefix for cache keys (default: "api") key_prefix: Prefix for cache keys (default: "api")
normalize_params: Remove None/empty values to reduce cache variations (default: True) normalize_params: Remove None/empty values to reduce cache variations (default: True)
Usage: Usage:
@cache_result(ttl=600, key_prefix="stats") @cache_result(ttl=600, key_prefix="stats")
async def get_player_stats(player_id: int, season: Optional[int] = None): async def get_player_stats(player_id: int, season: Optional[int] = None):
# expensive operation # expensive operation
return stats return stats
# These will use the same cache entry when normalize_params=True: # These will use the same cache entry when normalize_params=True:
# get_player_stats(123, None) and get_player_stats(123) # get_player_stats(123, None) and get_player_stats(123)
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
# Skip caching if Redis is not available # Skip caching if Redis is not available
if redis_client is None: if redis_client is None:
return await func(*args, **kwargs) return await func(*args, **kwargs)
try: try:
# Normalize parameters to reduce cache variations # Normalize parameters to reduce cache variations
normalized_kwargs = kwargs.copy() normalized_kwargs = kwargs.copy()
if normalize_params: if normalize_params:
# Remove None values and empty collections # Remove None values and empty collections
normalized_kwargs = { normalized_kwargs = {
k: v for k, v in kwargs.items() k: v
for k, v in kwargs.items()
if v is not None and v != [] and v != "" and v != {} if v is not None and v != [] and v != "" and v != {}
} }
# Generate more readable cache key # Generate more readable cache key
args_str = "_".join(str(arg) for arg in args if arg is not None) args_str = "_".join(str(arg) for arg in args if arg is not None)
kwargs_str = "_".join([ kwargs_str = "_".join(
f"{k}={v}" for k, v in sorted(normalized_kwargs.items()) [f"{k}={v}" for k, v in sorted(normalized_kwargs.items())]
]) )
# Combine args and kwargs for cache key # Combine args and kwargs for cache key
key_parts = [key_prefix, func.__name__] key_parts = [key_prefix, func.__name__]
if args_str: if args_str:
key_parts.append(args_str) key_parts.append(args_str)
if kwargs_str: if kwargs_str:
key_parts.append(kwargs_str) key_parts.append(kwargs_str)
cache_key = ":".join(key_parts) cache_key = ":".join(key_parts)
# Truncate very long cache keys to prevent Redis key size limits # Truncate very long cache keys to prevent Redis key size limits
if len(cache_key) > 200: if len(cache_key) > 200:
cache_key = f"{key_prefix}:{func.__name__}:{hash(cache_key)}" cache_key = f"{key_prefix}:{func.__name__}:{hash(cache_key)}"
# Try to get from cache # Try to get from cache
cached_result = redis_client.get(cache_key) cached_result = redis_client.get(cache_key)
if cached_result is not None: if cached_result is not None:
logger.debug(f"Cache hit for key: {cache_key}") logger.debug(f"Cache hit for key: {cache_key}")
return json.loads(cached_result) return json.loads(cached_result)
# Cache miss - execute function # Cache miss - execute function
logger.debug(f"Cache miss for key: {cache_key}") logger.debug(f"Cache miss for key: {cache_key}")
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
# Skip caching for Response objects (like CSV downloads) as they can't be properly serialized # Skip caching for Response objects (like CSV downloads) as they can't be properly serialized
if not isinstance(result, Response): if not isinstance(result, Response):
# Store in cache with TTL # Store in cache with TTL
redis_client.setex( redis_client.setex(
cache_key, cache_key,
ttl, ttl,
json.dumps(result, default=str, ensure_ascii=False) json.dumps(result, default=str, ensure_ascii=False),
) )
else: else:
logger.debug(f"Skipping cache for Response object from {func.__name__}") logger.debug(
f"Skipping cache for Response object from {func.__name__}"
)
return result return result
except Exception as e: except Exception as e:
# If caching fails, log error and continue without caching # If caching fails, log error and continue without caching
logger.error(f"Cache error for {func.__name__}: {e}") logger.error(f"Cache error for {func.__name__}: {e}")
return await func(*args, **kwargs) return await func(*args, **kwargs)
return wrapper return wrapper
return decorator return decorator
def invalidate_cache(pattern: str = "*"): def invalidate_cache(pattern: str = "*"):
""" """
Invalidate cache entries matching a pattern. Invalidate cache entries matching a pattern.
Args: Args:
pattern: Redis pattern to match keys (default: "*" for all) pattern: Redis pattern to match keys (default: "*" for all)
Usage: Usage:
invalidate_cache("stats:*") # Clear all stats cache invalidate_cache("stats:*") # Clear all stats cache
invalidate_cache("api:get_player_*") # Clear specific player cache invalidate_cache("api:get_player_*") # Clear specific player cache
@ -602,12 +639,14 @@ def invalidate_cache(pattern: str = "*"):
if redis_client is None: if redis_client is None:
logger.warning("Cannot invalidate cache: Redis not available") logger.warning("Cannot invalidate cache: Redis not available")
return 0 return 0
try: try:
keys = redis_client.keys(pattern) keys = redis_client.keys(pattern)
if keys: if keys:
deleted = redis_client.delete(*keys) deleted = redis_client.delete(*keys)
logger.info(f"Invalidated {deleted} cache entries matching pattern: {pattern}") logger.info(
f"Invalidated {deleted} cache entries matching pattern: {pattern}"
)
return deleted return deleted
else: else:
logger.debug(f"No cache entries found matching pattern: {pattern}") logger.debug(f"No cache entries found matching pattern: {pattern}")
@ -620,13 +659,13 @@ def invalidate_cache(pattern: str = "*"):
def get_cache_stats() -> dict: def get_cache_stats() -> dict:
""" """
Get Redis cache statistics. Get Redis cache statistics.
Returns: Returns:
dict: Cache statistics including memory usage, key count, etc. dict: Cache statistics including memory usage, key count, etc.
""" """
if redis_client is None: if redis_client is None:
return {"status": "unavailable", "message": "Redis not connected"} return {"status": "unavailable", "message": "Redis not connected"}
try: try:
info = redis_client.info() info = redis_client.info()
return { return {
@ -634,7 +673,7 @@ def get_cache_stats() -> dict:
"memory_used": info.get("used_memory_human", "unknown"), "memory_used": info.get("used_memory_human", "unknown"),
"total_keys": redis_client.dbsize(), "total_keys": redis_client.dbsize(),
"connected_clients": info.get("connected_clients", 0), "connected_clients": info.get("connected_clients", 0),
"uptime_seconds": info.get("uptime_in_seconds", 0) "uptime_seconds": info.get("uptime_in_seconds", 0),
} }
except Exception as e: except Exception as e:
logger.error(f"Error getting cache stats: {e}") logger.error(f"Error getting cache stats: {e}")
@ -642,34 +681,35 @@ def get_cache_stats() -> dict:
def add_cache_headers( def add_cache_headers(
max_age: int = 300, max_age: int = 300,
cache_type: str = "public", cache_type: str = "public",
vary: Optional[str] = None, vary: Optional[str] = None,
etag: bool = False etag: bool = False,
): ):
""" """
Decorator to add HTTP cache headers to FastAPI responses. Decorator to add HTTP cache headers to FastAPI responses.
Args: Args:
max_age: Cache duration in seconds (default: 5 minutes) max_age: Cache duration in seconds (default: 5 minutes)
cache_type: "public", "private", or "no-cache" (default: "public") cache_type: "public", "private", or "no-cache" (default: "public")
vary: Vary header value (e.g., "Accept-Encoding, Authorization") vary: Vary header value (e.g., "Accept-Encoding, Authorization")
etag: Whether to generate ETag based on response content etag: Whether to generate ETag based on response content
Usage: Usage:
@add_cache_headers(max_age=1800, cache_type="public") @add_cache_headers(max_age=1800, cache_type="public")
async def get_static_data(): async def get_static_data():
return {"data": "static content"} return {"data": "static content"}
@add_cache_headers(max_age=60, cache_type="private", vary="Authorization") @add_cache_headers(max_age=60, cache_type="private", vary="Authorization")
async def get_user_data(): async def get_user_data():
return {"data": "user specific"} return {"data": "user specific"}
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
# Handle different response types # Handle different response types
if isinstance(result, Response): if isinstance(result, Response):
response = result response = result
@ -677,38 +717,41 @@ def add_cache_headers(
# Convert to Response with JSON content # Convert to Response with JSON content
response = Response( response = Response(
content=json.dumps(result, default=str, ensure_ascii=False), content=json.dumps(result, default=str, ensure_ascii=False),
media_type="application/json" media_type="application/json",
) )
else: else:
# Handle other response types # Handle other response types
response = Response(content=str(result)) response = Response(content=str(result))
# Build Cache-Control header # Build Cache-Control header
cache_control_parts = [cache_type] cache_control_parts = [cache_type]
if cache_type != "no-cache" and max_age > 0: if cache_type != "no-cache" and max_age > 0:
cache_control_parts.append(f"max-age={max_age}") cache_control_parts.append(f"max-age={max_age}")
response.headers["Cache-Control"] = ", ".join(cache_control_parts) response.headers["Cache-Control"] = ", ".join(cache_control_parts)
# Add Vary header if specified # Add Vary header if specified
if vary: if vary:
response.headers["Vary"] = vary response.headers["Vary"] = vary
# Add ETag if requested # Add ETag if requested
if etag and (hasattr(result, '__dict__') or isinstance(result, (dict, list))): if etag and (
hasattr(result, "__dict__") or isinstance(result, (dict, list))
):
content_hash = hashlib.md5( content_hash = hashlib.md5(
json.dumps(result, default=str, sort_keys=True).encode() json.dumps(result, default=str, sort_keys=True).encode()
).hexdigest() ).hexdigest()
response.headers["ETag"] = f'"{content_hash}"' response.headers["ETag"] = f'"{content_hash}"'
# Add Last-Modified header with current time for dynamic content # Add Last-Modified header with current time for dynamic content
response.headers["Last-Modified"] = datetime.datetime.now(datetime.timezone.utc).strftime( response.headers["Last-Modified"] = datetime.datetime.now(
"%a, %d %b %Y %H:%M:%S GMT" datetime.timezone.utc
) ).strftime("%a, %d %b %Y %H:%M:%S GMT")
return response return response
return wrapper return wrapper
return decorator return decorator
@ -718,52 +761,59 @@ def handle_db_errors(func):
Ensures proper cleanup of database connections and provides consistent error handling. Ensures proper cleanup of database connections and provides consistent error handling.
Includes comprehensive logging with function context, timing, and stack traces. Includes comprehensive logging with function context, timing, and stack traces.
""" """
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
import time import time
import traceback import traceback
from .db_engine import db # Import here to avoid circular imports from .db_engine import db # Import here to avoid circular imports
start_time = time.time() start_time = time.time()
func_name = f"{func.__module__}.{func.__name__}" func_name = f"{func.__module__}.{func.__name__}"
# Sanitize arguments for logging (exclude sensitive data) # Sanitize arguments for logging (exclude sensitive data)
safe_args = [] safe_args = []
safe_kwargs = {} safe_kwargs = {}
try: try:
# Log sanitized arguments (avoid logging tokens, passwords, etc.) # Log sanitized arguments (avoid logging tokens, passwords, etc.)
for arg in args: for arg in args:
if hasattr(arg, '__dict__') and hasattr(arg, 'url'): # FastAPI Request object if hasattr(arg, "__dict__") and hasattr(
safe_args.append(f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})") arg, "url"
): # FastAPI Request object
safe_args.append(
f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})"
)
else: else:
safe_args.append(str(arg)[:100]) # Truncate long values safe_args.append(str(arg)[:100]) # Truncate long values
for key, value in kwargs.items(): for key, value in kwargs.items():
if key.lower() in ['token', 'password', 'secret', 'key']: if key.lower() in ["token", "password", "secret", "key"]:
safe_kwargs[key] = '[REDACTED]' safe_kwargs[key] = "[REDACTED]"
else: else:
safe_kwargs[key] = str(value)[:100] # Truncate long values safe_kwargs[key] = str(value)[:100] # Truncate long values
logger.info(f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}") logger.info(
f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}"
)
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info(f"Completed {func_name} successfully in {elapsed_time:.3f}s") logger.info(f"Completed {func_name} successfully in {elapsed_time:.3f}s")
return result return result
except Exception as e: except Exception as e:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
error_trace = traceback.format_exc() error_trace = traceback.format_exc()
logger.error(f"Database error in {func_name} after {elapsed_time:.3f}s") logger.error(f"Database error in {func_name} after {elapsed_time:.3f}s")
logger.error(f"Function args: {safe_args}") logger.error(f"Function args: {safe_args}")
logger.error(f"Function kwargs: {safe_kwargs}") logger.error(f"Function kwargs: {safe_kwargs}")
logger.error(f"Exception: {str(e)}") logger.error(f"Exception: {str(e)}")
logger.error(f"Full traceback:\n{error_trace}") logger.error(f"Full traceback:\n{error_trace}")
try: try:
logger.info(f"Attempting database rollback for {func_name}") logger.info(f"Attempting database rollback for {func_name}")
db.rollback() db.rollback()
@ -775,8 +825,12 @@ def handle_db_errors(func):
db.close() db.close()
logger.info(f"Database connection closed for {func_name}") logger.info(f"Database connection closed for {func_name}")
except Exception as close_error: except Exception as close_error:
logger.error(f"Error closing database connection in {func_name}: {close_error}") logger.error(
f"Error closing database connection in {func_name}: {close_error}"
raise HTTPException(status_code=500, detail=f'Database error in {func_name}: {str(e)}') )
raise HTTPException(
status_code=500, detail=f"Database error in {func_name}: {str(e)}"
)
return wrapper return wrapper

View File

@ -10,38 +10,64 @@ from fastapi.openapi.utils import get_openapi
# from fastapi.openapi.docs import get_swagger_ui_html # from fastapi.openapi.docs import get_swagger_ui_html
# from fastapi.openapi.utils import get_openapi # from fastapi.openapi.utils import get_openapi
from .routers_v3 import current, players, results, schedules, standings, teams, transactions, battingstats, pitchingstats, fieldingstats, draftpicks, draftlist, managers, awards, draftdata, keepers, stratgame, stratplay, injuries, decisions, divisions, sbaplayers, custom_commands, help_commands, views from .routers_v3 import (
current,
players,
results,
schedules,
standings,
teams,
transactions,
battingstats,
pitchingstats,
fieldingstats,
draftpicks,
draftlist,
managers,
awards,
draftdata,
keepers,
stratgame,
stratplay,
injuries,
decisions,
divisions,
sbaplayers,
custom_commands,
help_commands,
views,
)
# date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' # date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
log_level = logging.INFO if os.environ.get('LOG_LEVEL') == 'INFO' else logging.WARNING log_level = logging.INFO if os.environ.get("LOG_LEVEL") == "INFO" else logging.WARNING
# logging.basicConfig( # logging.basicConfig(
# filename=f'logs/database/{date}.log', # filename=f'logs/database/{date}.log',
# format='%(asctime)s - sba-database - %(levelname)s - %(message)s', # format='%(asctime)s - sba-database - %(levelname)s - %(message)s',
# level=log_level # level=log_level
# ) # )
logger = logging.getLogger('discord_app') logger = logging.getLogger("discord_app")
logger.setLevel(log_level) logger.setLevel(log_level)
handler = RotatingFileHandler( handler = RotatingFileHandler(
filename='./logs/sba-database.log', filename="./logs/sba-database.log",
# encoding='utf-8', # encoding='utf-8',
maxBytes=8 * 1024 * 1024, # 8 MiB maxBytes=8 * 1024 * 1024, # 8 MiB
backupCount=5, # Rotate through 5 files backupCount=5, # Rotate through 5 files
) )
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(handler)
app = FastAPI( app = FastAPI(
# root_path='/api', # root_path='/api',
responses={404: {'description': 'Not found'}}, responses={404: {"description": "Not found"}},
docs_url='/api/docs', docs_url="/api/docs",
redoc_url='/api/redoc' redoc_url="/api/redoc",
) )
logger.info(f'Starting up now...') logger.info(f"Starting up now...")
app.include_router(current.router) app.include_router(current.router)
@ -70,18 +96,20 @@ app.include_router(custom_commands.router)
app.include_router(help_commands.router) app.include_router(help_commands.router)
app.include_router(views.router) app.include_router(views.router)
logger.info(f'Loaded all routers.') logger.info(f"Loaded all routers.")
@app.get("/api/docs", include_in_schema=False) @app.get("/api/docs", include_in_schema=False)
async def get_docs(req: Request): async def get_docs(req: Request):
print(req.scope) logger.debug(req.scope)
return get_swagger_ui_html(openapi_url=req.scope.get('root_path')+'/openapi.json', title='Swagger') return get_swagger_ui_html(
openapi_url=req.scope.get("root_path") + "/openapi.json", title="Swagger"
)
@app.get("/api/openapi.json", include_in_schema=False) @app.get("/api/openapi.json", include_in_schema=False)
async def openapi(): async def openapi():
return get_openapi(title='SBa API Docs', version=f'0.1.1', routes=app.routes) return get_openapi(title="SBa API Docs", version=f"0.1.1", routes=app.routes)
# @app.get("/api") # @app.get("/api")

View File

@ -381,14 +381,25 @@ async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme)
all_stats = [] all_stats = []
all_team_ids = list(set(x.team_id for x in s_list.stats))
all_player_ids = list(set(x.player_id for x in s_list.stats))
found_team_ids = (
set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids))
if all_team_ids
else set()
)
found_player_ids = (
set(p.id for p in Player.select(Player.id).where(Player.id << all_player_ids))
if all_player_ids
else set()
)
for x in s_list.stats: for x in s_list.stats:
team = Team.get_or_none(Team.id == x.team_id) if x.team_id not in found_team_ids:
this_player = Player.get_or_none(Player.id == x.player_id)
if team is None:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.team_id} not found" status_code=404, detail=f"Team ID {x.team_id} not found"
) )
if this_player is None: if x.player_id not in found_player_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Player ID {x.player_id} not found" status_code=404, detail=f"Player ID {x.player_id} not found"
) )

View File

@ -296,9 +296,8 @@ async def get_custom_commands(
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
# Get full creator information # Get full creator information
creator_id = command_dict["creator_id"] creator_id = command_dict["creator_id"]
creator_cursor = db.execute_sql( creator_cursor = db.execute_sql(
@ -406,7 +405,7 @@ async def create_custom_command_endpoint(
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
creator_created_at = command_dict.pop("creator_created_at") creator_created_at = command_dict.pop("creator_created_at")
@ -467,7 +466,7 @@ async def update_custom_command_endpoint(
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
creator_created_at = command_dict.pop("creator_created_at") creator_created_at = command_dict.pop("creator_created_at")
@ -552,7 +551,7 @@ async def patch_custom_command(
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
creator_created_at = command_dict.pop("creator_created_at") creator_created_at = command_dict.pop("creator_created_at")
@ -781,7 +780,7 @@ async def get_custom_command_stats():
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
command_dict["creator"] = { command_dict["creator"] = {
"discord_id": command_dict.pop("creator_discord_id"), "discord_id": command_dict.pop("creator_discord_id"),
@ -881,7 +880,7 @@ async def get_custom_command_by_name_endpoint(command_name: str):
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
# Add creator info - get full creator record # Add creator info - get full creator record
@ -966,7 +965,7 @@ async def execute_custom_command(
if updated_dict.get("tags"): if updated_dict.get("tags"):
try: try:
updated_dict["tags"] = json.loads(updated_dict["tags"]) updated_dict["tags"] = json.loads(updated_dict["tags"])
except: except Exception:
updated_dict["tags"] = [] updated_dict["tags"] = []
# Build creator object from the fields returned by get_custom_command_by_id # Build creator object from the fields returned by get_custom_command_by_id
@ -1053,7 +1052,7 @@ async def get_custom_command(command_id: int):
if command_dict.get("tags"): if command_dict.get("tags"):
try: try:
command_dict["tags"] = json.loads(command_dict["tags"]) command_dict["tags"] = json.loads(command_dict["tags"])
except: except Exception:
command_dict["tags"] = [] command_dict["tags"] = []
creator_created_at = command_dict.pop("creator_created_at") creator_created_at = command_dict.pop("creator_created_at")

View File

@ -1,6 +1,3 @@
import datetime
import os
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Optional, Literal from typing import List, Optional, Literal
import logging import logging

View File

@ -159,12 +159,23 @@ async def post_results(result_list: ResultList, token: str = Depends(oauth2_sche
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
new_results = [] new_results = []
all_team_ids = list(
set(x.awayteam_id for x in result_list.results)
| set(x.hometeam_id for x in result_list.results)
)
found_team_ids = (
set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids))
if all_team_ids
else set()
)
for x in result_list.results: for x in result_list.results:
if Team.get_or_none(Team.id == x.awayteam_id) is None: if x.awayteam_id not in found_team_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.awayteam_id} not found" status_code=404, detail=f"Team ID {x.awayteam_id} not found"
) )
if Team.get_or_none(Team.id == x.hometeam_id) is None: if x.hometeam_id not in found_team_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.hometeam_id} not found" status_code=404, detail=f"Team ID {x.hometeam_id} not found"
) )

View File

@ -144,12 +144,23 @@ async def post_schedules(sched_list: ScheduleList, token: str = Depends(oauth2_s
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
new_sched = [] new_sched = []
all_team_ids = list(
set(x.awayteam_id for x in sched_list.schedules)
| set(x.hometeam_id for x in sched_list.schedules)
)
found_team_ids = (
set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids))
if all_team_ids
else set()
)
for x in sched_list.schedules: for x in sched_list.schedules:
if Team.get_or_none(Team.id == x.awayteam_id) is None: if x.awayteam_id not in found_team_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.awayteam_id} not found" status_code=404, detail=f"Team ID {x.awayteam_id} not found"
) )
if Team.get_or_none(Team.id == x.hometeam_id) is None: if x.hometeam_id not in found_team_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.hometeam_id} not found" status_code=404, detail=f"Team ID {x.hometeam_id} not found"
) )

View File

@ -1,24 +1,29 @@
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Optional from typing import List, Optional
import logging import logging
import pydantic
from ..db_engine import db, Standings, Team, Division, model_to_dict, chunked, fn from ..db_engine import db, Standings, Team, Division, model_to_dict, chunked, fn
from ..dependencies import oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors from ..dependencies import (
oauth2_scheme,
logger = logging.getLogger('discord_app') valid_token,
PRIVATE_IN_SCHEMA,
router = APIRouter( handle_db_errors,
prefix='/api/v3/standings',
tags=['standings']
) )
logger = logging.getLogger("discord_app")
@router.get('') router = APIRouter(prefix="/api/v3/standings", tags=["standings"])
@router.get("")
@handle_db_errors @handle_db_errors
async def get_standings( async def get_standings(
season: int, team_id: list = Query(default=None), league_abbrev: Optional[str] = None, season: int,
division_abbrev: Optional[str] = None, short_output: Optional[bool] = False): team_id: list = Query(default=None),
league_abbrev: Optional[str] = None,
division_abbrev: Optional[str] = None,
short_output: Optional[bool] = False,
):
standings = Standings.select_season(season) standings = Standings.select_season(season)
# if standings.count() == 0: # if standings.count() == 0:
@ -30,55 +35,66 @@ async def get_standings(
standings = standings.where(Standings.team << t_query) standings = standings.where(Standings.team << t_query)
if league_abbrev is not None: if league_abbrev is not None:
l_query = Division.select().where(fn.Lower(Division.league_abbrev) == league_abbrev.lower()) l_query = Division.select().where(
fn.Lower(Division.league_abbrev) == league_abbrev.lower()
)
standings = standings.where(Standings.team.division << l_query) standings = standings.where(Standings.team.division << l_query)
if division_abbrev is not None: if division_abbrev is not None:
d_query = Division.select().where(fn.Lower(Division.division_abbrev) == division_abbrev.lower()) d_query = Division.select().where(
fn.Lower(Division.division_abbrev) == division_abbrev.lower()
)
standings = standings.where(Standings.team.division << d_query) standings = standings.where(Standings.team.division << d_query)
def win_pct(this_team_stan): def win_pct(this_team_stan):
if this_team_stan.wins + this_team_stan.losses == 0: if this_team_stan.wins + this_team_stan.losses == 0:
return 0 return 0
else: else:
return (this_team_stan.wins / (this_team_stan.wins + this_team_stan.losses)) + \ return (
(this_team_stan.run_diff * .000001) this_team_stan.wins / (this_team_stan.wins + this_team_stan.losses)
) + (this_team_stan.run_diff * 0.000001)
div_teams = [x for x in standings] div_teams = [x for x in standings]
div_teams.sort(key=lambda team: win_pct(team), reverse=True) div_teams.sort(key=lambda team: win_pct(team), reverse=True)
return_standings = { return_standings = {
'count': len(div_teams), "count": len(div_teams),
'standings': [model_to_dict(x, recurse=not short_output) for x in div_teams] "standings": [model_to_dict(x, recurse=not short_output) for x in div_teams],
} }
db.close() db.close()
return return_standings return return_standings
@router.get('/team/{team_id}') @router.get("/team/{team_id}")
@handle_db_errors @handle_db_errors
async def get_team_standings(team_id: int): async def get_team_standings(team_id: int):
this_stan = Standings.get_or_none(Standings.team_id == team_id) this_stan = Standings.get_or_none(Standings.team_id == team_id)
if this_stan is None: if this_stan is None:
raise HTTPException(status_code=404, detail=f'No standings found for team id {team_id}') raise HTTPException(
status_code=404, detail=f"No standings found for team id {team_id}"
)
return model_to_dict(this_stan) return model_to_dict(this_stan)
@router.patch('/{stan_id}', include_in_schema=PRIVATE_IN_SCHEMA) @router.patch("/{stan_id}", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def patch_standings( async def patch_standings(
stan_id, wins: Optional[int] = None, losses: Optional[int] = None, token: str = Depends(oauth2_scheme)): stan_id,
wins: Optional[int] = None,
losses: Optional[int] = None,
token: str = Depends(oauth2_scheme),
):
if not valid_token(token): if not valid_token(token):
logger.warning(f'patch_standings - Bad Token: {token}') logger.warning(f"patch_standings - Bad Token: {token}")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
try: try:
this_stan = Standings.get_by_id(stan_id) this_stan = Standings.get_by_id(stan_id)
except Exception as e: except Exception as e:
db.close() db.close()
raise HTTPException(status_code=404, detail=f'No team found with id {stan_id}') raise HTTPException(status_code=404, detail=f"No team found with id {stan_id}")
if wins: if wins:
this_stan.wins = wins this_stan.wins = wins
@ -91,35 +107,35 @@ async def patch_standings(
return model_to_dict(this_stan) return model_to_dict(this_stan)
@router.post('/s{season}/new', include_in_schema=PRIVATE_IN_SCHEMA) @router.post("/s{season}/new", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def post_standings(season: int, token: str = Depends(oauth2_scheme)): async def post_standings(season: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f'post_standings - Bad Token: {token}') logger.warning(f"post_standings - Bad Token: {token}")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
new_teams = [] new_teams = []
all_teams = Team.select().where(Team.season == season) all_teams = Team.select().where(Team.season == season)
for x in all_teams: for x in all_teams:
new_teams.append(Standings({'team_id': x.id})) new_teams.append(Standings({"team_id": x.id}))
with db.atomic(): with db.atomic():
for batch in chunked(new_teams, 16): for batch in chunked(new_teams, 16):
Standings.insert_many(batch).on_conflict_ignore().execute() Standings.insert_many(batch).on_conflict_ignore().execute()
db.close() db.close()
return f'Inserted {len(new_teams)} standings' return f"Inserted {len(new_teams)} standings"
@router.post('/s{season}/recalculate', include_in_schema=PRIVATE_IN_SCHEMA) @router.post("/s{season}/recalculate", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def recalculate_standings(season: int, token: str = Depends(oauth2_scheme)): async def recalculate_standings(season: int, token: str = Depends(oauth2_scheme)):
if not valid_token(token): if not valid_token(token):
logger.warning(f'recalculate_standings - Bad Token: {token}') logger.warning(f"recalculate_standings - Bad Token: {token}")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
code = Standings.recalculate(season) code = Standings.recalculate(season)
db.close() db.close()
if code == 69: if code == 69:
raise HTTPException(status_code=500, detail=f'Error recreating Standings rows') raise HTTPException(status_code=500, detail=f"Error recreating Standings rows")
return f'Just recalculated standings for season {season}' return f"Just recalculated standings for season {season}"

View File

@ -143,16 +143,31 @@ async def post_transactions(
all_moves = [] all_moves = []
all_team_ids = list(
set(x.oldteam_id for x in moves.moves) | set(x.newteam_id for x in moves.moves)
)
all_player_ids = list(set(x.player_id for x in moves.moves))
found_team_ids = (
set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids))
if all_team_ids
else set()
)
found_player_ids = (
set(p.id for p in Player.select(Player.id).where(Player.id << all_player_ids))
if all_player_ids
else set()
)
for x in moves.moves: for x in moves.moves:
if Team.get_or_none(Team.id == x.oldteam_id) is None: if x.oldteam_id not in found_team_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.oldteam_id} not found" status_code=404, detail=f"Team ID {x.oldteam_id} not found"
) )
if Team.get_or_none(Team.id == x.newteam_id) is None: if x.newteam_id not in found_team_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Team ID {x.newteam_id} not found" status_code=404, detail=f"Team ID {x.newteam_id} not found"
) )
if Player.get_or_none(Player.id == x.player_id) is None: if x.player_id not in found_player_ids:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Player ID {x.player_id} not found" status_code=404, detail=f"Player ID {x.player_id} not found"
) )

View File

@ -3,239 +3,331 @@ from typing import List, Literal, Optional
import logging import logging
import pydantic import pydantic
from ..db_engine import SeasonBattingStats, SeasonPitchingStats, db, Manager, Team, Current, model_to_dict, fn, query_to_csv, StratPlay, StratGame from ..db_engine import (
from ..dependencies import add_cache_headers, cache_result, oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors, update_season_batting_stats, update_season_pitching_stats, get_cache_stats SeasonBattingStats,
SeasonPitchingStats,
logger = logging.getLogger('discord_app') db,
Manager,
router = APIRouter( Team,
prefix='/api/v3/views', Current,
tags=['views'] model_to_dict,
fn,
query_to_csv,
StratPlay,
StratGame,
)
from ..dependencies import (
add_cache_headers,
cache_result,
oauth2_scheme,
valid_token,
PRIVATE_IN_SCHEMA,
handle_db_errors,
update_season_batting_stats,
update_season_pitching_stats,
get_cache_stats,
) )
@router.get('/season-stats/batting') logger = logging.getLogger("discord_app")
router = APIRouter(prefix="/api/v3/views", tags=["views"])
@router.get("/season-stats/batting")
@handle_db_errors @handle_db_errors
@add_cache_headers(max_age=10*60) @add_cache_headers(max_age=10 * 60)
@cache_result(ttl=5*60, key_prefix='season-batting') @cache_result(ttl=5 * 60, key_prefix="season-batting")
async def get_season_batting_stats( async def get_season_batting_stats(
season: Optional[int] = None, season: Optional[int] = None,
team_id: Optional[int] = None, team_id: Optional[int] = None,
player_id: Optional[int] = None, player_id: Optional[int] = None,
sbaplayer_id: Optional[int] = None, sbaplayer_id: Optional[int] = None,
min_pa: Optional[int] = None, # Minimum plate appearances min_pa: Optional[int] = None, # Minimum plate appearances
sort_by: str = "woba", # Default sort field sort_by: Literal[
sort_order: Literal['asc', 'desc'] = 'desc', # asc or desc "pa",
"ab",
"run",
"hit",
"double",
"triple",
"homerun",
"rbi",
"bb",
"so",
"bphr",
"bpfo",
"bp1b",
"bplo",
"gidp",
"hbp",
"sac",
"ibb",
"avg",
"obp",
"slg",
"ops",
"woba",
"k_pct",
"sb",
"cs",
] = "woba", # Sort field
sort_order: Literal["asc", "desc"] = "desc", # asc or desc
limit: Optional[int] = 200, limit: Optional[int] = 200,
offset: int = 0, offset: int = 0,
csv: Optional[bool] = False csv: Optional[bool] = False,
): ):
logger.info(f'Getting season {season} batting stats - team_id: {team_id}, player_id: {player_id}, min_pa: {min_pa}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}') logger.info(
f"Getting season {season} batting stats - team_id: {team_id}, player_id: {player_id}, min_pa: {min_pa}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}"
)
# Use the enhanced get_top_hitters method # Use the enhanced get_top_hitters method
query = SeasonBattingStats.get_top_hitters( query = SeasonBattingStats.get_top_hitters(
season=season, season=season,
stat=sort_by, stat=sort_by,
limit=limit if limit != 0 else None, limit=limit if limit != 0 else None,
desc=(sort_order.lower() == 'desc'), desc=(sort_order.lower() == "desc"),
team_id=team_id, team_id=team_id,
player_id=player_id, player_id=player_id,
sbaplayer_id=sbaplayer_id, sbaplayer_id=sbaplayer_id,
min_pa=min_pa, min_pa=min_pa,
offset=offset offset=offset,
) )
# Build applied filters for response # Build applied filters for response
applied_filters = {} applied_filters = {}
if season is not None: if season is not None:
applied_filters['season'] = season applied_filters["season"] = season
if team_id is not None: if team_id is not None:
applied_filters['team_id'] = team_id applied_filters["team_id"] = team_id
if player_id is not None: if player_id is not None:
applied_filters['player_id'] = player_id applied_filters["player_id"] = player_id
if min_pa is not None: if min_pa is not None:
applied_filters['min_pa'] = min_pa applied_filters["min_pa"] = min_pa
if csv: if csv:
return_val = query_to_csv(query) return_val = query_to_csv(query)
return Response(content=return_val, media_type='text/csv') return Response(content=return_val, media_type="text/csv")
else: else:
stat_list = [model_to_dict(stat) for stat in query] stat_list = [model_to_dict(stat) for stat in query]
return { return {"count": len(stat_list), "filters": applied_filters, "stats": stat_list}
'count': len(stat_list),
'filters': applied_filters,
'stats': stat_list
}
@router.post('/season-stats/batting/refresh', include_in_schema=PRIVATE_IN_SCHEMA) @router.post("/season-stats/batting/refresh", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def refresh_season_batting_stats( async def refresh_season_batting_stats(
season: int, season: int, token: str = Depends(oauth2_scheme)
token: str = Depends(oauth2_scheme)
) -> dict: ) -> dict:
""" """
Refresh batting stats for all players in a specific season. Refresh batting stats for all players in a specific season.
Useful for full season updates. Useful for full season updates.
""" """
if not valid_token(token): if not valid_token(token):
logger.warning(f'refresh_season_batting_stats - Bad Token: {token}') logger.warning(f"refresh_season_batting_stats - Bad Token: {token}")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
logger.info(f"Refreshing all batting stats for season {season}")
logger.info(f'Refreshing all batting stats for season {season}')
try: try:
# Get all player IDs who have stratplay records in this season # Get all player IDs who have stratplay records in this season
batter_ids = [row.batter_id for row in batter_ids = [
StratPlay.select(StratPlay.batter_id.distinct()) row.batter_id
.join(StratGame).where(StratGame.season == season)] for row in StratPlay.select(StratPlay.batter_id.distinct())
.join(StratGame)
.where(StratGame.season == season)
]
if batter_ids: if batter_ids:
update_season_batting_stats(batter_ids, season, db) update_season_batting_stats(batter_ids, season, db)
logger.info(f'Successfully refreshed {len(batter_ids)} players for season {season}') logger.info(
f"Successfully refreshed {len(batter_ids)} players for season {season}"
)
return { return {
'message': f'Season {season} batting stats refreshed', "message": f"Season {season} batting stats refreshed",
'players_updated': len(batter_ids) "players_updated": len(batter_ids),
} }
else: else:
logger.warning(f'No batting data found for season {season}') logger.warning(f"No batting data found for season {season}")
return { return {
'message': f'No batting data found for season {season}', "message": f"No batting data found for season {season}",
'players_updated': 0 "players_updated": 0,
} }
except Exception as e: except Exception as e:
logger.error(f'Error refreshing season {season}: {e}') logger.error(f"Error refreshing season {season}: {e}")
raise HTTPException(status_code=500, detail=f'Refresh failed: {str(e)}') raise HTTPException(status_code=500, detail=f"Refresh failed: {str(e)}")
@router.get('/season-stats/pitching') @router.get("/season-stats/pitching")
@handle_db_errors @handle_db_errors
@add_cache_headers(max_age=10*60) @add_cache_headers(max_age=10 * 60)
@cache_result(ttl=5*60, key_prefix='season-pitching') @cache_result(ttl=5 * 60, key_prefix="season-pitching")
async def get_season_pitching_stats( async def get_season_pitching_stats(
season: Optional[int] = None, season: Optional[int] = None,
team_id: Optional[int] = None, team_id: Optional[int] = None,
player_id: Optional[int] = None, player_id: Optional[int] = None,
sbaplayer_id: Optional[int] = None, sbaplayer_id: Optional[int] = None,
min_outs: Optional[int] = None, # Minimum outs pitched min_outs: Optional[int] = None, # Minimum outs pitched
sort_by: str = "era", # Default sort field sort_by: Literal[
sort_order: Literal['asc', 'desc'] = 'asc', # asc or desc (asc default for ERA) "tbf",
"outs",
"games",
"gs",
"win",
"loss",
"hold",
"saves",
"bsave",
"ir",
"irs",
"ab",
"run",
"e_run",
"hits",
"double",
"triple",
"homerun",
"bb",
"so",
"hbp",
"sac",
"ibb",
"gidp",
"sb",
"cs",
"bphr",
"bpfo",
"bp1b",
"bplo",
"wp",
"balk",
"wpa",
"era",
"whip",
"avg",
"obp",
"slg",
"ops",
"woba",
"hper9",
"kper9",
"bbper9",
"kperbb",
"lob_2outs",
"rbipercent",
"re24",
] = "era", # Sort field
sort_order: Literal["asc", "desc"] = "asc", # asc or desc (asc default for ERA)
limit: Optional[int] = 200, limit: Optional[int] = 200,
offset: int = 0, offset: int = 0,
csv: Optional[bool] = False csv: Optional[bool] = False,
): ):
logger.info(f'Getting season {season} pitching stats - team_id: {team_id}, player_id: {player_id}, min_outs: {min_outs}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}') logger.info(
f"Getting season {season} pitching stats - team_id: {team_id}, player_id: {player_id}, min_outs: {min_outs}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}"
)
# Use the get_top_pitchers method # Use the get_top_pitchers method
query = SeasonPitchingStats.get_top_pitchers( query = SeasonPitchingStats.get_top_pitchers(
season=season, season=season,
stat=sort_by, stat=sort_by,
limit=limit if limit != 0 else None, limit=limit if limit != 0 else None,
desc=(sort_order.lower() == 'desc'), desc=(sort_order.lower() == "desc"),
team_id=team_id, team_id=team_id,
player_id=player_id, player_id=player_id,
sbaplayer_id=sbaplayer_id, sbaplayer_id=sbaplayer_id,
min_outs=min_outs, min_outs=min_outs,
offset=offset offset=offset,
) )
# Build applied filters for response # Build applied filters for response
applied_filters = {} applied_filters = {}
if season is not None: if season is not None:
applied_filters['season'] = season applied_filters["season"] = season
if team_id is not None: if team_id is not None:
applied_filters['team_id'] = team_id applied_filters["team_id"] = team_id
if player_id is not None: if player_id is not None:
applied_filters['player_id'] = player_id applied_filters["player_id"] = player_id
if min_outs is not None: if min_outs is not None:
applied_filters['min_outs'] = min_outs applied_filters["min_outs"] = min_outs
if csv: if csv:
return_val = query_to_csv(query) return_val = query_to_csv(query)
return Response(content=return_val, media_type='text/csv') return Response(content=return_val, media_type="text/csv")
else: else:
stat_list = [model_to_dict(stat) for stat in query] stat_list = [model_to_dict(stat) for stat in query]
return { return {"count": len(stat_list), "filters": applied_filters, "stats": stat_list}
'count': len(stat_list),
'filters': applied_filters,
'stats': stat_list
}
@router.post('/season-stats/pitching/refresh', include_in_schema=PRIVATE_IN_SCHEMA) @router.post("/season-stats/pitching/refresh", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def refresh_season_pitching_stats( async def refresh_season_pitching_stats(
season: int, season: int, token: str = Depends(oauth2_scheme)
token: str = Depends(oauth2_scheme)
) -> dict: ) -> dict:
""" """
Refresh pitching statistics for a specific season by aggregating from individual games. Refresh pitching statistics for a specific season by aggregating from individual games.
Private endpoint - not included in public API documentation. Private endpoint - not included in public API documentation.
""" """
if not valid_token(token): if not valid_token(token):
logger.warning(f'refresh_season_batting_stats - Bad Token: {token}') logger.warning(f"refresh_season_batting_stats - Bad Token: {token}")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
logger.info(f"Refreshing season {season} pitching stats")
logger.info(f'Refreshing season {season} pitching stats')
try: try:
# Get all pitcher IDs for this season # Get all pitcher IDs for this season
pitcher_query = ( pitcher_query = (
StratPlay StratPlay.select(StratPlay.pitcher_id)
.select(StratPlay.pitcher_id)
.join(StratGame, on=(StratPlay.game_id == StratGame.id)) .join(StratGame, on=(StratPlay.game_id == StratGame.id))
.where((StratGame.season == season) & (StratPlay.pitcher_id.is_null(False))) .where((StratGame.season == season) & (StratPlay.pitcher_id.is_null(False)))
.distinct() .distinct()
) )
pitcher_ids = [row.pitcher_id for row in pitcher_query] pitcher_ids = [row.pitcher_id for row in pitcher_query]
if not pitcher_ids: if not pitcher_ids:
logger.warning(f'No pitchers found for season {season}') logger.warning(f"No pitchers found for season {season}")
return { return {
'status': 'success', "status": "success",
'message': f'No pitchers found for season {season}', "message": f"No pitchers found for season {season}",
'players_updated': 0 "players_updated": 0,
} }
# Use the dependency function to update pitching stats # Use the dependency function to update pitching stats
update_season_pitching_stats(pitcher_ids, season, db) update_season_pitching_stats(pitcher_ids, season, db)
logger.info(f'Season {season} pitching stats refreshed successfully - {len(pitcher_ids)} players updated') logger.info(
f"Season {season} pitching stats refreshed successfully - {len(pitcher_ids)} players updated"
)
return { return {
'status': 'success', "status": "success",
'message': f'Season {season} pitching stats refreshed', "message": f"Season {season} pitching stats refreshed",
'players_updated': len(pitcher_ids) "players_updated": len(pitcher_ids),
} }
except Exception as e: except Exception as e:
logger.error(f'Error refreshing season {season} pitching stats: {e}') logger.error(f"Error refreshing season {season} pitching stats: {e}")
raise HTTPException(status_code=500, detail=f'Refresh failed: {str(e)}') raise HTTPException(status_code=500, detail=f"Refresh failed: {str(e)}")
@router.get('/admin/cache', include_in_schema=PRIVATE_IN_SCHEMA) @router.get("/admin/cache", include_in_schema=PRIVATE_IN_SCHEMA)
@handle_db_errors @handle_db_errors
async def get_admin_cache_stats( async def get_admin_cache_stats(token: str = Depends(oauth2_scheme)) -> dict:
token: str = Depends(oauth2_scheme)
) -> dict:
""" """
Get Redis cache statistics and status. Get Redis cache statistics and status.
Private endpoint - requires authentication. Private endpoint - requires authentication.
""" """
if not valid_token(token): if not valid_token(token):
logger.warning(f'get_admin_cache_stats - Bad Token: {token}') logger.warning(f"get_admin_cache_stats - Bad Token: {token}")
raise HTTPException(status_code=401, detail='Unauthorized') raise HTTPException(status_code=401, detail="Unauthorized")
logger.info("Getting cache statistics")
logger.info('Getting cache statistics')
try: try:
cache_stats = get_cache_stats() cache_stats = get_cache_stats()
logger.info(f'Cache stats retrieved: {cache_stats}') logger.info(f"Cache stats retrieved: {cache_stats}")
return { return {"status": "success", "cache_info": cache_stats}
'status': 'success',
'cache_info': cache_stats
}
except Exception as e: except Exception as e:
logger.error(f'Error getting cache stats: {e}') logger.error(f"Error getting cache stats: {e}")
raise HTTPException(status_code=500, detail=f'Failed to get cache stats: {str(e)}') raise HTTPException(
status_code=500, detail=f"Failed to get cache stats: {str(e)}"
)

View File

@ -39,7 +39,7 @@ class PlayerService(BaseService):
cache_patterns = ["players*", "players-search*", "player*", "team-roster*"] cache_patterns = ["players*", "players-search*", "player*", "team-roster*"]
# Deprecated fields to exclude from player responses # Deprecated fields to exclude from player responses
EXCLUDED_FIELDS = ['pitcher_injury'] EXCLUDED_FIELDS = ["pitcher_injury"]
# Class-level repository for dependency injection # Class-level repository for dependency injection
_injected_repo: Optional[AbstractPlayerRepository] = None _injected_repo: Optional[AbstractPlayerRepository] = None
@ -135,17 +135,21 @@ class PlayerService(BaseService):
# Apply sorting # Apply sorting
query = cls._apply_player_sort(query, sort) query = cls._apply_player_sort(query, sort)
# Convert to list of dicts # Apply pagination at DB level for real queries, Python level for mocks
players_data = cls._query_to_player_dicts(query, short_output) if isinstance(query, InMemoryQueryResult):
total_count = len(query)
# Store total count before pagination players_data = cls._query_to_player_dicts(query, short_output)
total_count = len(players_data) if offset is not None:
players_data = players_data[offset:]
# Apply pagination (offset and limit) if limit is not None:
if offset is not None: players_data = players_data[:limit]
players_data = players_data[offset:] else:
if limit is not None: total_count = query.count()
players_data = players_data[:limit] if offset is not None:
query = query.offset(offset)
if limit is not None:
query = query.limit(limit)
players_data = cls._query_to_player_dicts(query, short_output)
# Return format # Return format
if as_csv: if as_csv:
@ -154,7 +158,7 @@ class PlayerService(BaseService):
return { return {
"count": len(players_data), "count": len(players_data),
"total": total_count, "total": total_count,
"players": players_data "players": players_data,
} }
except Exception as e: except Exception as e:
@ -204,9 +208,9 @@ class PlayerService(BaseService):
p_list = [x.upper() for x in pos] p_list = [x.upper() for x in pos]
# Expand generic "P" to match all pitcher positions # Expand generic "P" to match all pitcher positions
pitcher_positions = ['SP', 'RP', 'CP'] pitcher_positions = ["SP", "RP", "CP"]
if 'P' in p_list: if "P" in p_list:
p_list.remove('P') p_list.remove("P")
p_list.extend(pitcher_positions) p_list.extend(pitcher_positions)
pos_conditions = ( pos_conditions = (
@ -245,9 +249,9 @@ class PlayerService(BaseService):
p_list = [p.upper() for p in pos] p_list = [p.upper() for p in pos]
# Expand generic "P" to match all pitcher positions # Expand generic "P" to match all pitcher positions
pitcher_positions = ['SP', 'RP', 'CP'] pitcher_positions = ["SP", "RP", "CP"]
if 'P' in p_list: if "P" in p_list:
p_list.remove('P') p_list.remove("P")
p_list.extend(pitcher_positions) p_list.extend(pitcher_positions)
player_pos = [ player_pos = [
@ -385,19 +389,23 @@ class PlayerService(BaseService):
# This filters at the database level instead of loading all players # This filters at the database level instead of loading all players
if search_all_seasons: if search_all_seasons:
# Search all seasons, order by season DESC (newest first) # Search all seasons, order by season DESC (newest first)
query = (Player.select() query = (
.where(fn.Lower(Player.name).contains(query_lower)) Player.select()
.order_by(Player.season.desc(), Player.name) .where(fn.Lower(Player.name).contains(query_lower))
.limit(limit * 2)) # Get extra for exact match sorting .order_by(Player.season.desc(), Player.name)
.limit(limit * 2)
) # Get extra for exact match sorting
else: else:
# Search specific season # Search specific season
query = (Player.select() query = (
.where( Player.select()
(Player.season == season) & .where(
(fn.Lower(Player.name).contains(query_lower)) (Player.season == season)
) & (fn.Lower(Player.name).contains(query_lower))
.order_by(Player.name) )
.limit(limit * 2)) # Get extra for exact match sorting .order_by(Player.name)
.limit(limit * 2)
) # Get extra for exact match sorting
# Execute query and convert limited results to dicts # Execute query and convert limited results to dicts
players = list(query) players = list(query)
@ -468,19 +476,29 @@ class PlayerService(BaseService):
# Use backrefs=False to avoid circular reference issues # Use backrefs=False to avoid circular reference issues
player_dict = model_to_dict(player, recurse=recurse, backrefs=False) player_dict = model_to_dict(player, recurse=recurse, backrefs=False)
# Filter out excluded fields # Filter out excluded fields
return {k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS} return {
k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS
}
except (ImportError, AttributeError, TypeError) as e: except (ImportError, AttributeError, TypeError) as e:
# Log the error and fall back to non-recursive serialization # Log the error and fall back to non-recursive serialization
logger.warning(f"Error in recursive player serialization: {e}, falling back to non-recursive") logger.warning(
f"Error in recursive player serialization: {e}, falling back to non-recursive"
)
try: try:
# Fallback to non-recursive serialization # Fallback to non-recursive serialization
player_dict = model_to_dict(player, recurse=False) player_dict = model_to_dict(player, recurse=False)
return {k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS} return {
k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS
}
except Exception as fallback_error: except Exception as fallback_error:
# Final fallback to basic dict conversion # Final fallback to basic dict conversion
logger.error(f"Error in non-recursive serialization: {fallback_error}, using basic dict") logger.error(
f"Error in non-recursive serialization: {fallback_error}, using basic dict"
)
player_dict = dict(player) player_dict = dict(player)
return {k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS} return {
k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS
}
@classmethod @classmethod
def update_player( def update_player(
@ -508,6 +526,8 @@ class PlayerService(BaseService):
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Error updating player {player_id}: {str(e)}" status_code=500, detail=f"Error updating player {player_id}: {str(e)}"
) )
finally:
temp_service.invalidate_related_cache(cls.cache_patterns)
@classmethod @classmethod
def patch_player( def patch_player(
@ -535,6 +555,8 @@ class PlayerService(BaseService):
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Error patching player {player_id}: {str(e)}" status_code=500, detail=f"Error patching player {player_id}: {str(e)}"
) )
finally:
temp_service.invalidate_related_cache(cls.cache_patterns)
@classmethod @classmethod
def create_players( def create_players(
@ -567,6 +589,8 @@ class PlayerService(BaseService):
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Error creating players: {str(e)}" status_code=500, detail=f"Error creating players: {str(e)}"
) )
finally:
temp_service.invalidate_related_cache(cls.cache_patterns)
@classmethod @classmethod
def delete_player(cls, player_id: int, token: str) -> Dict[str, str]: def delete_player(cls, player_id: int, token: str) -> Dict[str, str]:
@ -590,6 +614,8 @@ class PlayerService(BaseService):
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Error deleting player {player_id}: {str(e)}" status_code=500, detail=f"Error deleting player {player_id}: {str(e)}"
) )
finally:
temp_service.invalidate_related_cache(cls.cache_patterns)
@classmethod @classmethod
def _format_player_csv(cls, players: List[Dict]) -> str: def _format_player_csv(cls, players: List[Dict]) -> str:
@ -603,12 +629,12 @@ class PlayerService(BaseService):
flat_player = player.copy() flat_player = player.copy()
# Flatten team object to just abbreviation # Flatten team object to just abbreviation
if isinstance(flat_player.get('team'), dict): if isinstance(flat_player.get("team"), dict):
flat_player['team'] = flat_player['team'].get('abbrev', '') flat_player["team"] = flat_player["team"].get("abbrev", "")
# Flatten sbaplayer object to just ID # Flatten sbaplayer object to just ID
if isinstance(flat_player.get('sbaplayer'), dict): if isinstance(flat_player.get("sbaplayer"), dict):
flat_player['sbaplayer'] = flat_player['sbaplayer'].get('id', '') flat_player["sbaplayer"] = flat_player["sbaplayer"].get("id", "")
flattened_players.append(flat_player) flattened_players.append(flat_player)

View File

@ -7,21 +7,18 @@ import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import sys import sys
import os import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.services.player_service import PlayerService from app.services.player_service import PlayerService
from app.services.base import ServiceConfig from app.services.base import ServiceConfig
from app.services.mocks import ( from app.services.mocks import MockPlayerRepository, MockCacheService, EnhancedMockCache
MockPlayerRepository,
MockCacheService,
EnhancedMockCache
)
# ============================================================================ # ============================================================================
# FIXTURES # FIXTURES
# ============================================================================ # ============================================================================
@pytest.fixture @pytest.fixture
def cache(): def cache():
"""Create fresh cache for each test.""" """Create fresh cache for each test."""
@ -32,20 +29,73 @@ def cache():
def repo(cache): def repo(cache):
"""Create fresh repo with test data.""" """Create fresh repo with test data."""
repo = MockPlayerRepository() repo = MockPlayerRepository()
# Add test players # Add test players
players = [ players = [
{'id': 1, 'name': 'Mike Trout', 'wara': 5.2, 'team_id': 1, 'season': 10, 'pos_1': 'CF', 'pos_2': 'LF', 'strat_code': 'Elite', 'injury_rating': 'A'}, {
{'id': 2, 'name': 'Aaron Judge', 'wara': 4.8, 'team_id': 2, 'season': 10, 'pos_1': 'RF', 'strat_code': 'Power', 'injury_rating': 'B'}, "id": 1,
{'id': 3, 'name': 'Mookie Betts', 'wara': 5.5, 'team_id': 3, 'season': 10, 'pos_1': 'RF', 'pos_2': '2B', 'strat_code': 'Elite', 'injury_rating': 'A'}, "name": "Mike Trout",
{'id': 4, 'name': 'Injured Player', 'wara': 2.0, 'team_id': 1, 'season': 10, 'pos_1': 'P', 'il_return': 'Week 5', 'injury_rating': 'C'}, "wara": 5.2,
{'id': 5, 'name': 'Old Player', 'wara': 1.0, 'team_id': 1, 'season': 5, 'pos_1': '1B'}, "team_id": 1,
{'id': 6, 'name': 'Juan Soto', 'wara': 4.5, 'team_id': 2, 'season': 10, 'pos_1': '1B', 'strat_code': 'Contact'}, "season": 10,
"pos_1": "CF",
"pos_2": "LF",
"strat_code": "Elite",
"injury_rating": "A",
},
{
"id": 2,
"name": "Aaron Judge",
"wara": 4.8,
"team_id": 2,
"season": 10,
"pos_1": "RF",
"strat_code": "Power",
"injury_rating": "B",
},
{
"id": 3,
"name": "Mookie Betts",
"wara": 5.5,
"team_id": 3,
"season": 10,
"pos_1": "RF",
"pos_2": "2B",
"strat_code": "Elite",
"injury_rating": "A",
},
{
"id": 4,
"name": "Injured Player",
"wara": 2.0,
"team_id": 1,
"season": 10,
"pos_1": "P",
"il_return": "Week 5",
"injury_rating": "C",
},
{
"id": 5,
"name": "Old Player",
"wara": 1.0,
"team_id": 1,
"season": 5,
"pos_1": "1B",
},
{
"id": 6,
"name": "Juan Soto",
"wara": 4.5,
"team_id": 2,
"season": 10,
"pos_1": "1B",
"strat_code": "Contact",
},
] ]
for player in players: for player in players:
repo.add_player(player) repo.add_player(player)
return repo return repo
@ -60,463 +110,453 @@ def service(repo, cache):
# TEST CLASSES # TEST CLASSES
# ============================================================================ # ============================================================================
class TestPlayerServiceGetPlayers: class TestPlayerServiceGetPlayers:
"""Tests for get_players method - 50+ lines covered.""" """Tests for get_players method - 50+ lines covered."""
def test_get_all_season_players(self, service, repo): def test_get_all_season_players(self, service, repo):
"""Get all players for a season.""" """Get all players for a season."""
result = service.get_players(season=10) result = service.get_players(season=10)
assert result['count'] >= 5 # We have 5 season 10 players assert result["count"] >= 5 # We have 5 season 10 players
assert len(result['players']) >= 5 assert len(result["players"]) >= 5
assert all(p.get('season') == 10 for p in result['players']) assert all(p.get("season") == 10 for p in result["players"])
def test_filter_by_single_team(self, service): def test_filter_by_single_team(self, service):
"""Filter by single team ID.""" """Filter by single team ID."""
result = service.get_players(season=10, team_id=[1]) result = service.get_players(season=10, team_id=[1])
assert result['count'] >= 1 assert result["count"] >= 1
assert all(p.get('team_id') == 1 for p in result['players']) assert all(p.get("team_id") == 1 for p in result["players"])
def test_filter_by_multiple_teams(self, service): def test_filter_by_multiple_teams(self, service):
"""Filter by multiple team IDs.""" """Filter by multiple team IDs."""
result = service.get_players(season=10, team_id=[1, 2]) result = service.get_players(season=10, team_id=[1, 2])
assert result['count'] >= 2 assert result["count"] >= 2
assert all(p.get('team_id') in [1, 2] for p in result['players']) assert all(p.get("team_id") in [1, 2] for p in result["players"])
def test_filter_by_position(self, service): def test_filter_by_position(self, service):
"""Filter by position.""" """Filter by position."""
result = service.get_players(season=10, pos=['CF']) result = service.get_players(season=10, pos=["CF"])
assert result['count'] >= 1 assert result["count"] >= 1
assert any(p.get('pos_1') == 'CF' or p.get('pos_2') == 'CF' for p in result['players']) assert any(
p.get("pos_1") == "CF" or p.get("pos_2") == "CF" for p in result["players"]
)
def test_filter_by_strat_code(self, service): def test_filter_by_strat_code(self, service):
"""Filter by strat code.""" """Filter by strat code."""
result = service.get_players(season=10, strat_code=['Elite']) result = service.get_players(season=10, strat_code=["Elite"])
assert result['count'] >= 2 # Trout and Betts assert result["count"] >= 2 # Trout and Betts
assert all('Elite' in str(p.get('strat_code', '')) for p in result['players']) assert all("Elite" in str(p.get("strat_code", "")) for p in result["players"])
def test_filter_injured_only(self, service): def test_filter_injured_only(self, service):
"""Filter injured players only.""" """Filter injured players only."""
result = service.get_players(season=10, is_injured=True) result = service.get_players(season=10, is_injured=True)
assert result['count'] >= 1 assert result["count"] >= 1
assert all(p.get('il_return') is not None for p in result['players']) assert all(p.get("il_return") is not None for p in result["players"])
def test_sort_cost_ascending(self, service): def test_sort_cost_ascending(self, service):
"""Sort by WARA ascending.""" """Sort by WARA ascending."""
result = service.get_players(season=10, sort='cost-asc') result = service.get_players(season=10, sort="cost-asc")
wara = [p.get('wara', 0) for p in result['players']] wara = [p.get("wara", 0) for p in result["players"]]
assert wara == sorted(wara) assert wara == sorted(wara)
def test_sort_cost_descending(self, service): def test_sort_cost_descending(self, service):
"""Sort by WARA descending.""" """Sort by WARA descending."""
result = service.get_players(season=10, sort='cost-desc') result = service.get_players(season=10, sort="cost-desc")
wara = [p.get('wara', 0) for p in result['players']] wara = [p.get("wara", 0) for p in result["players"]]
assert wara == sorted(wara, reverse=True) assert wara == sorted(wara, reverse=True)
def test_sort_name_ascending(self, service): def test_sort_name_ascending(self, service):
"""Sort by name ascending.""" """Sort by name ascending."""
result = service.get_players(season=10, sort='name-asc') result = service.get_players(season=10, sort="name-asc")
names = [p.get('name', '') for p in result['players']] names = [p.get("name", "") for p in result["players"]]
assert names == sorted(names) assert names == sorted(names)
def test_sort_name_descending(self, service): def test_sort_name_descending(self, service):
"""Sort by name descending.""" """Sort by name descending."""
result = service.get_players(season=10, sort='name-desc') result = service.get_players(season=10, sort="name-desc")
names = [p.get('name', '') for p in result['players']] names = [p.get("name", "") for p in result["players"]]
assert names == sorted(names, reverse=True) assert names == sorted(names, reverse=True)
class TestPlayerServiceSearch: class TestPlayerServiceSearch:
"""Tests for search_players method.""" """Tests for search_players method."""
def test_exact_name_match(self, service): def test_exact_name_match(self, service):
"""Search with exact name match.""" """Search with exact name match."""
result = service.search_players('Mike Trout', season=10) result = service.search_players("Mike Trout", season=10)
assert result['count'] >= 1 assert result["count"] >= 1
names = [p.get('name') for p in result['players']] names = [p.get("name") for p in result["players"]]
assert 'Mike Trout' in names assert "Mike Trout" in names
def test_partial_name_match(self, service): def test_partial_name_match(self, service):
"""Search with partial name match.""" """Search with partial name match."""
result = service.search_players('Trout', season=10) result = service.search_players("Trout", season=10)
assert result['count'] >= 1 assert result["count"] >= 1
assert any('Trout' in p.get('name', '') for p in result['players']) assert any("Trout" in p.get("name", "") for p in result["players"])
def test_case_insensitive_search(self, service): def test_case_insensitive_search(self, service):
"""Search is case insensitive.""" """Search is case insensitive."""
result1 = service.search_players('MIKE', season=10) result1 = service.search_players("MIKE", season=10)
result2 = service.search_players('mike', season=10) result2 = service.search_players("mike", season=10)
assert result1['count'] == result2['count'] assert result1["count"] == result2["count"]
def test_search_all_seasons(self, service): def test_search_all_seasons(self, service):
"""Search across all seasons.""" """Search across all seasons."""
result = service.search_players('Player', season=None) result = service.search_players("Player", season=None)
# Should find both current and old players # Should find both current and old players
assert result['all_seasons'] == True assert result["all_seasons"] == True
assert result['count'] >= 2 assert result["count"] >= 2
def test_search_limit(self, service): def test_search_limit(self, service):
"""Limit search results.""" """Limit search results."""
result = service.search_players('a', season=10, limit=2) result = service.search_players("a", season=10, limit=2)
assert result['count'] <= 2 assert result["count"] <= 2
def test_search_no_results(self, service): def test_search_no_results(self, service):
"""Search returns empty when no matches.""" """Search returns empty when no matches."""
result = service.search_players('XYZ123NotExist', season=10) result = service.search_players("XYZ123NotExist", season=10)
assert result['count'] == 0 assert result["count"] == 0
assert result['players'] == [] assert result["players"] == []
class TestPlayerServiceGetPlayer: class TestPlayerServiceGetPlayer:
"""Tests for get_player method.""" """Tests for get_player method."""
def test_get_existing_player(self, service): def test_get_existing_player(self, service):
"""Get existing player by ID.""" """Get existing player by ID."""
result = service.get_player(1) result = service.get_player(1)
assert result is not None assert result is not None
assert result.get('id') == 1 assert result.get("id") == 1
assert result.get('name') == 'Mike Trout' assert result.get("name") == "Mike Trout"
def test_get_nonexistent_player(self, service): def test_get_nonexistent_player(self, service):
"""Get player that doesn't exist.""" """Get player that doesn't exist."""
result = service.get_player(99999) result = service.get_player(99999)
assert result is None assert result is None
def test_get_player_short_output(self, service): def test_get_player_short_output(self, service):
"""Get player with short output.""" """Get player with short output."""
result = service.get_player(1, short_output=True) result = service.get_player(1, short_output=True)
# Should still have basic fields # Should still have basic fields
assert result.get('id') == 1 assert result.get("id") == 1
assert result.get('name') == 'Mike Trout' assert result.get("name") == "Mike Trout"
class TestPlayerServiceCreate: class TestPlayerServiceCreate:
"""Tests for create_players method.""" """Tests for create_players method."""
def test_create_single_player(self, repo, cache): def test_create_single_player(self, repo, cache):
"""Create a single new player.""" """Create a single new player."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
new_player = [{ new_player = [
'name': 'New Player', {
'wara': 3.0, "name": "New Player",
'team_id': 1, "wara": 3.0,
'season': 10, "team_id": 1,
'pos_1': 'SS' "season": 10,
}] "pos_1": "SS",
}
]
# Mock auth # Mock auth
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
result = service.create_players(new_player, 'valid_token') result = service.create_players(new_player, "valid_token")
assert 'Inserted' in str(result) assert "Inserted" in str(result)
# Verify player was added (ID 7 since fixture has players 1-6) # Verify player was added (ID 7 since fixture has players 1-6)
player = repo.get_by_id(7) # Next ID after fixture data player = repo.get_by_id(7) # Next ID after fixture data
assert player is not None assert player is not None
assert player['name'] == 'New Player' assert player["name"] == "New Player"
def test_create_multiple_players(self, repo, cache): def test_create_multiple_players(self, repo, cache):
"""Create multiple new players.""" """Create multiple new players."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
new_players = [ new_players = [
{'name': 'Player A', 'wara': 2.0, 'team_id': 1, 'season': 10, 'pos_1': '2B'}, {
{'name': 'Player B', 'wara': 2.5, 'team_id': 2, 'season': 10, 'pos_1': '3B'}, "name": "Player A",
"wara": 2.0,
"team_id": 1,
"season": 10,
"pos_1": "2B",
},
{
"name": "Player B",
"wara": 2.5,
"team_id": 2,
"season": 10,
"pos_1": "3B",
},
] ]
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
result = service.create_players(new_players, 'valid_token') result = service.create_players(new_players, "valid_token")
assert 'Inserted 2 players' in str(result) assert "Inserted 2 players" in str(result)
def test_create_duplicate_fails(self, repo, cache): def test_create_duplicate_fails(self, repo, cache):
"""Creating duplicate player should fail.""" """Creating duplicate player should fail."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
duplicate = [{'name': 'Mike Trout', 'wara': 5.0, 'team_id': 1, 'season': 10, 'pos_1': 'CF'}] duplicate = [
{
with patch.object(service, 'require_auth', return_value=True): "name": "Mike Trout",
"wara": 5.0,
"team_id": 1,
"season": 10,
"pos_1": "CF",
}
]
with patch.object(service, "require_auth", return_value=True):
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
service.create_players(duplicate, 'valid_token') service.create_players(duplicate, "valid_token")
assert 'already exists' in str(exc_info.value) assert "already exists" in str(exc_info.value)
def test_create_requires_auth(self, repo, cache): def test_create_requires_auth(self, repo, cache):
"""Creating players requires authentication.""" """Creating players requires authentication."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
new_player = [{'name': 'Test', 'wara': 1.0, 'team_id': 1, 'season': 10, 'pos_1': 'P'}] new_player = [
{"name": "Test", "wara": 1.0, "team_id": 1, "season": 10, "pos_1": "P"}
]
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
service.create_players(new_player, 'bad_token') service.create_players(new_player, "bad_token")
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
class TestPlayerServiceUpdate: class TestPlayerServiceUpdate:
"""Tests for update_player and patch_player methods.""" """Tests for update_player and patch_player methods."""
def test_patch_player_name(self, repo, cache): def test_patch_player_name(self, repo, cache):
"""Patch player's name.""" """Patch player's name."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
result = service.patch_player(1, {'name': 'New Name'}, 'valid_token') result = service.patch_player(1, {"name": "New Name"}, "valid_token")
assert result is not None assert result is not None
assert result.get('name') == 'New Name' assert result.get("name") == "New Name"
def test_patch_player_wara(self, repo, cache): def test_patch_player_wara(self, repo, cache):
"""Patch player's WARA.""" """Patch player's WARA."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
result = service.patch_player(1, {'wara': 6.0}, 'valid_token') result = service.patch_player(1, {"wara": 6.0}, "valid_token")
assert result.get('wara') == 6.0 assert result.get("wara") == 6.0
def test_patch_multiple_fields(self, repo, cache): def test_patch_multiple_fields(self, repo, cache):
"""Patch multiple fields at once.""" """Patch multiple fields at once."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
updates = { updates = {"name": "Updated Name", "wara": 7.0, "strat_code": "Super Elite"}
'name': 'Updated Name',
'wara': 7.0, with patch.object(service, "require_auth", return_value=True):
'strat_code': 'Super Elite' result = service.patch_player(1, updates, "valid_token")
}
assert result.get("name") == "Updated Name"
with patch.object(service, 'require_auth', return_value=True): assert result.get("wara") == 7.0
result = service.patch_player(1, updates, 'valid_token') assert result.get("strat_code") == "Super Elite"
assert result.get('name') == 'Updated Name'
assert result.get('wara') == 7.0
assert result.get('strat_code') == 'Super Elite'
def test_patch_nonexistent_player(self, repo, cache): def test_patch_nonexistent_player(self, repo, cache):
"""Patch fails for non-existent player.""" """Patch fails for non-existent player."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
service.patch_player(99999, {'name': 'Test'}, 'valid_token') service.patch_player(99999, {"name": "Test"}, "valid_token")
assert 'not found' in str(exc_info.value) assert "not found" in str(exc_info.value)
def test_patch_requires_auth(self, repo, cache): def test_patch_requires_auth(self, repo, cache):
"""Patching requires authentication.""" """Patching requires authentication."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
service.patch_player(1, {'name': 'Test'}, 'bad_token') service.patch_player(1, {"name": "Test"}, "bad_token")
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
class TestPlayerServiceDelete: class TestPlayerServiceDelete:
"""Tests for delete_player method.""" """Tests for delete_player method."""
def test_delete_player(self, repo, cache): def test_delete_player(self, repo, cache):
"""Delete existing player.""" """Delete existing player."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
# Verify player exists # Verify player exists
assert repo.get_by_id(1) is not None assert repo.get_by_id(1) is not None
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
result = service.delete_player(1, 'valid_token') result = service.delete_player(1, "valid_token")
assert 'deleted' in str(result) assert "deleted" in str(result)
# Verify player is gone # Verify player is gone
assert repo.get_by_id(1) is None assert repo.get_by_id(1) is None
def test_delete_nonexistent_player(self, repo, cache): def test_delete_nonexistent_player(self, repo, cache):
"""Delete fails for non-existent player.""" """Delete fails for non-existent player."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
service.delete_player(99999, 'valid_token') service.delete_player(99999, "valid_token")
assert 'not found' in str(exc_info.value) assert "not found" in str(exc_info.value)
def test_delete_requires_auth(self, repo, cache): def test_delete_requires_auth(self, repo, cache):
"""Deleting requires authentication.""" """Deleting requires authentication."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
service.delete_player(1, 'bad_token') service.delete_player(1, "bad_token")
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
class TestPlayerServiceCache:
"""Tests for cache functionality."""
@pytest.mark.skip(reason="Caching not yet implemented in service methods")
def test_cache_set_on_read(self, service, cache):
"""Cache is set on player read."""
service.get_players(season=10)
assert cache.was_called('set')
@pytest.mark.skip(reason="Caching not yet implemented in service methods")
def test_cache_invalidation_on_update(self, repo, cache):
"""Cache is invalidated on player update."""
config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config)
# Read to set cache
service.get_players(season=10)
initial_calls = len(cache.get_calls('set'))
# Update should invalidate cache
with patch.object(service, 'require_auth', return_value=True):
service.patch_player(1, {'name': 'Test'}, 'valid_token')
# Should have more delete calls after update
delete_calls = [c for c in cache.get_calls() if c.get('method') == 'delete']
assert len(delete_calls) > 0
@pytest.mark.skip(reason="Caching not yet implemented in service methods")
def test_cache_hit_rate(self, repo, cache):
"""Test cache hit rate tracking."""
config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config)
# First call - cache miss
service.get_players(season=10)
miss_count = cache._miss_count
# Second call - cache hit
service.get_players(season=10)
# Hit rate should have improved
assert cache.hit_rate > 0
class TestPlayerServiceValidation: class TestPlayerServiceValidation:
"""Tests for input validation and edge cases.""" """Tests for input validation and edge cases."""
def test_invalid_season_returns_empty(self, service): def test_invalid_season_returns_empty(self, service):
"""Invalid season returns empty result.""" """Invalid season returns empty result."""
result = service.get_players(season=999) result = service.get_players(season=999)
assert result['count'] == 0 or result['players'] == [] assert result["count"] == 0 or result["players"] == []
def test_empty_search_returns_all(self, service): def test_empty_search_returns_all(self, service):
"""Empty search query returns all players.""" """Empty search query returns all players."""
result = service.search_players('', season=10) result = service.search_players("", season=10)
assert result['count'] >= 1 assert result["count"] >= 1
def test_sort_with_no_results(self, service): def test_sort_with_no_results(self, service):
"""Sorting with no results doesn't error.""" """Sorting with no results doesn't error."""
result = service.get_players(season=999, sort='cost-desc') result = service.get_players(season=999, sort="cost-desc")
assert result['count'] == 0 or result['players'] == [] assert result["count"] == 0 or result["players"] == []
def test_cache_clear_on_create(self, repo, cache): def test_cache_clear_on_create(self, repo, cache):
"""Cache is cleared when new players are created.""" """Cache is cleared when new players are created."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
# Set up some cache data # Set up some cache data
cache.set('test:key', 'value', 300) cache.set("test:key", "value", 300)
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
service.create_players([{ service.create_players(
'name': 'New', [
'wara': 1.0, {
'team_id': 1, "name": "New",
'season': 10, "wara": 1.0,
'pos_1': 'P' "team_id": 1,
}], 'valid_token') "season": 10,
"pos_1": "P",
}
],
"valid_token",
)
# Should have invalidate calls # Should have invalidate calls
assert len(cache.get_calls()) > 0 assert len(cache.get_calls()) > 0
class TestPlayerServiceIntegration: class TestPlayerServiceIntegration:
"""Integration tests combining multiple operations.""" """Integration tests combining multiple operations."""
def test_full_crud_cycle(self, repo, cache): def test_full_crud_cycle(self, repo, cache):
"""Test complete CRUD cycle.""" """Test complete CRUD cycle."""
config = ServiceConfig(player_repo=repo, cache=cache) config = ServiceConfig(player_repo=repo, cache=cache)
service = PlayerService(config=config) service = PlayerService(config=config)
# CREATE # CREATE
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
create_result = service.create_players([{ create_result = service.create_players(
'name': 'CRUD Test', [
'wara': 3.0, {
'team_id': 1, "name": "CRUD Test",
'season': 10, "wara": 3.0,
'pos_1': 'DH' "team_id": 1,
}], 'valid_token') "season": 10,
"pos_1": "DH",
}
],
"valid_token",
)
# READ # READ
search_result = service.search_players('CRUD', season=10) search_result = service.search_players("CRUD", season=10)
assert search_result['count'] >= 1 assert search_result["count"] >= 1
player_id = search_result['players'][0].get('id') player_id = search_result["players"][0].get("id")
# UPDATE # UPDATE
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
update_result = service.patch_player(player_id, {'wara': 4.0}, 'valid_token') update_result = service.patch_player(
assert update_result.get('wara') == 4.0 player_id, {"wara": 4.0}, "valid_token"
)
assert update_result.get("wara") == 4.0
# DELETE # DELETE
with patch.object(service, 'require_auth', return_value=True): with patch.object(service, "require_auth", return_value=True):
delete_result = service.delete_player(player_id, 'valid_token') delete_result = service.delete_player(player_id, "valid_token")
assert 'deleted' in str(delete_result) assert "deleted" in str(delete_result)
# VERIFY DELETED # VERIFY DELETED
get_result = service.get_player(player_id) get_result = service.get_player(player_id)
assert get_result is None assert get_result is None
def test_search_then_filter(self, service): def test_search_then_filter(self, service):
"""Search and then filter operations.""" """Search and then filter operations."""
# First get all players # First get all players
all_result = service.get_players(season=10) all_result = service.get_players(season=10)
initial_count = all_result['count'] initial_count = all_result["count"]
# Then filter by team # Then filter by team
filtered = service.get_players(season=10, team_id=[1]) filtered = service.get_players(season=10, team_id=[1])
# Filtered should be <= all # Filtered should be <= all
assert filtered['count'] <= initial_count assert filtered["count"] <= initial_count
# ============================================================================ # ============================================================================