diff --git a/app/db_engine.py b/app/db_engine.py index 3b7af32..c1979cf 100644 --- a/app/db_engine.py +++ b/app/db_engine.py @@ -51,12 +51,6 @@ Per season updates: """ -WEEK_NUMS = { - 'regular': { - - } -} - def model_csv_headers(this_obj, exclude=None) -> List: data = model_to_dict(this_obj, recurse=False, exclude=exclude) @@ -458,7 +452,7 @@ class Team(BaseModel): active_roster['WARa'] -= move.player.wara try: active_roster['players'].remove(move.player) - except: + except Exception: print(f'I could not drop {move.player.name}') for move in all_adds: @@ -519,7 +513,7 @@ class Team(BaseModel): # print(f'SIL dropping {move.player.name} id ({move.player.get_id()}) for {move.player.wara} WARa') try: short_roster['players'].remove(move.player) - except: + except Exception: print(f'I could not drop {move.player.name}') for move in all_adds: @@ -580,7 +574,7 @@ class Team(BaseModel): # print(f'LIL dropping {move.player.name} id ({move.player.get_id()}) for {move.player.wara} WARa') try: long_roster['players'].remove(move.player) - except: + except Exception: print(f'I could not drop {move.player.name}') for move in all_adds: @@ -2351,7 +2345,7 @@ class CustomCommand(BaseModel): try: import json return json.loads(self.tags) - except: + except Exception: return [] def set_tags_list(self, tags_list): diff --git a/app/dependencies.py b/app/dependencies.py index b95747d..6441155 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -11,8 +11,8 @@ from fastapi import HTTPException, Response from fastapi.security import OAuth2PasswordBearer from redis import Redis -date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' -logger = logging.getLogger('discord_app') +date = f"{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}" +logger = logging.getLogger("discord_app") # date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' # log_level = logger.info if os.environ.get('LOG_LEVEL') == 'INFO' else 'WARN' @@ -23,10 +23,10 @@ logger = logging.getLogger('discord_app') # ) # Redis configuration -REDIS_HOST = os.environ.get('REDIS_HOST', 'localhost') -REDIS_PORT = int(os.environ.get('REDIS_PORT', '6379')) -REDIS_DB = int(os.environ.get('REDIS_DB', '0')) -CACHE_ENABLED = os.environ.get('CACHE_ENABLED', 'true').lower() == 'true' +REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") +REDIS_PORT = int(os.environ.get("REDIS_PORT", "6379")) +REDIS_DB = int(os.environ.get("REDIS_DB", "0")) +CACHE_ENABLED = os.environ.get("CACHE_ENABLED", "true").lower() == "true" # Initialize Redis client with connection error handling if not CACHE_ENABLED: @@ -40,7 +40,7 @@ else: db=REDIS_DB, decode_responses=True, socket_connect_timeout=5, - socket_timeout=5 + socket_timeout=5, ) # Test connection redis_client.ping() @@ -50,12 +50,16 @@ else: redis_client = None oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -priv_help = False if not os.environ.get('PRIVATE_IN_SCHEMA') else os.environ.get('PRIVATE_IN_SCHEMA').upper() -PRIVATE_IN_SCHEMA = True if priv_help == 'TRUE' else False +priv_help = ( + False + if not os.environ.get("PRIVATE_IN_SCHEMA") + else os.environ.get("PRIVATE_IN_SCHEMA").upper() +) +PRIVATE_IN_SCHEMA = True if priv_help == "TRUE" else False def valid_token(token): - return token == os.environ.get('API_TOKEN') + return token == os.environ.get("API_TOKEN") def update_season_batting_stats(player_ids, season, db_connection): @@ -63,17 +67,19 @@ def update_season_batting_stats(player_ids, season, db_connection): Update season batting stats for specific players in a given season. Recalculates stats from stratplay data and upserts into seasonbattingstats table. """ - + if not player_ids: logger.warning("update_season_batting_stats called with empty player_ids list") return - + # Convert single player_id to list for consistency if isinstance(player_ids, int): player_ids = [player_ids] - - logger.info(f"Updating season batting stats for {len(player_ids)} players in season {season}") - + + logger.info( + f"Updating season batting stats for {len(player_ids)} players in season {season}" + ) + try: # SQL query to recalculate and upsert batting stats query = """ @@ -217,12 +223,14 @@ def update_season_batting_stats(player_ids, season, db_connection): sb = EXCLUDED.sb, cs = EXCLUDED.cs; """ - + # Execute the query with parameters using the passed database connection db_connection.execute_sql(query, [season, player_ids, season, player_ids]) - - logger.info(f"Successfully updated season batting stats for {len(player_ids)} players in season {season}") - + + logger.info( + f"Successfully updated season batting stats for {len(player_ids)} players in season {season}" + ) + except Exception as e: logger.error(f"Error updating season batting stats: {e}") raise @@ -233,17 +241,19 @@ def update_season_pitching_stats(player_ids, season, db_connection): Update season pitching stats for specific players in a given season. Recalculates stats from stratplay and decision data and upserts into seasonpitchingstats table. """ - + if not player_ids: logger.warning("update_season_pitching_stats called with empty player_ids list") return - + # Convert single player_id to list for consistency if isinstance(player_ids, int): player_ids = [player_ids] - - logger.info(f"Updating season pitching stats for {len(player_ids)} players in season {season}") - + + logger.info( + f"Updating season pitching stats for {len(player_ids)} players in season {season}" + ) + try: # SQL query to recalculate and upsert pitching stats query = """ @@ -357,8 +367,28 @@ def update_season_pitching_stats(player_ids, season, db_connection): WHEN SUM(sp.bb) > 0 THEN ROUND(SUM(sp.so)::DECIMAL / SUM(sp.bb), 2) ELSE 0.0 - END AS kperbb - + END AS kperbb, + + -- Runners left on base when pitcher recorded the 3rd out + SUM(CASE WHEN sp.on_first_final IS NOT NULL AND sp.on_first_final != 4 AND sp.starting_outs + sp.outs = 3 THEN 1 ELSE 0 END) + + SUM(CASE WHEN sp.on_second_final IS NOT NULL AND sp.on_second_final != 4 AND sp.starting_outs + sp.outs = 3 THEN 1 ELSE 0 END) + + SUM(CASE WHEN sp.on_third_final IS NOT NULL AND sp.on_third_final != 4 AND sp.starting_outs + sp.outs = 3 THEN 1 ELSE 0 END) AS lob_2outs, + + -- RBI allowed (excluding HR) per runner opportunity + CASE + WHEN (SUM(CASE WHEN sp.on_first IS NOT NULL THEN 1 ELSE 0 END) + + SUM(CASE WHEN sp.on_second IS NOT NULL THEN 1 ELSE 0 END) + + SUM(CASE WHEN sp.on_third IS NOT NULL THEN 1 ELSE 0 END)) > 0 + THEN ROUND( + (SUM(sp.rbi) - SUM(sp.homerun))::DECIMAL / + (SUM(CASE WHEN sp.on_first IS NOT NULL THEN 1 ELSE 0 END) + + SUM(CASE WHEN sp.on_second IS NOT NULL THEN 1 ELSE 0 END) + + SUM(CASE WHEN sp.on_third IS NOT NULL THEN 1 ELSE 0 END)), + 3 + ) + ELSE 0.000 + END AS rbipercent + FROM stratplay sp JOIN stratgame sg ON sg.id = sp.game_id JOIN player p ON p.id = sp.pitcher_id @@ -402,7 +432,7 @@ def update_season_pitching_stats(player_ids, season, db_connection): ps.bphr, ps.bpfo, ps.bp1b, ps.bplo, ps.wp, ps.balk, ps.wpa * -1, ps.era, ps.whip, ps.avg, ps.obp, ps.slg, ps.ops, ps.woba, ps.hper9, ps.kper9, ps.bbper9, ps.kperbb, - 0.0, 0.0, COALESCE(ps.re24 * -1, 0.0) + ps.lob_2outs, ps.rbipercent, COALESCE(ps.re24 * -1, 0.0) FROM pitching_stats ps LEFT JOIN decision_stats ds ON ps.player_id = ds.player_id AND ps.season = ds.season ON CONFLICT (player_id, season) @@ -460,12 +490,14 @@ def update_season_pitching_stats(player_ids, season, db_connection): rbipercent = EXCLUDED.rbipercent, re24 = EXCLUDED.re24; """ - + # Execute the query with parameters using the passed database connection db_connection.execute_sql(query, [season, player_ids, season, player_ids]) - - logger.info(f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}") - + + logger.info( + f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}" + ) + except Exception as e: logger.error(f"Error updating season pitching stats: {e}") raise @@ -474,26 +506,24 @@ def update_season_pitching_stats(player_ids, season, db_connection): def send_webhook_message(message: str) -> bool: """ Send a message to Discord via webhook. - + Args: message: The message content to send - + Returns: bool: True if successful, False otherwise """ webhook_url = "https://discord.com/api/webhooks/1408811717424840876/7RXG_D5IqovA3Jwa9YOobUjVcVMuLc6cQyezABcWuXaHo5Fvz1en10M7J43o3OJ3bzGW" - + try: - payload = { - "content": message - } - + payload = {"content": message} + response = requests.post(webhook_url, json=payload, timeout=10) response.raise_for_status() - + logger.info(f"Webhook message sent successfully: {message[:100]}...") return True - + except requests.exceptions.RequestException as e: logger.error(f"Failed to send webhook message: {e}") return False @@ -502,99 +532,106 @@ def send_webhook_message(message: str) -> bool: return False -def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True): +def cache_result( + ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True +): """ Decorator to cache function results in Redis with parameter normalization. - + Args: ttl: Time to live in seconds (default: 5 minutes) key_prefix: Prefix for cache keys (default: "api") normalize_params: Remove None/empty values to reduce cache variations (default: True) - + Usage: @cache_result(ttl=600, key_prefix="stats") async def get_player_stats(player_id: int, season: Optional[int] = None): # expensive operation return stats - + # These will use the same cache entry when normalize_params=True: # get_player_stats(123, None) and get_player_stats(123) """ + def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): # Skip caching if Redis is not available if redis_client is None: return await func(*args, **kwargs) - + try: # Normalize parameters to reduce cache variations normalized_kwargs = kwargs.copy() if normalize_params: # Remove None values and empty collections normalized_kwargs = { - k: v for k, v in kwargs.items() + k: v + for k, v in kwargs.items() if v is not None and v != [] and v != "" and v != {} } - + # Generate more readable cache key args_str = "_".join(str(arg) for arg in args if arg is not None) - kwargs_str = "_".join([ - f"{k}={v}" for k, v in sorted(normalized_kwargs.items()) - ]) - + kwargs_str = "_".join( + [f"{k}={v}" for k, v in sorted(normalized_kwargs.items())] + ) + # Combine args and kwargs for cache key key_parts = [key_prefix, func.__name__] if args_str: key_parts.append(args_str) if kwargs_str: key_parts.append(kwargs_str) - + cache_key = ":".join(key_parts) - + # Truncate very long cache keys to prevent Redis key size limits if len(cache_key) > 200: cache_key = f"{key_prefix}:{func.__name__}:{hash(cache_key)}" - + # Try to get from cache cached_result = redis_client.get(cache_key) if cached_result is not None: logger.debug(f"Cache hit for key: {cache_key}") return json.loads(cached_result) - + # Cache miss - execute function logger.debug(f"Cache miss for key: {cache_key}") result = await func(*args, **kwargs) - + # Skip caching for Response objects (like CSV downloads) as they can't be properly serialized if not isinstance(result, Response): # Store in cache with TTL redis_client.setex( - cache_key, - ttl, - json.dumps(result, default=str, ensure_ascii=False) + cache_key, + ttl, + json.dumps(result, default=str, ensure_ascii=False), ) else: - logger.debug(f"Skipping cache for Response object from {func.__name__}") - + logger.debug( + f"Skipping cache for Response object from {func.__name__}" + ) + return result - + except Exception as e: # If caching fails, log error and continue without caching logger.error(f"Cache error for {func.__name__}: {e}") return await func(*args, **kwargs) - + return wrapper + return decorator def invalidate_cache(pattern: str = "*"): """ Invalidate cache entries matching a pattern. - + Args: pattern: Redis pattern to match keys (default: "*" for all) - + Usage: invalidate_cache("stats:*") # Clear all stats cache invalidate_cache("api:get_player_*") # Clear specific player cache @@ -602,12 +639,14 @@ def invalidate_cache(pattern: str = "*"): if redis_client is None: logger.warning("Cannot invalidate cache: Redis not available") return 0 - + try: keys = redis_client.keys(pattern) if keys: deleted = redis_client.delete(*keys) - logger.info(f"Invalidated {deleted} cache entries matching pattern: {pattern}") + logger.info( + f"Invalidated {deleted} cache entries matching pattern: {pattern}" + ) return deleted else: logger.debug(f"No cache entries found matching pattern: {pattern}") @@ -620,13 +659,13 @@ def invalidate_cache(pattern: str = "*"): def get_cache_stats() -> dict: """ Get Redis cache statistics. - + Returns: dict: Cache statistics including memory usage, key count, etc. """ if redis_client is None: return {"status": "unavailable", "message": "Redis not connected"} - + try: info = redis_client.info() return { @@ -634,7 +673,7 @@ def get_cache_stats() -> dict: "memory_used": info.get("used_memory_human", "unknown"), "total_keys": redis_client.dbsize(), "connected_clients": info.get("connected_clients", 0), - "uptime_seconds": info.get("uptime_in_seconds", 0) + "uptime_seconds": info.get("uptime_in_seconds", 0), } except Exception as e: logger.error(f"Error getting cache stats: {e}") @@ -642,34 +681,35 @@ def get_cache_stats() -> dict: def add_cache_headers( - max_age: int = 300, + max_age: int = 300, cache_type: str = "public", vary: Optional[str] = None, - etag: bool = False + etag: bool = False, ): """ Decorator to add HTTP cache headers to FastAPI responses. - + Args: max_age: Cache duration in seconds (default: 5 minutes) cache_type: "public", "private", or "no-cache" (default: "public") vary: Vary header value (e.g., "Accept-Encoding, Authorization") etag: Whether to generate ETag based on response content - + Usage: @add_cache_headers(max_age=1800, cache_type="public") async def get_static_data(): return {"data": "static content"} - + @add_cache_headers(max_age=60, cache_type="private", vary="Authorization") async def get_user_data(): return {"data": "user specific"} """ + def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): result = await func(*args, **kwargs) - + # Handle different response types if isinstance(result, Response): response = result @@ -677,38 +717,41 @@ def add_cache_headers( # Convert to Response with JSON content response = Response( content=json.dumps(result, default=str, ensure_ascii=False), - media_type="application/json" + media_type="application/json", ) else: # Handle other response types response = Response(content=str(result)) - + # Build Cache-Control header cache_control_parts = [cache_type] if cache_type != "no-cache" and max_age > 0: cache_control_parts.append(f"max-age={max_age}") - + response.headers["Cache-Control"] = ", ".join(cache_control_parts) - + # Add Vary header if specified if vary: response.headers["Vary"] = vary - + # Add ETag if requested - if etag and (hasattr(result, '__dict__') or isinstance(result, (dict, list))): + if etag and ( + hasattr(result, "__dict__") or isinstance(result, (dict, list)) + ): content_hash = hashlib.md5( json.dumps(result, default=str, sort_keys=True).encode() ).hexdigest() response.headers["ETag"] = f'"{content_hash}"' - + # Add Last-Modified header with current time for dynamic content - response.headers["Last-Modified"] = datetime.datetime.now(datetime.timezone.utc).strftime( - "%a, %d %b %Y %H:%M:%S GMT" - ) - + response.headers["Last-Modified"] = datetime.datetime.now( + datetime.timezone.utc + ).strftime("%a, %d %b %Y %H:%M:%S GMT") + return response - + return wrapper + return decorator @@ -718,52 +761,59 @@ def handle_db_errors(func): Ensures proper cleanup of database connections and provides consistent error handling. Includes comprehensive logging with function context, timing, and stack traces. """ + @wraps(func) async def wrapper(*args, **kwargs): import time import traceback from .db_engine import db # Import here to avoid circular imports - + start_time = time.time() func_name = f"{func.__module__}.{func.__name__}" - + # Sanitize arguments for logging (exclude sensitive data) safe_args = [] safe_kwargs = {} - + try: # Log sanitized arguments (avoid logging tokens, passwords, etc.) for arg in args: - if hasattr(arg, '__dict__') and hasattr(arg, 'url'): # FastAPI Request object - safe_args.append(f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})") + if hasattr(arg, "__dict__") and hasattr( + arg, "url" + ): # FastAPI Request object + safe_args.append( + f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})" + ) else: safe_args.append(str(arg)[:100]) # Truncate long values - + for key, value in kwargs.items(): - if key.lower() in ['token', 'password', 'secret', 'key']: - safe_kwargs[key] = '[REDACTED]' + if key.lower() in ["token", "password", "secret", "key"]: + safe_kwargs[key] = "[REDACTED]" else: safe_kwargs[key] = str(value)[:100] # Truncate long values - - logger.info(f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}") - + + logger.info( + f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}" + ) + result = await func(*args, **kwargs) - + elapsed_time = time.time() - start_time logger.info(f"Completed {func_name} successfully in {elapsed_time:.3f}s") - + return result - + except Exception as e: elapsed_time = time.time() - start_time error_trace = traceback.format_exc() - + logger.error(f"Database error in {func_name} after {elapsed_time:.3f}s") logger.error(f"Function args: {safe_args}") logger.error(f"Function kwargs: {safe_kwargs}") logger.error(f"Exception: {str(e)}") logger.error(f"Full traceback:\n{error_trace}") - + try: logger.info(f"Attempting database rollback for {func_name}") db.rollback() @@ -775,8 +825,12 @@ def handle_db_errors(func): db.close() logger.info(f"Database connection closed for {func_name}") except Exception as close_error: - logger.error(f"Error closing database connection in {func_name}: {close_error}") - - raise HTTPException(status_code=500, detail=f'Database error in {func_name}: {str(e)}') - + logger.error( + f"Error closing database connection in {func_name}: {close_error}" + ) + + raise HTTPException( + status_code=500, detail=f"Database error in {func_name}: {str(e)}" + ) + return wrapper diff --git a/app/main.py b/app/main.py index ab04918..3de0bd3 100644 --- a/app/main.py +++ b/app/main.py @@ -10,38 +10,64 @@ from fastapi.openapi.utils import get_openapi # from fastapi.openapi.docs import get_swagger_ui_html # from fastapi.openapi.utils import get_openapi -from .routers_v3 import current, players, results, schedules, standings, teams, transactions, battingstats, pitchingstats, fieldingstats, draftpicks, draftlist, managers, awards, draftdata, keepers, stratgame, stratplay, injuries, decisions, divisions, sbaplayers, custom_commands, help_commands, views +from .routers_v3 import ( + current, + players, + results, + schedules, + standings, + teams, + transactions, + battingstats, + pitchingstats, + fieldingstats, + draftpicks, + draftlist, + managers, + awards, + draftdata, + keepers, + stratgame, + stratplay, + injuries, + decisions, + divisions, + sbaplayers, + custom_commands, + help_commands, + views, +) # date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' -log_level = logging.INFO if os.environ.get('LOG_LEVEL') == 'INFO' else logging.WARNING +log_level = logging.INFO if os.environ.get("LOG_LEVEL") == "INFO" else logging.WARNING # logging.basicConfig( # filename=f'logs/database/{date}.log', # format='%(asctime)s - sba-database - %(levelname)s - %(message)s', # level=log_level # ) -logger = logging.getLogger('discord_app') +logger = logging.getLogger("discord_app") logger.setLevel(log_level) handler = RotatingFileHandler( - filename='./logs/sba-database.log', + filename="./logs/sba-database.log", # encoding='utf-8', maxBytes=8 * 1024 * 1024, # 8 MiB backupCount=5, # Rotate through 5 files ) -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) app = FastAPI( # root_path='/api', - responses={404: {'description': 'Not found'}}, - docs_url='/api/docs', - redoc_url='/api/redoc' + responses={404: {"description": "Not found"}}, + docs_url="/api/docs", + redoc_url="/api/redoc", ) -logger.info(f'Starting up now...') +logger.info(f"Starting up now...") app.include_router(current.router) @@ -70,18 +96,20 @@ app.include_router(custom_commands.router) app.include_router(help_commands.router) app.include_router(views.router) -logger.info(f'Loaded all routers.') +logger.info(f"Loaded all routers.") @app.get("/api/docs", include_in_schema=False) async def get_docs(req: Request): - print(req.scope) - return get_swagger_ui_html(openapi_url=req.scope.get('root_path')+'/openapi.json', title='Swagger') + logger.debug(req.scope) + return get_swagger_ui_html( + openapi_url=req.scope.get("root_path") + "/openapi.json", title="Swagger" + ) @app.get("/api/openapi.json", include_in_schema=False) async def openapi(): - return get_openapi(title='SBa API Docs', version=f'0.1.1', routes=app.routes) + return get_openapi(title="SBa API Docs", version=f"0.1.1", routes=app.routes) # @app.get("/api") diff --git a/app/routers_v3/battingstats.py b/app/routers_v3/battingstats.py index 68df4d5..49d3fa8 100644 --- a/app/routers_v3/battingstats.py +++ b/app/routers_v3/battingstats.py @@ -381,14 +381,25 @@ async def post_batstats(s_list: BatStatList, token: str = Depends(oauth2_scheme) all_stats = [] + all_team_ids = list(set(x.team_id for x in s_list.stats)) + all_player_ids = list(set(x.player_id for x in s_list.stats)) + found_team_ids = ( + set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids)) + if all_team_ids + else set() + ) + found_player_ids = ( + set(p.id for p in Player.select(Player.id).where(Player.id << all_player_ids)) + if all_player_ids + else set() + ) + for x in s_list.stats: - team = Team.get_or_none(Team.id == x.team_id) - this_player = Player.get_or_none(Player.id == x.player_id) - if team is None: + if x.team_id not in found_team_ids: raise HTTPException( status_code=404, detail=f"Team ID {x.team_id} not found" ) - if this_player is None: + if x.player_id not in found_player_ids: raise HTTPException( status_code=404, detail=f"Player ID {x.player_id} not found" ) diff --git a/app/routers_v3/custom_commands.py b/app/routers_v3/custom_commands.py index bbe3077..577e78d 100644 --- a/app/routers_v3/custom_commands.py +++ b/app/routers_v3/custom_commands.py @@ -296,9 +296,8 @@ async def get_custom_commands( if command_dict.get("tags"): try: command_dict["tags"] = json.loads(command_dict["tags"]) - except: + except Exception: command_dict["tags"] = [] - # Get full creator information creator_id = command_dict["creator_id"] creator_cursor = db.execute_sql( @@ -406,7 +405,7 @@ async def create_custom_command_endpoint( if command_dict.get("tags"): try: command_dict["tags"] = json.loads(command_dict["tags"]) - except: + except Exception: command_dict["tags"] = [] creator_created_at = command_dict.pop("creator_created_at") @@ -467,7 +466,7 @@ async def update_custom_command_endpoint( if command_dict.get("tags"): try: command_dict["tags"] = json.loads(command_dict["tags"]) - except: + except Exception: command_dict["tags"] = [] creator_created_at = command_dict.pop("creator_created_at") @@ -552,7 +551,7 @@ async def patch_custom_command( if command_dict.get("tags"): try: command_dict["tags"] = json.loads(command_dict["tags"]) - except: + except Exception: command_dict["tags"] = [] creator_created_at = command_dict.pop("creator_created_at") @@ -781,7 +780,7 @@ async def get_custom_command_stats(): if command_dict.get("tags"): try: command_dict["tags"] = json.loads(command_dict["tags"]) - except: + except Exception: command_dict["tags"] = [] command_dict["creator"] = { "discord_id": command_dict.pop("creator_discord_id"), @@ -881,7 +880,7 @@ async def get_custom_command_by_name_endpoint(command_name: str): if command_dict.get("tags"): try: command_dict["tags"] = json.loads(command_dict["tags"]) - except: + except Exception: command_dict["tags"] = [] # Add creator info - get full creator record @@ -966,7 +965,7 @@ async def execute_custom_command( if updated_dict.get("tags"): try: updated_dict["tags"] = json.loads(updated_dict["tags"]) - except: + except Exception: updated_dict["tags"] = [] # Build creator object from the fields returned by get_custom_command_by_id @@ -1053,7 +1052,7 @@ async def get_custom_command(command_id: int): if command_dict.get("tags"): try: command_dict["tags"] = json.loads(command_dict["tags"]) - except: + except Exception: command_dict["tags"] = [] creator_created_at = command_dict.pop("creator_created_at") diff --git a/app/routers_v3/pitchingstats.py b/app/routers_v3/pitchingstats.py index cf3357d..d318013 100644 --- a/app/routers_v3/pitchingstats.py +++ b/app/routers_v3/pitchingstats.py @@ -1,6 +1,3 @@ -import datetime -import os - from fastapi import APIRouter, Depends, HTTPException, Query from typing import List, Optional, Literal import logging diff --git a/app/routers_v3/results.py b/app/routers_v3/results.py index 76bd440..7ba46b8 100644 --- a/app/routers_v3/results.py +++ b/app/routers_v3/results.py @@ -159,12 +159,23 @@ async def post_results(result_list: ResultList, token: str = Depends(oauth2_sche raise HTTPException(status_code=401, detail="Unauthorized") new_results = [] + + all_team_ids = list( + set(x.awayteam_id for x in result_list.results) + | set(x.hometeam_id for x in result_list.results) + ) + found_team_ids = ( + set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids)) + if all_team_ids + else set() + ) + for x in result_list.results: - if Team.get_or_none(Team.id == x.awayteam_id) is None: + if x.awayteam_id not in found_team_ids: raise HTTPException( status_code=404, detail=f"Team ID {x.awayteam_id} not found" ) - if Team.get_or_none(Team.id == x.hometeam_id) is None: + if x.hometeam_id not in found_team_ids: raise HTTPException( status_code=404, detail=f"Team ID {x.hometeam_id} not found" ) diff --git a/app/routers_v3/schedules.py b/app/routers_v3/schedules.py index a7bca82..afcaabf 100644 --- a/app/routers_v3/schedules.py +++ b/app/routers_v3/schedules.py @@ -144,12 +144,23 @@ async def post_schedules(sched_list: ScheduleList, token: str = Depends(oauth2_s raise HTTPException(status_code=401, detail="Unauthorized") new_sched = [] + + all_team_ids = list( + set(x.awayteam_id for x in sched_list.schedules) + | set(x.hometeam_id for x in sched_list.schedules) + ) + found_team_ids = ( + set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids)) + if all_team_ids + else set() + ) + for x in sched_list.schedules: - if Team.get_or_none(Team.id == x.awayteam_id) is None: + if x.awayteam_id not in found_team_ids: raise HTTPException( status_code=404, detail=f"Team ID {x.awayteam_id} not found" ) - if Team.get_or_none(Team.id == x.hometeam_id) is None: + if x.hometeam_id not in found_team_ids: raise HTTPException( status_code=404, detail=f"Team ID {x.hometeam_id} not found" ) diff --git a/app/routers_v3/standings.py b/app/routers_v3/standings.py index b59dc26..f5ef37e 100644 --- a/app/routers_v3/standings.py +++ b/app/routers_v3/standings.py @@ -1,24 +1,29 @@ from fastapi import APIRouter, Depends, HTTPException, Query from typing import List, Optional import logging -import pydantic from ..db_engine import db, Standings, Team, Division, model_to_dict, chunked, fn -from ..dependencies import oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors - -logger = logging.getLogger('discord_app') - -router = APIRouter( - prefix='/api/v3/standings', - tags=['standings'] +from ..dependencies import ( + oauth2_scheme, + valid_token, + PRIVATE_IN_SCHEMA, + handle_db_errors, ) +logger = logging.getLogger("discord_app") -@router.get('') +router = APIRouter(prefix="/api/v3/standings", tags=["standings"]) + + +@router.get("") @handle_db_errors async def get_standings( - season: int, team_id: list = Query(default=None), league_abbrev: Optional[str] = None, - division_abbrev: Optional[str] = None, short_output: Optional[bool] = False): + season: int, + team_id: list = Query(default=None), + league_abbrev: Optional[str] = None, + division_abbrev: Optional[str] = None, + short_output: Optional[bool] = False, +): standings = Standings.select_season(season) # if standings.count() == 0: @@ -30,55 +35,66 @@ async def get_standings( standings = standings.where(Standings.team << t_query) if league_abbrev is not None: - l_query = Division.select().where(fn.Lower(Division.league_abbrev) == league_abbrev.lower()) + l_query = Division.select().where( + fn.Lower(Division.league_abbrev) == league_abbrev.lower() + ) standings = standings.where(Standings.team.division << l_query) if division_abbrev is not None: - d_query = Division.select().where(fn.Lower(Division.division_abbrev) == division_abbrev.lower()) + d_query = Division.select().where( + fn.Lower(Division.division_abbrev) == division_abbrev.lower() + ) standings = standings.where(Standings.team.division << d_query) def win_pct(this_team_stan): if this_team_stan.wins + this_team_stan.losses == 0: return 0 else: - return (this_team_stan.wins / (this_team_stan.wins + this_team_stan.losses)) + \ - (this_team_stan.run_diff * .000001) + return ( + this_team_stan.wins / (this_team_stan.wins + this_team_stan.losses) + ) + (this_team_stan.run_diff * 0.000001) div_teams = [x for x in standings] div_teams.sort(key=lambda team: win_pct(team), reverse=True) return_standings = { - 'count': len(div_teams), - 'standings': [model_to_dict(x, recurse=not short_output) for x in div_teams] + "count": len(div_teams), + "standings": [model_to_dict(x, recurse=not short_output) for x in div_teams], } db.close() return return_standings -@router.get('/team/{team_id}') +@router.get("/team/{team_id}") @handle_db_errors async def get_team_standings(team_id: int): this_stan = Standings.get_or_none(Standings.team_id == team_id) if this_stan is None: - raise HTTPException(status_code=404, detail=f'No standings found for team id {team_id}') + raise HTTPException( + status_code=404, detail=f"No standings found for team id {team_id}" + ) return model_to_dict(this_stan) -@router.patch('/{stan_id}', include_in_schema=PRIVATE_IN_SCHEMA) +@router.patch("/{stan_id}", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def patch_standings( - stan_id, wins: Optional[int] = None, losses: Optional[int] = None, token: str = Depends(oauth2_scheme)): + stan_id, + wins: Optional[int] = None, + losses: Optional[int] = None, + token: str = Depends(oauth2_scheme), +): if not valid_token(token): - logger.warning(f'patch_standings - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"patch_standings - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") try: this_stan = Standings.get_by_id(stan_id) except Exception as e: db.close() - raise HTTPException(status_code=404, detail=f'No team found with id {stan_id}') + raise HTTPException(status_code=404, detail=f"No team found with id {stan_id}") if wins: this_stan.wins = wins @@ -91,35 +107,35 @@ async def patch_standings( return model_to_dict(this_stan) -@router.post('/s{season}/new', include_in_schema=PRIVATE_IN_SCHEMA) +@router.post("/s{season}/new", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def post_standings(season: int, token: str = Depends(oauth2_scheme)): if not valid_token(token): - logger.warning(f'post_standings - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"post_standings - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") new_teams = [] all_teams = Team.select().where(Team.season == season) for x in all_teams: - new_teams.append(Standings({'team_id': x.id})) - + new_teams.append(Standings({"team_id": x.id})) + with db.atomic(): for batch in chunked(new_teams, 16): Standings.insert_many(batch).on_conflict_ignore().execute() db.close() - return f'Inserted {len(new_teams)} standings' + return f"Inserted {len(new_teams)} standings" -@router.post('/s{season}/recalculate', include_in_schema=PRIVATE_IN_SCHEMA) +@router.post("/s{season}/recalculate", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def recalculate_standings(season: int, token: str = Depends(oauth2_scheme)): if not valid_token(token): - logger.warning(f'recalculate_standings - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"recalculate_standings - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") code = Standings.recalculate(season) db.close() if code == 69: - raise HTTPException(status_code=500, detail=f'Error recreating Standings rows') - return f'Just recalculated standings for season {season}' + raise HTTPException(status_code=500, detail=f"Error recreating Standings rows") + return f"Just recalculated standings for season {season}" diff --git a/app/routers_v3/transactions.py b/app/routers_v3/transactions.py index 4d11e3a..1880dcc 100644 --- a/app/routers_v3/transactions.py +++ b/app/routers_v3/transactions.py @@ -143,16 +143,31 @@ async def post_transactions( all_moves = [] + all_team_ids = list( + set(x.oldteam_id for x in moves.moves) | set(x.newteam_id for x in moves.moves) + ) + all_player_ids = list(set(x.player_id for x in moves.moves)) + found_team_ids = ( + set(t.id for t in Team.select(Team.id).where(Team.id << all_team_ids)) + if all_team_ids + else set() + ) + found_player_ids = ( + set(p.id for p in Player.select(Player.id).where(Player.id << all_player_ids)) + if all_player_ids + else set() + ) + for x in moves.moves: - if Team.get_or_none(Team.id == x.oldteam_id) is None: + if x.oldteam_id not in found_team_ids: raise HTTPException( status_code=404, detail=f"Team ID {x.oldteam_id} not found" ) - if Team.get_or_none(Team.id == x.newteam_id) is None: + if x.newteam_id not in found_team_ids: raise HTTPException( status_code=404, detail=f"Team ID {x.newteam_id} not found" ) - if Player.get_or_none(Player.id == x.player_id) is None: + if x.player_id not in found_player_ids: raise HTTPException( status_code=404, detail=f"Player ID {x.player_id} not found" ) diff --git a/app/routers_v3/views.py b/app/routers_v3/views.py index ba77a0d..c658262 100644 --- a/app/routers_v3/views.py +++ b/app/routers_v3/views.py @@ -3,239 +3,331 @@ from typing import List, Literal, Optional import logging import pydantic -from ..db_engine import SeasonBattingStats, SeasonPitchingStats, db, Manager, Team, Current, model_to_dict, fn, query_to_csv, StratPlay, StratGame -from ..dependencies import add_cache_headers, cache_result, oauth2_scheme, valid_token, PRIVATE_IN_SCHEMA, handle_db_errors, update_season_batting_stats, update_season_pitching_stats, get_cache_stats - -logger = logging.getLogger('discord_app') - -router = APIRouter( - prefix='/api/v3/views', - tags=['views'] +from ..db_engine import ( + SeasonBattingStats, + SeasonPitchingStats, + db, + Manager, + Team, + Current, + model_to_dict, + fn, + query_to_csv, + StratPlay, + StratGame, +) +from ..dependencies import ( + add_cache_headers, + cache_result, + oauth2_scheme, + valid_token, + PRIVATE_IN_SCHEMA, + handle_db_errors, + update_season_batting_stats, + update_season_pitching_stats, + get_cache_stats, ) -@router.get('/season-stats/batting') +logger = logging.getLogger("discord_app") + +router = APIRouter(prefix="/api/v3/views", tags=["views"]) + + +@router.get("/season-stats/batting") @handle_db_errors -@add_cache_headers(max_age=10*60) -@cache_result(ttl=5*60, key_prefix='season-batting') +@add_cache_headers(max_age=10 * 60) +@cache_result(ttl=5 * 60, key_prefix="season-batting") async def get_season_batting_stats( season: Optional[int] = None, team_id: Optional[int] = None, player_id: Optional[int] = None, sbaplayer_id: Optional[int] = None, min_pa: Optional[int] = None, # Minimum plate appearances - sort_by: str = "woba", # Default sort field - sort_order: Literal['asc', 'desc'] = 'desc', # asc or desc + sort_by: Literal[ + "pa", + "ab", + "run", + "hit", + "double", + "triple", + "homerun", + "rbi", + "bb", + "so", + "bphr", + "bpfo", + "bp1b", + "bplo", + "gidp", + "hbp", + "sac", + "ibb", + "avg", + "obp", + "slg", + "ops", + "woba", + "k_pct", + "sb", + "cs", + ] = "woba", # Sort field + sort_order: Literal["asc", "desc"] = "desc", # asc or desc limit: Optional[int] = 200, offset: int = 0, - csv: Optional[bool] = False + csv: Optional[bool] = False, ): - logger.info(f'Getting season {season} batting stats - team_id: {team_id}, player_id: {player_id}, min_pa: {min_pa}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}') + logger.info( + f"Getting season {season} batting stats - team_id: {team_id}, player_id: {player_id}, min_pa: {min_pa}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}" + ) # Use the enhanced get_top_hitters method query = SeasonBattingStats.get_top_hitters( season=season, stat=sort_by, limit=limit if limit != 0 else None, - desc=(sort_order.lower() == 'desc'), + desc=(sort_order.lower() == "desc"), team_id=team_id, player_id=player_id, sbaplayer_id=sbaplayer_id, min_pa=min_pa, - offset=offset + offset=offset, ) - + # Build applied filters for response applied_filters = {} if season is not None: - applied_filters['season'] = season + applied_filters["season"] = season if team_id is not None: - applied_filters['team_id'] = team_id + applied_filters["team_id"] = team_id if player_id is not None: - applied_filters['player_id'] = player_id + applied_filters["player_id"] = player_id if min_pa is not None: - applied_filters['min_pa'] = min_pa - + applied_filters["min_pa"] = min_pa + if csv: return_val = query_to_csv(query) - return Response(content=return_val, media_type='text/csv') + return Response(content=return_val, media_type="text/csv") else: stat_list = [model_to_dict(stat) for stat in query] - return { - 'count': len(stat_list), - 'filters': applied_filters, - 'stats': stat_list - } + return {"count": len(stat_list), "filters": applied_filters, "stats": stat_list} -@router.post('/season-stats/batting/refresh', include_in_schema=PRIVATE_IN_SCHEMA) +@router.post("/season-stats/batting/refresh", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def refresh_season_batting_stats( - season: int, - token: str = Depends(oauth2_scheme) + season: int, token: str = Depends(oauth2_scheme) ) -> dict: """ Refresh batting stats for all players in a specific season. Useful for full season updates. """ if not valid_token(token): - logger.warning(f'refresh_season_batting_stats - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"refresh_season_batting_stats - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") + + logger.info(f"Refreshing all batting stats for season {season}") - logger.info(f'Refreshing all batting stats for season {season}') - try: # Get all player IDs who have stratplay records in this season - batter_ids = [row.batter_id for row in - StratPlay.select(StratPlay.batter_id.distinct()) - .join(StratGame).where(StratGame.season == season)] - + batter_ids = [ + row.batter_id + for row in StratPlay.select(StratPlay.batter_id.distinct()) + .join(StratGame) + .where(StratGame.season == season) + ] + if batter_ids: update_season_batting_stats(batter_ids, season, db) - logger.info(f'Successfully refreshed {len(batter_ids)} players for season {season}') - + logger.info( + f"Successfully refreshed {len(batter_ids)} players for season {season}" + ) + return { - 'message': f'Season {season} batting stats refreshed', - 'players_updated': len(batter_ids) + "message": f"Season {season} batting stats refreshed", + "players_updated": len(batter_ids), } else: - logger.warning(f'No batting data found for season {season}') + logger.warning(f"No batting data found for season {season}") return { - 'message': f'No batting data found for season {season}', - 'players_updated': 0 + "message": f"No batting data found for season {season}", + "players_updated": 0, } - + except Exception as e: - logger.error(f'Error refreshing season {season}: {e}') - raise HTTPException(status_code=500, detail=f'Refresh failed: {str(e)}') + logger.error(f"Error refreshing season {season}: {e}") + raise HTTPException(status_code=500, detail=f"Refresh failed: {str(e)}") -@router.get('/season-stats/pitching') +@router.get("/season-stats/pitching") @handle_db_errors -@add_cache_headers(max_age=10*60) -@cache_result(ttl=5*60, key_prefix='season-pitching') +@add_cache_headers(max_age=10 * 60) +@cache_result(ttl=5 * 60, key_prefix="season-pitching") async def get_season_pitching_stats( season: Optional[int] = None, team_id: Optional[int] = None, player_id: Optional[int] = None, sbaplayer_id: Optional[int] = None, min_outs: Optional[int] = None, # Minimum outs pitched - sort_by: str = "era", # Default sort field - sort_order: Literal['asc', 'desc'] = 'asc', # asc or desc (asc default for ERA) + sort_by: Literal[ + "tbf", + "outs", + "games", + "gs", + "win", + "loss", + "hold", + "saves", + "bsave", + "ir", + "irs", + "ab", + "run", + "e_run", + "hits", + "double", + "triple", + "homerun", + "bb", + "so", + "hbp", + "sac", + "ibb", + "gidp", + "sb", + "cs", + "bphr", + "bpfo", + "bp1b", + "bplo", + "wp", + "balk", + "wpa", + "era", + "whip", + "avg", + "obp", + "slg", + "ops", + "woba", + "hper9", + "kper9", + "bbper9", + "kperbb", + "lob_2outs", + "rbipercent", + "re24", + ] = "era", # Sort field + sort_order: Literal["asc", "desc"] = "asc", # asc or desc (asc default for ERA) limit: Optional[int] = 200, offset: int = 0, - csv: Optional[bool] = False + csv: Optional[bool] = False, ): - logger.info(f'Getting season {season} pitching stats - team_id: {team_id}, player_id: {player_id}, min_outs: {min_outs}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}') + logger.info( + f"Getting season {season} pitching stats - team_id: {team_id}, player_id: {player_id}, min_outs: {min_outs}, sort_by: {sort_by}, sort_order: {sort_order}, limit: {limit}, offset: {offset}" + ) # Use the get_top_pitchers method query = SeasonPitchingStats.get_top_pitchers( season=season, stat=sort_by, limit=limit if limit != 0 else None, - desc=(sort_order.lower() == 'desc'), + desc=(sort_order.lower() == "desc"), team_id=team_id, player_id=player_id, sbaplayer_id=sbaplayer_id, min_outs=min_outs, - offset=offset + offset=offset, ) - + # Build applied filters for response applied_filters = {} if season is not None: - applied_filters['season'] = season + applied_filters["season"] = season if team_id is not None: - applied_filters['team_id'] = team_id + applied_filters["team_id"] = team_id if player_id is not None: - applied_filters['player_id'] = player_id + applied_filters["player_id"] = player_id if min_outs is not None: - applied_filters['min_outs'] = min_outs - + applied_filters["min_outs"] = min_outs + if csv: return_val = query_to_csv(query) - return Response(content=return_val, media_type='text/csv') + return Response(content=return_val, media_type="text/csv") else: stat_list = [model_to_dict(stat) for stat in query] - return { - 'count': len(stat_list), - 'filters': applied_filters, - 'stats': stat_list - } + return {"count": len(stat_list), "filters": applied_filters, "stats": stat_list} -@router.post('/season-stats/pitching/refresh', include_in_schema=PRIVATE_IN_SCHEMA) +@router.post("/season-stats/pitching/refresh", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors async def refresh_season_pitching_stats( - season: int, - token: str = Depends(oauth2_scheme) + season: int, token: str = Depends(oauth2_scheme) ) -> dict: """ Refresh pitching statistics for a specific season by aggregating from individual games. Private endpoint - not included in public API documentation. """ if not valid_token(token): - logger.warning(f'refresh_season_batting_stats - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"refresh_season_batting_stats - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") + + logger.info(f"Refreshing season {season} pitching stats") - logger.info(f'Refreshing season {season} pitching stats') - try: # Get all pitcher IDs for this season pitcher_query = ( - StratPlay - .select(StratPlay.pitcher_id) + StratPlay.select(StratPlay.pitcher_id) .join(StratGame, on=(StratPlay.game_id == StratGame.id)) .where((StratGame.season == season) & (StratPlay.pitcher_id.is_null(False))) .distinct() ) pitcher_ids = [row.pitcher_id for row in pitcher_query] - + if not pitcher_ids: - logger.warning(f'No pitchers found for season {season}') + logger.warning(f"No pitchers found for season {season}") return { - 'status': 'success', - 'message': f'No pitchers found for season {season}', - 'players_updated': 0 + "status": "success", + "message": f"No pitchers found for season {season}", + "players_updated": 0, } - + # Use the dependency function to update pitching stats update_season_pitching_stats(pitcher_ids, season, db) - - logger.info(f'Season {season} pitching stats refreshed successfully - {len(pitcher_ids)} players updated') + + logger.info( + f"Season {season} pitching stats refreshed successfully - {len(pitcher_ids)} players updated" + ) return { - 'status': 'success', - 'message': f'Season {season} pitching stats refreshed', - 'players_updated': len(pitcher_ids) + "status": "success", + "message": f"Season {season} pitching stats refreshed", + "players_updated": len(pitcher_ids), } - + except Exception as e: - logger.error(f'Error refreshing season {season} pitching stats: {e}') - raise HTTPException(status_code=500, detail=f'Refresh failed: {str(e)}') + logger.error(f"Error refreshing season {season} pitching stats: {e}") + raise HTTPException(status_code=500, detail=f"Refresh failed: {str(e)}") -@router.get('/admin/cache', include_in_schema=PRIVATE_IN_SCHEMA) +@router.get("/admin/cache", include_in_schema=PRIVATE_IN_SCHEMA) @handle_db_errors -async def get_admin_cache_stats( - token: str = Depends(oauth2_scheme) -) -> dict: +async def get_admin_cache_stats(token: str = Depends(oauth2_scheme)) -> dict: """ Get Redis cache statistics and status. Private endpoint - requires authentication. """ if not valid_token(token): - logger.warning(f'get_admin_cache_stats - Bad Token: {token}') - raise HTTPException(status_code=401, detail='Unauthorized') + logger.warning(f"get_admin_cache_stats - Bad Token: {token}") + raise HTTPException(status_code=401, detail="Unauthorized") + + logger.info("Getting cache statistics") - logger.info('Getting cache statistics') - try: cache_stats = get_cache_stats() - logger.info(f'Cache stats retrieved: {cache_stats}') - return { - 'status': 'success', - 'cache_info': cache_stats - } - + logger.info(f"Cache stats retrieved: {cache_stats}") + return {"status": "success", "cache_info": cache_stats} + except Exception as e: - logger.error(f'Error getting cache stats: {e}') - raise HTTPException(status_code=500, detail=f'Failed to get cache stats: {str(e)}') + logger.error(f"Error getting cache stats: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to get cache stats: {str(e)}" + ) diff --git a/app/services/player_service.py b/app/services/player_service.py index 23ab8d8..bddea6a 100644 --- a/app/services/player_service.py +++ b/app/services/player_service.py @@ -39,7 +39,7 @@ class PlayerService(BaseService): cache_patterns = ["players*", "players-search*", "player*", "team-roster*"] # Deprecated fields to exclude from player responses - EXCLUDED_FIELDS = ['pitcher_injury'] + EXCLUDED_FIELDS = ["pitcher_injury"] # Class-level repository for dependency injection _injected_repo: Optional[AbstractPlayerRepository] = None @@ -135,17 +135,21 @@ class PlayerService(BaseService): # Apply sorting query = cls._apply_player_sort(query, sort) - # Convert to list of dicts - players_data = cls._query_to_player_dicts(query, short_output) - - # Store total count before pagination - total_count = len(players_data) - - # Apply pagination (offset and limit) - if offset is not None: - players_data = players_data[offset:] - if limit is not None: - players_data = players_data[:limit] + # Apply pagination at DB level for real queries, Python level for mocks + if isinstance(query, InMemoryQueryResult): + total_count = len(query) + players_data = cls._query_to_player_dicts(query, short_output) + if offset is not None: + players_data = players_data[offset:] + if limit is not None: + players_data = players_data[:limit] + else: + total_count = query.count() + if offset is not None: + query = query.offset(offset) + if limit is not None: + query = query.limit(limit) + players_data = cls._query_to_player_dicts(query, short_output) # Return format if as_csv: @@ -154,7 +158,7 @@ class PlayerService(BaseService): return { "count": len(players_data), "total": total_count, - "players": players_data + "players": players_data, } except Exception as e: @@ -204,9 +208,9 @@ class PlayerService(BaseService): p_list = [x.upper() for x in pos] # Expand generic "P" to match all pitcher positions - pitcher_positions = ['SP', 'RP', 'CP'] - if 'P' in p_list: - p_list.remove('P') + pitcher_positions = ["SP", "RP", "CP"] + if "P" in p_list: + p_list.remove("P") p_list.extend(pitcher_positions) pos_conditions = ( @@ -245,9 +249,9 @@ class PlayerService(BaseService): p_list = [p.upper() for p in pos] # Expand generic "P" to match all pitcher positions - pitcher_positions = ['SP', 'RP', 'CP'] - if 'P' in p_list: - p_list.remove('P') + pitcher_positions = ["SP", "RP", "CP"] + if "P" in p_list: + p_list.remove("P") p_list.extend(pitcher_positions) player_pos = [ @@ -385,19 +389,23 @@ class PlayerService(BaseService): # This filters at the database level instead of loading all players if search_all_seasons: # Search all seasons, order by season DESC (newest first) - query = (Player.select() - .where(fn.Lower(Player.name).contains(query_lower)) - .order_by(Player.season.desc(), Player.name) - .limit(limit * 2)) # Get extra for exact match sorting + query = ( + Player.select() + .where(fn.Lower(Player.name).contains(query_lower)) + .order_by(Player.season.desc(), Player.name) + .limit(limit * 2) + ) # Get extra for exact match sorting else: # Search specific season - query = (Player.select() - .where( - (Player.season == season) & - (fn.Lower(Player.name).contains(query_lower)) - ) - .order_by(Player.name) - .limit(limit * 2)) # Get extra for exact match sorting + query = ( + Player.select() + .where( + (Player.season == season) + & (fn.Lower(Player.name).contains(query_lower)) + ) + .order_by(Player.name) + .limit(limit * 2) + ) # Get extra for exact match sorting # Execute query and convert limited results to dicts players = list(query) @@ -468,19 +476,29 @@ class PlayerService(BaseService): # Use backrefs=False to avoid circular reference issues player_dict = model_to_dict(player, recurse=recurse, backrefs=False) # Filter out excluded fields - return {k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS} + return { + k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS + } except (ImportError, AttributeError, TypeError) as e: # Log the error and fall back to non-recursive serialization - logger.warning(f"Error in recursive player serialization: {e}, falling back to non-recursive") + logger.warning( + f"Error in recursive player serialization: {e}, falling back to non-recursive" + ) try: # Fallback to non-recursive serialization player_dict = model_to_dict(player, recurse=False) - return {k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS} + return { + k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS + } except Exception as fallback_error: # Final fallback to basic dict conversion - logger.error(f"Error in non-recursive serialization: {fallback_error}, using basic dict") + logger.error( + f"Error in non-recursive serialization: {fallback_error}, using basic dict" + ) player_dict = dict(player) - return {k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS} + return { + k: v for k, v in player_dict.items() if k not in cls.EXCLUDED_FIELDS + } @classmethod def update_player( @@ -508,6 +526,8 @@ class PlayerService(BaseService): raise HTTPException( status_code=500, detail=f"Error updating player {player_id}: {str(e)}" ) + finally: + temp_service.invalidate_related_cache(cls.cache_patterns) @classmethod def patch_player( @@ -535,6 +555,8 @@ class PlayerService(BaseService): raise HTTPException( status_code=500, detail=f"Error patching player {player_id}: {str(e)}" ) + finally: + temp_service.invalidate_related_cache(cls.cache_patterns) @classmethod def create_players( @@ -567,6 +589,8 @@ class PlayerService(BaseService): raise HTTPException( status_code=500, detail=f"Error creating players: {str(e)}" ) + finally: + temp_service.invalidate_related_cache(cls.cache_patterns) @classmethod def delete_player(cls, player_id: int, token: str) -> Dict[str, str]: @@ -590,6 +614,8 @@ class PlayerService(BaseService): raise HTTPException( status_code=500, detail=f"Error deleting player {player_id}: {str(e)}" ) + finally: + temp_service.invalidate_related_cache(cls.cache_patterns) @classmethod def _format_player_csv(cls, players: List[Dict]) -> str: @@ -603,12 +629,12 @@ class PlayerService(BaseService): flat_player = player.copy() # Flatten team object to just abbreviation - if isinstance(flat_player.get('team'), dict): - flat_player['team'] = flat_player['team'].get('abbrev', '') + if isinstance(flat_player.get("team"), dict): + flat_player["team"] = flat_player["team"].get("abbrev", "") # Flatten sbaplayer object to just ID - if isinstance(flat_player.get('sbaplayer'), dict): - flat_player['sbaplayer'] = flat_player['sbaplayer'].get('id', '') + if isinstance(flat_player.get("sbaplayer"), dict): + flat_player["sbaplayer"] = flat_player["sbaplayer"].get("id", "") flattened_players.append(flat_player) diff --git a/tests/unit/test_player_service.py b/tests/unit/test_player_service.py index 8a1b00a..6555757 100644 --- a/tests/unit/test_player_service.py +++ b/tests/unit/test_player_service.py @@ -7,21 +7,18 @@ import pytest from unittest.mock import MagicMock, patch import sys import os + sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from app.services.player_service import PlayerService from app.services.base import ServiceConfig -from app.services.mocks import ( - MockPlayerRepository, - MockCacheService, - EnhancedMockCache -) - +from app.services.mocks import MockPlayerRepository, MockCacheService, EnhancedMockCache # ============================================================================ # FIXTURES # ============================================================================ + @pytest.fixture def cache(): """Create fresh cache for each test.""" @@ -32,20 +29,73 @@ def cache(): def repo(cache): """Create fresh repo with test data.""" repo = MockPlayerRepository() - + # Add test players players = [ - {'id': 1, 'name': 'Mike Trout', 'wara': 5.2, 'team_id': 1, 'season': 10, 'pos_1': 'CF', 'pos_2': 'LF', 'strat_code': 'Elite', 'injury_rating': 'A'}, - {'id': 2, 'name': 'Aaron Judge', 'wara': 4.8, 'team_id': 2, 'season': 10, 'pos_1': 'RF', 'strat_code': 'Power', 'injury_rating': 'B'}, - {'id': 3, 'name': 'Mookie Betts', 'wara': 5.5, 'team_id': 3, 'season': 10, 'pos_1': 'RF', 'pos_2': '2B', 'strat_code': 'Elite', 'injury_rating': 'A'}, - {'id': 4, 'name': 'Injured Player', 'wara': 2.0, 'team_id': 1, 'season': 10, 'pos_1': 'P', 'il_return': 'Week 5', 'injury_rating': 'C'}, - {'id': 5, 'name': 'Old Player', 'wara': 1.0, 'team_id': 1, 'season': 5, 'pos_1': '1B'}, - {'id': 6, 'name': 'Juan Soto', 'wara': 4.5, 'team_id': 2, 'season': 10, 'pos_1': '1B', 'strat_code': 'Contact'}, + { + "id": 1, + "name": "Mike Trout", + "wara": 5.2, + "team_id": 1, + "season": 10, + "pos_1": "CF", + "pos_2": "LF", + "strat_code": "Elite", + "injury_rating": "A", + }, + { + "id": 2, + "name": "Aaron Judge", + "wara": 4.8, + "team_id": 2, + "season": 10, + "pos_1": "RF", + "strat_code": "Power", + "injury_rating": "B", + }, + { + "id": 3, + "name": "Mookie Betts", + "wara": 5.5, + "team_id": 3, + "season": 10, + "pos_1": "RF", + "pos_2": "2B", + "strat_code": "Elite", + "injury_rating": "A", + }, + { + "id": 4, + "name": "Injured Player", + "wara": 2.0, + "team_id": 1, + "season": 10, + "pos_1": "P", + "il_return": "Week 5", + "injury_rating": "C", + }, + { + "id": 5, + "name": "Old Player", + "wara": 1.0, + "team_id": 1, + "season": 5, + "pos_1": "1B", + }, + { + "id": 6, + "name": "Juan Soto", + "wara": 4.5, + "team_id": 2, + "season": 10, + "pos_1": "1B", + "strat_code": "Contact", + }, ] - + for player in players: repo.add_player(player) - + return repo @@ -60,463 +110,453 @@ def service(repo, cache): # TEST CLASSES # ============================================================================ + class TestPlayerServiceGetPlayers: """Tests for get_players method - 50+ lines covered.""" - + def test_get_all_season_players(self, service, repo): """Get all players for a season.""" result = service.get_players(season=10) - - assert result['count'] >= 5 # We have 5 season 10 players - assert len(result['players']) >= 5 - assert all(p.get('season') == 10 for p in result['players']) - + + assert result["count"] >= 5 # We have 5 season 10 players + assert len(result["players"]) >= 5 + assert all(p.get("season") == 10 for p in result["players"]) + def test_filter_by_single_team(self, service): """Filter by single team ID.""" result = service.get_players(season=10, team_id=[1]) - - assert result['count'] >= 1 - assert all(p.get('team_id') == 1 for p in result['players']) - + + assert result["count"] >= 1 + assert all(p.get("team_id") == 1 for p in result["players"]) + def test_filter_by_multiple_teams(self, service): """Filter by multiple team IDs.""" result = service.get_players(season=10, team_id=[1, 2]) - - assert result['count'] >= 2 - assert all(p.get('team_id') in [1, 2] for p in result['players']) - + + assert result["count"] >= 2 + assert all(p.get("team_id") in [1, 2] for p in result["players"]) + def test_filter_by_position(self, service): """Filter by position.""" - result = service.get_players(season=10, pos=['CF']) - - assert result['count'] >= 1 - assert any(p.get('pos_1') == 'CF' or p.get('pos_2') == 'CF' for p in result['players']) - + result = service.get_players(season=10, pos=["CF"]) + + assert result["count"] >= 1 + assert any( + p.get("pos_1") == "CF" or p.get("pos_2") == "CF" for p in result["players"] + ) + def test_filter_by_strat_code(self, service): """Filter by strat code.""" - result = service.get_players(season=10, strat_code=['Elite']) - - assert result['count'] >= 2 # Trout and Betts - assert all('Elite' in str(p.get('strat_code', '')) for p in result['players']) - + result = service.get_players(season=10, strat_code=["Elite"]) + + assert result["count"] >= 2 # Trout and Betts + assert all("Elite" in str(p.get("strat_code", "")) for p in result["players"]) + def test_filter_injured_only(self, service): """Filter injured players only.""" result = service.get_players(season=10, is_injured=True) - - assert result['count'] >= 1 - assert all(p.get('il_return') is not None for p in result['players']) - + + assert result["count"] >= 1 + assert all(p.get("il_return") is not None for p in result["players"]) + def test_sort_cost_ascending(self, service): """Sort by WARA ascending.""" - result = service.get_players(season=10, sort='cost-asc') - - wara = [p.get('wara', 0) for p in result['players']] + result = service.get_players(season=10, sort="cost-asc") + + wara = [p.get("wara", 0) for p in result["players"]] assert wara == sorted(wara) - + def test_sort_cost_descending(self, service): """Sort by WARA descending.""" - result = service.get_players(season=10, sort='cost-desc') - - wara = [p.get('wara', 0) for p in result['players']] + result = service.get_players(season=10, sort="cost-desc") + + wara = [p.get("wara", 0) for p in result["players"]] assert wara == sorted(wara, reverse=True) - + def test_sort_name_ascending(self, service): """Sort by name ascending.""" - result = service.get_players(season=10, sort='name-asc') - - names = [p.get('name', '') for p in result['players']] + result = service.get_players(season=10, sort="name-asc") + + names = [p.get("name", "") for p in result["players"]] assert names == sorted(names) - + def test_sort_name_descending(self, service): """Sort by name descending.""" - result = service.get_players(season=10, sort='name-desc') - - names = [p.get('name', '') for p in result['players']] + result = service.get_players(season=10, sort="name-desc") + + names = [p.get("name", "") for p in result["players"]] assert names == sorted(names, reverse=True) class TestPlayerServiceSearch: """Tests for search_players method.""" - + def test_exact_name_match(self, service): """Search with exact name match.""" - result = service.search_players('Mike Trout', season=10) - - assert result['count'] >= 1 - names = [p.get('name') for p in result['players']] - assert 'Mike Trout' in names - + result = service.search_players("Mike Trout", season=10) + + assert result["count"] >= 1 + names = [p.get("name") for p in result["players"]] + assert "Mike Trout" in names + def test_partial_name_match(self, service): """Search with partial name match.""" - result = service.search_players('Trout', season=10) - - assert result['count'] >= 1 - assert any('Trout' in p.get('name', '') for p in result['players']) - + result = service.search_players("Trout", season=10) + + assert result["count"] >= 1 + assert any("Trout" in p.get("name", "") for p in result["players"]) + def test_case_insensitive_search(self, service): """Search is case insensitive.""" - result1 = service.search_players('MIKE', season=10) - result2 = service.search_players('mike', season=10) - - assert result1['count'] == result2['count'] - + result1 = service.search_players("MIKE", season=10) + result2 = service.search_players("mike", season=10) + + assert result1["count"] == result2["count"] + def test_search_all_seasons(self, service): """Search across all seasons.""" - result = service.search_players('Player', season=None) - + result = service.search_players("Player", season=None) + # Should find both current and old players - assert result['all_seasons'] == True - assert result['count'] >= 2 - + assert result["all_seasons"] == True + assert result["count"] >= 2 + def test_search_limit(self, service): """Limit search results.""" - result = service.search_players('a', season=10, limit=2) - - assert result['count'] <= 2 - + result = service.search_players("a", season=10, limit=2) + + assert result["count"] <= 2 + def test_search_no_results(self, service): """Search returns empty when no matches.""" - result = service.search_players('XYZ123NotExist', season=10) - - assert result['count'] == 0 - assert result['players'] == [] + result = service.search_players("XYZ123NotExist", season=10) + + assert result["count"] == 0 + assert result["players"] == [] class TestPlayerServiceGetPlayer: """Tests for get_player method.""" - + def test_get_existing_player(self, service): """Get existing player by ID.""" result = service.get_player(1) - + assert result is not None - assert result.get('id') == 1 - assert result.get('name') == 'Mike Trout' - + assert result.get("id") == 1 + assert result.get("name") == "Mike Trout" + def test_get_nonexistent_player(self, service): """Get player that doesn't exist.""" result = service.get_player(99999) - + assert result is None - + def test_get_player_short_output(self, service): """Get player with short output.""" result = service.get_player(1, short_output=True) - + # Should still have basic fields - assert result.get('id') == 1 - assert result.get('name') == 'Mike Trout' + assert result.get("id") == 1 + assert result.get("name") == "Mike Trout" class TestPlayerServiceCreate: """Tests for create_players method.""" - + def test_create_single_player(self, repo, cache): """Create a single new player.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - - new_player = [{ - 'name': 'New Player', - 'wara': 3.0, - 'team_id': 1, - 'season': 10, - 'pos_1': 'SS' - }] - + + new_player = [ + { + "name": "New Player", + "wara": 3.0, + "team_id": 1, + "season": 10, + "pos_1": "SS", + } + ] + # Mock auth - with patch.object(service, 'require_auth', return_value=True): - result = service.create_players(new_player, 'valid_token') - - assert 'Inserted' in str(result) + with patch.object(service, "require_auth", return_value=True): + result = service.create_players(new_player, "valid_token") + + assert "Inserted" in str(result) # Verify player was added (ID 7 since fixture has players 1-6) player = repo.get_by_id(7) # Next ID after fixture data assert player is not None - assert player['name'] == 'New Player' - + assert player["name"] == "New Player" + def test_create_multiple_players(self, repo, cache): """Create multiple new players.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - + new_players = [ - {'name': 'Player A', 'wara': 2.0, 'team_id': 1, 'season': 10, 'pos_1': '2B'}, - {'name': 'Player B', 'wara': 2.5, 'team_id': 2, 'season': 10, 'pos_1': '3B'}, + { + "name": "Player A", + "wara": 2.0, + "team_id": 1, + "season": 10, + "pos_1": "2B", + }, + { + "name": "Player B", + "wara": 2.5, + "team_id": 2, + "season": 10, + "pos_1": "3B", + }, ] - - with patch.object(service, 'require_auth', return_value=True): - result = service.create_players(new_players, 'valid_token') - - assert 'Inserted 2 players' in str(result) - + + with patch.object(service, "require_auth", return_value=True): + result = service.create_players(new_players, "valid_token") + + assert "Inserted 2 players" in str(result) + def test_create_duplicate_fails(self, repo, cache): """Creating duplicate player should fail.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - - duplicate = [{'name': 'Mike Trout', 'wara': 5.0, 'team_id': 1, 'season': 10, 'pos_1': 'CF'}] - - with patch.object(service, 'require_auth', return_value=True): + + duplicate = [ + { + "name": "Mike Trout", + "wara": 5.0, + "team_id": 1, + "season": 10, + "pos_1": "CF", + } + ] + + with patch.object(service, "require_auth", return_value=True): with pytest.raises(Exception) as exc_info: - service.create_players(duplicate, 'valid_token') - - assert 'already exists' in str(exc_info.value) - + service.create_players(duplicate, "valid_token") + + assert "already exists" in str(exc_info.value) + def test_create_requires_auth(self, repo, cache): """Creating players requires authentication.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - - new_player = [{'name': 'Test', 'wara': 1.0, 'team_id': 1, 'season': 10, 'pos_1': 'P'}] - + + new_player = [ + {"name": "Test", "wara": 1.0, "team_id": 1, "season": 10, "pos_1": "P"} + ] + with pytest.raises(Exception) as exc_info: - service.create_players(new_player, 'bad_token') - + service.create_players(new_player, "bad_token") + assert exc_info.value.status_code == 401 class TestPlayerServiceUpdate: """Tests for update_player and patch_player methods.""" - + def test_patch_player_name(self, repo, cache): """Patch player's name.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - - with patch.object(service, 'require_auth', return_value=True): - result = service.patch_player(1, {'name': 'New Name'}, 'valid_token') - + + with patch.object(service, "require_auth", return_value=True): + result = service.patch_player(1, {"name": "New Name"}, "valid_token") + assert result is not None - assert result.get('name') == 'New Name' - + assert result.get("name") == "New Name" + def test_patch_player_wara(self, repo, cache): """Patch player's WARA.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - - with patch.object(service, 'require_auth', return_value=True): - result = service.patch_player(1, {'wara': 6.0}, 'valid_token') - - assert result.get('wara') == 6.0 - + + with patch.object(service, "require_auth", return_value=True): + result = service.patch_player(1, {"wara": 6.0}, "valid_token") + + assert result.get("wara") == 6.0 + def test_patch_multiple_fields(self, repo, cache): """Patch multiple fields at once.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - - updates = { - 'name': 'Updated Name', - 'wara': 7.0, - 'strat_code': 'Super Elite' - } - - with patch.object(service, 'require_auth', return_value=True): - result = service.patch_player(1, updates, 'valid_token') - - assert result.get('name') == 'Updated Name' - assert result.get('wara') == 7.0 - assert result.get('strat_code') == 'Super Elite' - + + updates = {"name": "Updated Name", "wara": 7.0, "strat_code": "Super Elite"} + + with patch.object(service, "require_auth", return_value=True): + result = service.patch_player(1, updates, "valid_token") + + assert result.get("name") == "Updated Name" + assert result.get("wara") == 7.0 + assert result.get("strat_code") == "Super Elite" + def test_patch_nonexistent_player(self, repo, cache): """Patch fails for non-existent player.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - - with patch.object(service, 'require_auth', return_value=True): + + with patch.object(service, "require_auth", return_value=True): with pytest.raises(Exception) as exc_info: - service.patch_player(99999, {'name': 'Test'}, 'valid_token') - - assert 'not found' in str(exc_info.value) - + service.patch_player(99999, {"name": "Test"}, "valid_token") + + assert "not found" in str(exc_info.value) + def test_patch_requires_auth(self, repo, cache): """Patching requires authentication.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - + with pytest.raises(Exception) as exc_info: - service.patch_player(1, {'name': 'Test'}, 'bad_token') - + service.patch_player(1, {"name": "Test"}, "bad_token") + assert exc_info.value.status_code == 401 class TestPlayerServiceDelete: """Tests for delete_player method.""" - + def test_delete_player(self, repo, cache): """Delete existing player.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - + # Verify player exists assert repo.get_by_id(1) is not None - - with patch.object(service, 'require_auth', return_value=True): - result = service.delete_player(1, 'valid_token') - - assert 'deleted' in str(result) - + + with patch.object(service, "require_auth", return_value=True): + result = service.delete_player(1, "valid_token") + + assert "deleted" in str(result) + # Verify player is gone assert repo.get_by_id(1) is None - + def test_delete_nonexistent_player(self, repo, cache): """Delete fails for non-existent player.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - - with patch.object(service, 'require_auth', return_value=True): + + with patch.object(service, "require_auth", return_value=True): with pytest.raises(Exception) as exc_info: - service.delete_player(99999, 'valid_token') - - assert 'not found' in str(exc_info.value) - + service.delete_player(99999, "valid_token") + + assert "not found" in str(exc_info.value) + def test_delete_requires_auth(self, repo, cache): """Deleting requires authentication.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - + with pytest.raises(Exception) as exc_info: - service.delete_player(1, 'bad_token') - + service.delete_player(1, "bad_token") + assert exc_info.value.status_code == 401 -class TestPlayerServiceCache: - """Tests for cache functionality.""" - - @pytest.mark.skip(reason="Caching not yet implemented in service methods") - def test_cache_set_on_read(self, service, cache): - """Cache is set on player read.""" - service.get_players(season=10) - - assert cache.was_called('set') - - @pytest.mark.skip(reason="Caching not yet implemented in service methods") - def test_cache_invalidation_on_update(self, repo, cache): - """Cache is invalidated on player update.""" - config = ServiceConfig(player_repo=repo, cache=cache) - service = PlayerService(config=config) - - # Read to set cache - service.get_players(season=10) - initial_calls = len(cache.get_calls('set')) - - # Update should invalidate cache - with patch.object(service, 'require_auth', return_value=True): - service.patch_player(1, {'name': 'Test'}, 'valid_token') - - # Should have more delete calls after update - delete_calls = [c for c in cache.get_calls() if c.get('method') == 'delete'] - assert len(delete_calls) > 0 - - @pytest.mark.skip(reason="Caching not yet implemented in service methods") - def test_cache_hit_rate(self, repo, cache): - """Test cache hit rate tracking.""" - config = ServiceConfig(player_repo=repo, cache=cache) - service = PlayerService(config=config) - - # First call - cache miss - service.get_players(season=10) - miss_count = cache._miss_count - - # Second call - cache hit - service.get_players(season=10) - - # Hit rate should have improved - assert cache.hit_rate > 0 - - class TestPlayerServiceValidation: """Tests for input validation and edge cases.""" - + def test_invalid_season_returns_empty(self, service): """Invalid season returns empty result.""" result = service.get_players(season=999) - - assert result['count'] == 0 or result['players'] == [] - + + assert result["count"] == 0 or result["players"] == [] + def test_empty_search_returns_all(self, service): """Empty search query returns all players.""" - result = service.search_players('', season=10) - - assert result['count'] >= 1 - + result = service.search_players("", season=10) + + assert result["count"] >= 1 + def test_sort_with_no_results(self, service): """Sorting with no results doesn't error.""" - result = service.get_players(season=999, sort='cost-desc') - - assert result['count'] == 0 or result['players'] == [] - + result = service.get_players(season=999, sort="cost-desc") + + assert result["count"] == 0 or result["players"] == [] + def test_cache_clear_on_create(self, repo, cache): """Cache is cleared when new players are created.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - + # Set up some cache data - cache.set('test:key', 'value', 300) - - with patch.object(service, 'require_auth', return_value=True): - service.create_players([{ - 'name': 'New', - 'wara': 1.0, - 'team_id': 1, - 'season': 10, - 'pos_1': 'P' - }], 'valid_token') - + cache.set("test:key", "value", 300) + + with patch.object(service, "require_auth", return_value=True): + service.create_players( + [ + { + "name": "New", + "wara": 1.0, + "team_id": 1, + "season": 10, + "pos_1": "P", + } + ], + "valid_token", + ) + # Should have invalidate calls assert len(cache.get_calls()) > 0 class TestPlayerServiceIntegration: """Integration tests combining multiple operations.""" - + def test_full_crud_cycle(self, repo, cache): """Test complete CRUD cycle.""" config = ServiceConfig(player_repo=repo, cache=cache) service = PlayerService(config=config) - + # CREATE - with patch.object(service, 'require_auth', return_value=True): - create_result = service.create_players([{ - 'name': 'CRUD Test', - 'wara': 3.0, - 'team_id': 1, - 'season': 10, - 'pos_1': 'DH' - }], 'valid_token') - + with patch.object(service, "require_auth", return_value=True): + create_result = service.create_players( + [ + { + "name": "CRUD Test", + "wara": 3.0, + "team_id": 1, + "season": 10, + "pos_1": "DH", + } + ], + "valid_token", + ) + # READ - search_result = service.search_players('CRUD', season=10) - assert search_result['count'] >= 1 - - player_id = search_result['players'][0].get('id') - + search_result = service.search_players("CRUD", season=10) + assert search_result["count"] >= 1 + + player_id = search_result["players"][0].get("id") + # UPDATE - with patch.object(service, 'require_auth', return_value=True): - update_result = service.patch_player(player_id, {'wara': 4.0}, 'valid_token') - assert update_result.get('wara') == 4.0 - + with patch.object(service, "require_auth", return_value=True): + update_result = service.patch_player( + player_id, {"wara": 4.0}, "valid_token" + ) + assert update_result.get("wara") == 4.0 + # DELETE - with patch.object(service, 'require_auth', return_value=True): - delete_result = service.delete_player(player_id, 'valid_token') - assert 'deleted' in str(delete_result) - + with patch.object(service, "require_auth", return_value=True): + delete_result = service.delete_player(player_id, "valid_token") + assert "deleted" in str(delete_result) + # VERIFY DELETED get_result = service.get_player(player_id) assert get_result is None - + def test_search_then_filter(self, service): """Search and then filter operations.""" # First get all players all_result = service.get_players(season=10) - initial_count = all_result['count'] - + initial_count = all_result["count"] + # Then filter by team filtered = service.get_players(season=10, team_id=[1]) - + # Filtered should be <= all - assert filtered['count'] <= initial_count + assert filtered["count"] <= initial_count # ============================================================================