fix: remove hardcoded Discord webhook URL from dependencies.py (#19) #56

Open
cal wants to merge 1 commits from ai/major-domo-database-19 into next-release

View File

@ -11,8 +11,8 @@ from fastapi import HTTPException, Response
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from redis import Redis from redis import Redis
date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' date = f"{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}"
logger = logging.getLogger('discord_app') logger = logging.getLogger("discord_app")
# date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' # date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
# log_level = logger.info if os.environ.get('LOG_LEVEL') == 'INFO' else 'WARN' # log_level = logger.info if os.environ.get('LOG_LEVEL') == 'INFO' else 'WARN'
@ -23,10 +23,10 @@ logger = logging.getLogger('discord_app')
# ) # )
# Redis configuration # Redis configuration
REDIS_HOST = os.environ.get('REDIS_HOST', 'localhost') REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
REDIS_PORT = int(os.environ.get('REDIS_PORT', '6379')) REDIS_PORT = int(os.environ.get("REDIS_PORT", "6379"))
REDIS_DB = int(os.environ.get('REDIS_DB', '0')) REDIS_DB = int(os.environ.get("REDIS_DB", "0"))
CACHE_ENABLED = os.environ.get('CACHE_ENABLED', 'true').lower() == 'true' CACHE_ENABLED = os.environ.get("CACHE_ENABLED", "true").lower() == "true"
# Initialize Redis client with connection error handling # Initialize Redis client with connection error handling
if not CACHE_ENABLED: if not CACHE_ENABLED:
@ -40,7 +40,7 @@ else:
db=REDIS_DB, db=REDIS_DB,
decode_responses=True, decode_responses=True,
socket_connect_timeout=5, socket_connect_timeout=5,
socket_timeout=5 socket_timeout=5,
) )
# Test connection # Test connection
redis_client.ping() redis_client.ping()
@ -50,12 +50,16 @@ else:
redis_client = None redis_client = None
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
priv_help = False if not os.environ.get('PRIVATE_IN_SCHEMA') else os.environ.get('PRIVATE_IN_SCHEMA').upper() priv_help = (
PRIVATE_IN_SCHEMA = True if priv_help == 'TRUE' else False False
if not os.environ.get("PRIVATE_IN_SCHEMA")
else os.environ.get("PRIVATE_IN_SCHEMA").upper()
)
PRIVATE_IN_SCHEMA = True if priv_help == "TRUE" else False
def valid_token(token): def valid_token(token):
return token == os.environ.get('API_TOKEN') return token == os.environ.get("API_TOKEN")
def update_season_batting_stats(player_ids, season, db_connection): def update_season_batting_stats(player_ids, season, db_connection):
@ -63,17 +67,19 @@ def update_season_batting_stats(player_ids, season, db_connection):
Update season batting stats for specific players in a given season. Update season batting stats for specific players in a given season.
Recalculates stats from stratplay data and upserts into seasonbattingstats table. Recalculates stats from stratplay data and upserts into seasonbattingstats table.
""" """
if not player_ids: if not player_ids:
logger.warning("update_season_batting_stats called with empty player_ids list") logger.warning("update_season_batting_stats called with empty player_ids list")
return return
# Convert single player_id to list for consistency # Convert single player_id to list for consistency
if isinstance(player_ids, int): if isinstance(player_ids, int):
player_ids = [player_ids] player_ids = [player_ids]
logger.info(f"Updating season batting stats for {len(player_ids)} players in season {season}") logger.info(
f"Updating season batting stats for {len(player_ids)} players in season {season}"
)
try: try:
# SQL query to recalculate and upsert batting stats # SQL query to recalculate and upsert batting stats
query = """ query = """
@ -217,12 +223,14 @@ def update_season_batting_stats(player_ids, season, db_connection):
sb = EXCLUDED.sb, sb = EXCLUDED.sb,
cs = EXCLUDED.cs; cs = EXCLUDED.cs;
""" """
# Execute the query with parameters using the passed database connection # Execute the query with parameters using the passed database connection
db_connection.execute_sql(query, [season, player_ids, season, player_ids]) db_connection.execute_sql(query, [season, player_ids, season, player_ids])
logger.info(f"Successfully updated season batting stats for {len(player_ids)} players in season {season}") logger.info(
f"Successfully updated season batting stats for {len(player_ids)} players in season {season}"
)
except Exception as e: except Exception as e:
logger.error(f"Error updating season batting stats: {e}") logger.error(f"Error updating season batting stats: {e}")
raise raise
@ -233,17 +241,19 @@ def update_season_pitching_stats(player_ids, season, db_connection):
Update season pitching stats for specific players in a given season. Update season pitching stats for specific players in a given season.
Recalculates stats from stratplay and decision data and upserts into seasonpitchingstats table. Recalculates stats from stratplay and decision data and upserts into seasonpitchingstats table.
""" """
if not player_ids: if not player_ids:
logger.warning("update_season_pitching_stats called with empty player_ids list") logger.warning("update_season_pitching_stats called with empty player_ids list")
return return
# Convert single player_id to list for consistency # Convert single player_id to list for consistency
if isinstance(player_ids, int): if isinstance(player_ids, int):
player_ids = [player_ids] player_ids = [player_ids]
logger.info(f"Updating season pitching stats for {len(player_ids)} players in season {season}") logger.info(
f"Updating season pitching stats for {len(player_ids)} players in season {season}"
)
try: try:
# SQL query to recalculate and upsert pitching stats # SQL query to recalculate and upsert pitching stats
query = """ query = """
@ -460,12 +470,14 @@ def update_season_pitching_stats(player_ids, season, db_connection):
rbipercent = EXCLUDED.rbipercent, rbipercent = EXCLUDED.rbipercent,
re24 = EXCLUDED.re24; re24 = EXCLUDED.re24;
""" """
# Execute the query with parameters using the passed database connection # Execute the query with parameters using the passed database connection
db_connection.execute_sql(query, [season, player_ids, season, player_ids]) db_connection.execute_sql(query, [season, player_ids, season, player_ids])
logger.info(f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}") logger.info(
f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}"
)
except Exception as e: except Exception as e:
logger.error(f"Error updating season pitching stats: {e}") logger.error(f"Error updating season pitching stats: {e}")
raise raise
@ -474,26 +486,27 @@ def update_season_pitching_stats(player_ids, season, db_connection):
def send_webhook_message(message: str) -> bool: def send_webhook_message(message: str) -> bool:
""" """
Send a message to Discord via webhook. Send a message to Discord via webhook.
Args: Args:
message: The message content to send message: The message content to send
Returns: Returns:
bool: True if successful, False otherwise bool: True if successful, False otherwise
""" """
webhook_url = "https://discord.com/api/webhooks/1408811717424840876/7RXG_D5IqovA3Jwa9YOobUjVcVMuLc6cQyezABcWuXaHo5Fvz1en10M7J43o3OJ3bzGW" webhook_url = os.environ.get("DISCORD_WEBHOOK_URL")
if not webhook_url:
logger.error("DISCORD_WEBHOOK_URL environment variable is not set")
return False
try: try:
payload = { payload = {"content": message}
"content": message
}
response = requests.post(webhook_url, json=payload, timeout=10) response = requests.post(webhook_url, json=payload, timeout=10)
response.raise_for_status() response.raise_for_status()
logger.info(f"Webhook message sent successfully: {message[:100]}...") logger.info(f"Webhook message sent successfully: {message[:100]}...")
return True return True
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.error(f"Failed to send webhook message: {e}") logger.error(f"Failed to send webhook message: {e}")
return False return False
@ -502,99 +515,106 @@ def send_webhook_message(message: str) -> bool:
return False return False
def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True): def cache_result(
ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True
):
""" """
Decorator to cache function results in Redis with parameter normalization. Decorator to cache function results in Redis with parameter normalization.
Args: Args:
ttl: Time to live in seconds (default: 5 minutes) ttl: Time to live in seconds (default: 5 minutes)
key_prefix: Prefix for cache keys (default: "api") key_prefix: Prefix for cache keys (default: "api")
normalize_params: Remove None/empty values to reduce cache variations (default: True) normalize_params: Remove None/empty values to reduce cache variations (default: True)
Usage: Usage:
@cache_result(ttl=600, key_prefix="stats") @cache_result(ttl=600, key_prefix="stats")
async def get_player_stats(player_id: int, season: Optional[int] = None): async def get_player_stats(player_id: int, season: Optional[int] = None):
# expensive operation # expensive operation
return stats return stats
# These will use the same cache entry when normalize_params=True: # These will use the same cache entry when normalize_params=True:
# get_player_stats(123, None) and get_player_stats(123) # get_player_stats(123, None) and get_player_stats(123)
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
# Skip caching if Redis is not available # Skip caching if Redis is not available
if redis_client is None: if redis_client is None:
return await func(*args, **kwargs) return await func(*args, **kwargs)
try: try:
# Normalize parameters to reduce cache variations # Normalize parameters to reduce cache variations
normalized_kwargs = kwargs.copy() normalized_kwargs = kwargs.copy()
if normalize_params: if normalize_params:
# Remove None values and empty collections # Remove None values and empty collections
normalized_kwargs = { normalized_kwargs = {
k: v for k, v in kwargs.items() k: v
for k, v in kwargs.items()
if v is not None and v != [] and v != "" and v != {} if v is not None and v != [] and v != "" and v != {}
} }
# Generate more readable cache key # Generate more readable cache key
args_str = "_".join(str(arg) for arg in args if arg is not None) args_str = "_".join(str(arg) for arg in args if arg is not None)
kwargs_str = "_".join([ kwargs_str = "_".join(
f"{k}={v}" for k, v in sorted(normalized_kwargs.items()) [f"{k}={v}" for k, v in sorted(normalized_kwargs.items())]
]) )
# Combine args and kwargs for cache key # Combine args and kwargs for cache key
key_parts = [key_prefix, func.__name__] key_parts = [key_prefix, func.__name__]
if args_str: if args_str:
key_parts.append(args_str) key_parts.append(args_str)
if kwargs_str: if kwargs_str:
key_parts.append(kwargs_str) key_parts.append(kwargs_str)
cache_key = ":".join(key_parts) cache_key = ":".join(key_parts)
# Truncate very long cache keys to prevent Redis key size limits # Truncate very long cache keys to prevent Redis key size limits
if len(cache_key) > 200: if len(cache_key) > 200:
cache_key = f"{key_prefix}:{func.__name__}:{hash(cache_key)}" cache_key = f"{key_prefix}:{func.__name__}:{hash(cache_key)}"
# Try to get from cache # Try to get from cache
cached_result = redis_client.get(cache_key) cached_result = redis_client.get(cache_key)
if cached_result is not None: if cached_result is not None:
logger.debug(f"Cache hit for key: {cache_key}") logger.debug(f"Cache hit for key: {cache_key}")
return json.loads(cached_result) return json.loads(cached_result)
# Cache miss - execute function # Cache miss - execute function
logger.debug(f"Cache miss for key: {cache_key}") logger.debug(f"Cache miss for key: {cache_key}")
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
# Skip caching for Response objects (like CSV downloads) as they can't be properly serialized # Skip caching for Response objects (like CSV downloads) as they can't be properly serialized
if not isinstance(result, Response): if not isinstance(result, Response):
# Store in cache with TTL # Store in cache with TTL
redis_client.setex( redis_client.setex(
cache_key, cache_key,
ttl, ttl,
json.dumps(result, default=str, ensure_ascii=False) json.dumps(result, default=str, ensure_ascii=False),
) )
else: else:
logger.debug(f"Skipping cache for Response object from {func.__name__}") logger.debug(
f"Skipping cache for Response object from {func.__name__}"
)
return result return result
except Exception as e: except Exception as e:
# If caching fails, log error and continue without caching # If caching fails, log error and continue without caching
logger.error(f"Cache error for {func.__name__}: {e}") logger.error(f"Cache error for {func.__name__}: {e}")
return await func(*args, **kwargs) return await func(*args, **kwargs)
return wrapper return wrapper
return decorator return decorator
def invalidate_cache(pattern: str = "*"): def invalidate_cache(pattern: str = "*"):
""" """
Invalidate cache entries matching a pattern. Invalidate cache entries matching a pattern.
Args: Args:
pattern: Redis pattern to match keys (default: "*" for all) pattern: Redis pattern to match keys (default: "*" for all)
Usage: Usage:
invalidate_cache("stats:*") # Clear all stats cache invalidate_cache("stats:*") # Clear all stats cache
invalidate_cache("api:get_player_*") # Clear specific player cache invalidate_cache("api:get_player_*") # Clear specific player cache
@ -602,12 +622,14 @@ def invalidate_cache(pattern: str = "*"):
if redis_client is None: if redis_client is None:
logger.warning("Cannot invalidate cache: Redis not available") logger.warning("Cannot invalidate cache: Redis not available")
return 0 return 0
try: try:
keys = redis_client.keys(pattern) keys = redis_client.keys(pattern)
if keys: if keys:
deleted = redis_client.delete(*keys) deleted = redis_client.delete(*keys)
logger.info(f"Invalidated {deleted} cache entries matching pattern: {pattern}") logger.info(
f"Invalidated {deleted} cache entries matching pattern: {pattern}"
)
return deleted return deleted
else: else:
logger.debug(f"No cache entries found matching pattern: {pattern}") logger.debug(f"No cache entries found matching pattern: {pattern}")
@ -620,13 +642,13 @@ def invalidate_cache(pattern: str = "*"):
def get_cache_stats() -> dict: def get_cache_stats() -> dict:
""" """
Get Redis cache statistics. Get Redis cache statistics.
Returns: Returns:
dict: Cache statistics including memory usage, key count, etc. dict: Cache statistics including memory usage, key count, etc.
""" """
if redis_client is None: if redis_client is None:
return {"status": "unavailable", "message": "Redis not connected"} return {"status": "unavailable", "message": "Redis not connected"}
try: try:
info = redis_client.info() info = redis_client.info()
return { return {
@ -634,7 +656,7 @@ def get_cache_stats() -> dict:
"memory_used": info.get("used_memory_human", "unknown"), "memory_used": info.get("used_memory_human", "unknown"),
"total_keys": redis_client.dbsize(), "total_keys": redis_client.dbsize(),
"connected_clients": info.get("connected_clients", 0), "connected_clients": info.get("connected_clients", 0),
"uptime_seconds": info.get("uptime_in_seconds", 0) "uptime_seconds": info.get("uptime_in_seconds", 0),
} }
except Exception as e: except Exception as e:
logger.error(f"Error getting cache stats: {e}") logger.error(f"Error getting cache stats: {e}")
@ -642,34 +664,35 @@ def get_cache_stats() -> dict:
def add_cache_headers( def add_cache_headers(
max_age: int = 300, max_age: int = 300,
cache_type: str = "public", cache_type: str = "public",
vary: Optional[str] = None, vary: Optional[str] = None,
etag: bool = False etag: bool = False,
): ):
""" """
Decorator to add HTTP cache headers to FastAPI responses. Decorator to add HTTP cache headers to FastAPI responses.
Args: Args:
max_age: Cache duration in seconds (default: 5 minutes) max_age: Cache duration in seconds (default: 5 minutes)
cache_type: "public", "private", or "no-cache" (default: "public") cache_type: "public", "private", or "no-cache" (default: "public")
vary: Vary header value (e.g., "Accept-Encoding, Authorization") vary: Vary header value (e.g., "Accept-Encoding, Authorization")
etag: Whether to generate ETag based on response content etag: Whether to generate ETag based on response content
Usage: Usage:
@add_cache_headers(max_age=1800, cache_type="public") @add_cache_headers(max_age=1800, cache_type="public")
async def get_static_data(): async def get_static_data():
return {"data": "static content"} return {"data": "static content"}
@add_cache_headers(max_age=60, cache_type="private", vary="Authorization") @add_cache_headers(max_age=60, cache_type="private", vary="Authorization")
async def get_user_data(): async def get_user_data():
return {"data": "user specific"} return {"data": "user specific"}
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
# Handle different response types # Handle different response types
if isinstance(result, Response): if isinstance(result, Response):
response = result response = result
@ -677,38 +700,41 @@ def add_cache_headers(
# Convert to Response with JSON content # Convert to Response with JSON content
response = Response( response = Response(
content=json.dumps(result, default=str, ensure_ascii=False), content=json.dumps(result, default=str, ensure_ascii=False),
media_type="application/json" media_type="application/json",
) )
else: else:
# Handle other response types # Handle other response types
response = Response(content=str(result)) response = Response(content=str(result))
# Build Cache-Control header # Build Cache-Control header
cache_control_parts = [cache_type] cache_control_parts = [cache_type]
if cache_type != "no-cache" and max_age > 0: if cache_type != "no-cache" and max_age > 0:
cache_control_parts.append(f"max-age={max_age}") cache_control_parts.append(f"max-age={max_age}")
response.headers["Cache-Control"] = ", ".join(cache_control_parts) response.headers["Cache-Control"] = ", ".join(cache_control_parts)
# Add Vary header if specified # Add Vary header if specified
if vary: if vary:
response.headers["Vary"] = vary response.headers["Vary"] = vary
# Add ETag if requested # Add ETag if requested
if etag and (hasattr(result, '__dict__') or isinstance(result, (dict, list))): if etag and (
hasattr(result, "__dict__") or isinstance(result, (dict, list))
):
content_hash = hashlib.md5( content_hash = hashlib.md5(
json.dumps(result, default=str, sort_keys=True).encode() json.dumps(result, default=str, sort_keys=True).encode()
).hexdigest() ).hexdigest()
response.headers["ETag"] = f'"{content_hash}"' response.headers["ETag"] = f'"{content_hash}"'
# Add Last-Modified header with current time for dynamic content # Add Last-Modified header with current time for dynamic content
response.headers["Last-Modified"] = datetime.datetime.now(datetime.timezone.utc).strftime( response.headers["Last-Modified"] = datetime.datetime.now(
"%a, %d %b %Y %H:%M:%S GMT" datetime.timezone.utc
) ).strftime("%a, %d %b %Y %H:%M:%S GMT")
return response return response
return wrapper return wrapper
return decorator return decorator
@ -718,52 +744,59 @@ def handle_db_errors(func):
Ensures proper cleanup of database connections and provides consistent error handling. Ensures proper cleanup of database connections and provides consistent error handling.
Includes comprehensive logging with function context, timing, and stack traces. Includes comprehensive logging with function context, timing, and stack traces.
""" """
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
import time import time
import traceback import traceback
from .db_engine import db # Import here to avoid circular imports from .db_engine import db # Import here to avoid circular imports
start_time = time.time() start_time = time.time()
func_name = f"{func.__module__}.{func.__name__}" func_name = f"{func.__module__}.{func.__name__}"
# Sanitize arguments for logging (exclude sensitive data) # Sanitize arguments for logging (exclude sensitive data)
safe_args = [] safe_args = []
safe_kwargs = {} safe_kwargs = {}
try: try:
# Log sanitized arguments (avoid logging tokens, passwords, etc.) # Log sanitized arguments (avoid logging tokens, passwords, etc.)
for arg in args: for arg in args:
if hasattr(arg, '__dict__') and hasattr(arg, 'url'): # FastAPI Request object if hasattr(arg, "__dict__") and hasattr(
safe_args.append(f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})") arg, "url"
): # FastAPI Request object
safe_args.append(
f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})"
)
else: else:
safe_args.append(str(arg)[:100]) # Truncate long values safe_args.append(str(arg)[:100]) # Truncate long values
for key, value in kwargs.items(): for key, value in kwargs.items():
if key.lower() in ['token', 'password', 'secret', 'key']: if key.lower() in ["token", "password", "secret", "key"]:
safe_kwargs[key] = '[REDACTED]' safe_kwargs[key] = "[REDACTED]"
else: else:
safe_kwargs[key] = str(value)[:100] # Truncate long values safe_kwargs[key] = str(value)[:100] # Truncate long values
logger.info(f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}") logger.info(
f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}"
)
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info(f"Completed {func_name} successfully in {elapsed_time:.3f}s") logger.info(f"Completed {func_name} successfully in {elapsed_time:.3f}s")
return result return result
except Exception as e: except Exception as e:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
error_trace = traceback.format_exc() error_trace = traceback.format_exc()
logger.error(f"Database error in {func_name} after {elapsed_time:.3f}s") logger.error(f"Database error in {func_name} after {elapsed_time:.3f}s")
logger.error(f"Function args: {safe_args}") logger.error(f"Function args: {safe_args}")
logger.error(f"Function kwargs: {safe_kwargs}") logger.error(f"Function kwargs: {safe_kwargs}")
logger.error(f"Exception: {str(e)}") logger.error(f"Exception: {str(e)}")
logger.error(f"Full traceback:\n{error_trace}") logger.error(f"Full traceback:\n{error_trace}")
try: try:
logger.info(f"Attempting database rollback for {func_name}") logger.info(f"Attempting database rollback for {func_name}")
db.rollback() db.rollback()
@ -775,8 +808,12 @@ def handle_db_errors(func):
db.close() db.close()
logger.info(f"Database connection closed for {func_name}") logger.info(f"Database connection closed for {func_name}")
except Exception as close_error: except Exception as close_error:
logger.error(f"Error closing database connection in {func_name}: {close_error}") logger.error(
f"Error closing database connection in {func_name}: {close_error}"
raise HTTPException(status_code=500, detail=f'Database error in {func_name}: {str(e)}') )
raise HTTPException(
status_code=500, detail=f"Database error in {func_name}: {str(e)}"
)
return wrapper return wrapper