diff --git a/app/dependencies.py b/app/dependencies.py index b95747d..0d19626 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -11,8 +11,8 @@ from fastapi import HTTPException, Response from fastapi.security import OAuth2PasswordBearer from redis import Redis -date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' -logger = logging.getLogger('discord_app') +date = f"{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}" +logger = logging.getLogger("discord_app") # date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' # log_level = logger.info if os.environ.get('LOG_LEVEL') == 'INFO' else 'WARN' @@ -23,10 +23,10 @@ logger = logging.getLogger('discord_app') # ) # Redis configuration -REDIS_HOST = os.environ.get('REDIS_HOST', 'localhost') -REDIS_PORT = int(os.environ.get('REDIS_PORT', '6379')) -REDIS_DB = int(os.environ.get('REDIS_DB', '0')) -CACHE_ENABLED = os.environ.get('CACHE_ENABLED', 'true').lower() == 'true' +REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") +REDIS_PORT = int(os.environ.get("REDIS_PORT", "6379")) +REDIS_DB = int(os.environ.get("REDIS_DB", "0")) +CACHE_ENABLED = os.environ.get("CACHE_ENABLED", "true").lower() == "true" # Initialize Redis client with connection error handling if not CACHE_ENABLED: @@ -40,7 +40,7 @@ else: db=REDIS_DB, decode_responses=True, socket_connect_timeout=5, - socket_timeout=5 + socket_timeout=5, ) # Test connection redis_client.ping() @@ -50,12 +50,16 @@ else: redis_client = None oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -priv_help = False if not os.environ.get('PRIVATE_IN_SCHEMA') else os.environ.get('PRIVATE_IN_SCHEMA').upper() -PRIVATE_IN_SCHEMA = True if priv_help == 'TRUE' else False +priv_help = ( + False + if not os.environ.get("PRIVATE_IN_SCHEMA") + else os.environ.get("PRIVATE_IN_SCHEMA").upper() +) +PRIVATE_IN_SCHEMA = True if priv_help == "TRUE" else False def valid_token(token): - return token == os.environ.get('API_TOKEN') + return token == os.environ.get("API_TOKEN") def update_season_batting_stats(player_ids, season, db_connection): @@ -63,17 +67,19 @@ def update_season_batting_stats(player_ids, season, db_connection): Update season batting stats for specific players in a given season. Recalculates stats from stratplay data and upserts into seasonbattingstats table. """ - + if not player_ids: logger.warning("update_season_batting_stats called with empty player_ids list") return - + # Convert single player_id to list for consistency if isinstance(player_ids, int): player_ids = [player_ids] - - logger.info(f"Updating season batting stats for {len(player_ids)} players in season {season}") - + + logger.info( + f"Updating season batting stats for {len(player_ids)} players in season {season}" + ) + try: # SQL query to recalculate and upsert batting stats query = """ @@ -217,12 +223,14 @@ def update_season_batting_stats(player_ids, season, db_connection): sb = EXCLUDED.sb, cs = EXCLUDED.cs; """ - + # Execute the query with parameters using the passed database connection db_connection.execute_sql(query, [season, player_ids, season, player_ids]) - - logger.info(f"Successfully updated season batting stats for {len(player_ids)} players in season {season}") - + + logger.info( + f"Successfully updated season batting stats for {len(player_ids)} players in season {season}" + ) + except Exception as e: logger.error(f"Error updating season batting stats: {e}") raise @@ -233,17 +241,19 @@ def update_season_pitching_stats(player_ids, season, db_connection): Update season pitching stats for specific players in a given season. Recalculates stats from stratplay and decision data and upserts into seasonpitchingstats table. """ - + if not player_ids: logger.warning("update_season_pitching_stats called with empty player_ids list") return - + # Convert single player_id to list for consistency if isinstance(player_ids, int): player_ids = [player_ids] - - logger.info(f"Updating season pitching stats for {len(player_ids)} players in season {season}") - + + logger.info( + f"Updating season pitching stats for {len(player_ids)} players in season {season}" + ) + try: # SQL query to recalculate and upsert pitching stats query = """ @@ -460,12 +470,14 @@ def update_season_pitching_stats(player_ids, season, db_connection): rbipercent = EXCLUDED.rbipercent, re24 = EXCLUDED.re24; """ - + # Execute the query with parameters using the passed database connection db_connection.execute_sql(query, [season, player_ids, season, player_ids]) - - logger.info(f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}") - + + logger.info( + f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}" + ) + except Exception as e: logger.error(f"Error updating season pitching stats: {e}") raise @@ -474,26 +486,27 @@ def update_season_pitching_stats(player_ids, season, db_connection): def send_webhook_message(message: str) -> bool: """ Send a message to Discord via webhook. - + Args: message: The message content to send - + Returns: bool: True if successful, False otherwise """ - webhook_url = "https://discord.com/api/webhooks/1408811717424840876/7RXG_D5IqovA3Jwa9YOobUjVcVMuLc6cQyezABcWuXaHo5Fvz1en10M7J43o3OJ3bzGW" - + webhook_url = os.environ.get("DISCORD_WEBHOOK_URL") + if not webhook_url: + logger.error("DISCORD_WEBHOOK_URL environment variable is not set") + return False + try: - payload = { - "content": message - } - + payload = {"content": message} + response = requests.post(webhook_url, json=payload, timeout=10) response.raise_for_status() - + logger.info(f"Webhook message sent successfully: {message[:100]}...") return True - + except requests.exceptions.RequestException as e: logger.error(f"Failed to send webhook message: {e}") return False @@ -502,99 +515,106 @@ def send_webhook_message(message: str) -> bool: return False -def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True): +def cache_result( + ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True +): """ Decorator to cache function results in Redis with parameter normalization. - + Args: ttl: Time to live in seconds (default: 5 minutes) key_prefix: Prefix for cache keys (default: "api") normalize_params: Remove None/empty values to reduce cache variations (default: True) - + Usage: @cache_result(ttl=600, key_prefix="stats") async def get_player_stats(player_id: int, season: Optional[int] = None): # expensive operation return stats - + # These will use the same cache entry when normalize_params=True: # get_player_stats(123, None) and get_player_stats(123) """ + def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): # Skip caching if Redis is not available if redis_client is None: return await func(*args, **kwargs) - + try: # Normalize parameters to reduce cache variations normalized_kwargs = kwargs.copy() if normalize_params: # Remove None values and empty collections normalized_kwargs = { - k: v for k, v in kwargs.items() + k: v + for k, v in kwargs.items() if v is not None and v != [] and v != "" and v != {} } - + # Generate more readable cache key args_str = "_".join(str(arg) for arg in args if arg is not None) - kwargs_str = "_".join([ - f"{k}={v}" for k, v in sorted(normalized_kwargs.items()) - ]) - + kwargs_str = "_".join( + [f"{k}={v}" for k, v in sorted(normalized_kwargs.items())] + ) + # Combine args and kwargs for cache key key_parts = [key_prefix, func.__name__] if args_str: key_parts.append(args_str) if kwargs_str: key_parts.append(kwargs_str) - + cache_key = ":".join(key_parts) - + # Truncate very long cache keys to prevent Redis key size limits if len(cache_key) > 200: cache_key = f"{key_prefix}:{func.__name__}:{hash(cache_key)}" - + # Try to get from cache cached_result = redis_client.get(cache_key) if cached_result is not None: logger.debug(f"Cache hit for key: {cache_key}") return json.loads(cached_result) - + # Cache miss - execute function logger.debug(f"Cache miss for key: {cache_key}") result = await func(*args, **kwargs) - + # Skip caching for Response objects (like CSV downloads) as they can't be properly serialized if not isinstance(result, Response): # Store in cache with TTL redis_client.setex( - cache_key, - ttl, - json.dumps(result, default=str, ensure_ascii=False) + cache_key, + ttl, + json.dumps(result, default=str, ensure_ascii=False), ) else: - logger.debug(f"Skipping cache for Response object from {func.__name__}") - + logger.debug( + f"Skipping cache for Response object from {func.__name__}" + ) + return result - + except Exception as e: # If caching fails, log error and continue without caching logger.error(f"Cache error for {func.__name__}: {e}") return await func(*args, **kwargs) - + return wrapper + return decorator def invalidate_cache(pattern: str = "*"): """ Invalidate cache entries matching a pattern. - + Args: pattern: Redis pattern to match keys (default: "*" for all) - + Usage: invalidate_cache("stats:*") # Clear all stats cache invalidate_cache("api:get_player_*") # Clear specific player cache @@ -602,12 +622,14 @@ def invalidate_cache(pattern: str = "*"): if redis_client is None: logger.warning("Cannot invalidate cache: Redis not available") return 0 - + try: keys = redis_client.keys(pattern) if keys: deleted = redis_client.delete(*keys) - logger.info(f"Invalidated {deleted} cache entries matching pattern: {pattern}") + logger.info( + f"Invalidated {deleted} cache entries matching pattern: {pattern}" + ) return deleted else: logger.debug(f"No cache entries found matching pattern: {pattern}") @@ -620,13 +642,13 @@ def invalidate_cache(pattern: str = "*"): def get_cache_stats() -> dict: """ Get Redis cache statistics. - + Returns: dict: Cache statistics including memory usage, key count, etc. """ if redis_client is None: return {"status": "unavailable", "message": "Redis not connected"} - + try: info = redis_client.info() return { @@ -634,7 +656,7 @@ def get_cache_stats() -> dict: "memory_used": info.get("used_memory_human", "unknown"), "total_keys": redis_client.dbsize(), "connected_clients": info.get("connected_clients", 0), - "uptime_seconds": info.get("uptime_in_seconds", 0) + "uptime_seconds": info.get("uptime_in_seconds", 0), } except Exception as e: logger.error(f"Error getting cache stats: {e}") @@ -642,34 +664,35 @@ def get_cache_stats() -> dict: def add_cache_headers( - max_age: int = 300, + max_age: int = 300, cache_type: str = "public", vary: Optional[str] = None, - etag: bool = False + etag: bool = False, ): """ Decorator to add HTTP cache headers to FastAPI responses. - + Args: max_age: Cache duration in seconds (default: 5 minutes) cache_type: "public", "private", or "no-cache" (default: "public") vary: Vary header value (e.g., "Accept-Encoding, Authorization") etag: Whether to generate ETag based on response content - + Usage: @add_cache_headers(max_age=1800, cache_type="public") async def get_static_data(): return {"data": "static content"} - + @add_cache_headers(max_age=60, cache_type="private", vary="Authorization") async def get_user_data(): return {"data": "user specific"} """ + def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): result = await func(*args, **kwargs) - + # Handle different response types if isinstance(result, Response): response = result @@ -677,38 +700,41 @@ def add_cache_headers( # Convert to Response with JSON content response = Response( content=json.dumps(result, default=str, ensure_ascii=False), - media_type="application/json" + media_type="application/json", ) else: # Handle other response types response = Response(content=str(result)) - + # Build Cache-Control header cache_control_parts = [cache_type] if cache_type != "no-cache" and max_age > 0: cache_control_parts.append(f"max-age={max_age}") - + response.headers["Cache-Control"] = ", ".join(cache_control_parts) - + # Add Vary header if specified if vary: response.headers["Vary"] = vary - + # Add ETag if requested - if etag and (hasattr(result, '__dict__') or isinstance(result, (dict, list))): + if etag and ( + hasattr(result, "__dict__") or isinstance(result, (dict, list)) + ): content_hash = hashlib.md5( json.dumps(result, default=str, sort_keys=True).encode() ).hexdigest() response.headers["ETag"] = f'"{content_hash}"' - + # Add Last-Modified header with current time for dynamic content - response.headers["Last-Modified"] = datetime.datetime.now(datetime.timezone.utc).strftime( - "%a, %d %b %Y %H:%M:%S GMT" - ) - + response.headers["Last-Modified"] = datetime.datetime.now( + datetime.timezone.utc + ).strftime("%a, %d %b %Y %H:%M:%S GMT") + return response - + return wrapper + return decorator @@ -718,52 +744,59 @@ def handle_db_errors(func): Ensures proper cleanup of database connections and provides consistent error handling. Includes comprehensive logging with function context, timing, and stack traces. """ + @wraps(func) async def wrapper(*args, **kwargs): import time import traceback from .db_engine import db # Import here to avoid circular imports - + start_time = time.time() func_name = f"{func.__module__}.{func.__name__}" - + # Sanitize arguments for logging (exclude sensitive data) safe_args = [] safe_kwargs = {} - + try: # Log sanitized arguments (avoid logging tokens, passwords, etc.) for arg in args: - if hasattr(arg, '__dict__') and hasattr(arg, 'url'): # FastAPI Request object - safe_args.append(f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})") + if hasattr(arg, "__dict__") and hasattr( + arg, "url" + ): # FastAPI Request object + safe_args.append( + f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})" + ) else: safe_args.append(str(arg)[:100]) # Truncate long values - + for key, value in kwargs.items(): - if key.lower() in ['token', 'password', 'secret', 'key']: - safe_kwargs[key] = '[REDACTED]' + if key.lower() in ["token", "password", "secret", "key"]: + safe_kwargs[key] = "[REDACTED]" else: safe_kwargs[key] = str(value)[:100] # Truncate long values - - logger.info(f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}") - + + logger.info( + f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}" + ) + result = await func(*args, **kwargs) - + elapsed_time = time.time() - start_time logger.info(f"Completed {func_name} successfully in {elapsed_time:.3f}s") - + return result - + except Exception as e: elapsed_time = time.time() - start_time error_trace = traceback.format_exc() - + logger.error(f"Database error in {func_name} after {elapsed_time:.3f}s") logger.error(f"Function args: {safe_args}") logger.error(f"Function kwargs: {safe_kwargs}") logger.error(f"Exception: {str(e)}") logger.error(f"Full traceback:\n{error_trace}") - + try: logger.info(f"Attempting database rollback for {func_name}") db.rollback() @@ -775,8 +808,12 @@ def handle_db_errors(func): db.close() logger.info(f"Database connection closed for {func_name}") except Exception as close_error: - logger.error(f"Error closing database connection in {func_name}: {close_error}") - - raise HTTPException(status_code=500, detail=f'Database error in {func_name}: {str(e)}') - + logger.error( + f"Error closing database connection in {func_name}: {close_error}" + ) + + raise HTTPException( + status_code=500, detail=f"Database error in {func_name}: {str(e)}" + ) + return wrapper