Compare commits

...

2 Commits

Author SHA1 Message Date
Cal Corum
55bf035db0 fix: restore finally block in handle_db_errors to prevent connection leaks (#38)
Reviewer correctly identified that removing the finally block introduced
real connection leaks for handlers that do not call db.close() on their
own error paths. Peewee's PooledDatabase.close() is a no-op on the second
call, so double-close is harmless — the finally block provides necessary
safety net.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-07 02:01:48 -06:00
Cal Corum
7734e558a9 fix: remove db.close() from handle_db_errors error handler (#38)
All checks were successful
Build Docker Image / build (pull_request) Successful in 1m59s
Removes the `finally` block that called `db.close()` after rollback in
the `handle_db_errors` decorator. With connection pooling, route handlers
already call `db.close()` themselves, so closing again in the error handler
could return connections to the pool twice, corrupting pool state.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-05 10:02:45 -06:00

View File

@ -11,8 +11,8 @@ from fastapi import HTTPException, Response
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from redis import Redis from redis import Redis
date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' date = f"{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}"
logger = logging.getLogger('discord_app') logger = logging.getLogger("discord_app")
# date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}' # date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
# log_level = logger.info if os.environ.get('LOG_LEVEL') == 'INFO' else 'WARN' # log_level = logger.info if os.environ.get('LOG_LEVEL') == 'INFO' else 'WARN'
@ -23,10 +23,10 @@ logger = logging.getLogger('discord_app')
# ) # )
# Redis configuration # Redis configuration
REDIS_HOST = os.environ.get('REDIS_HOST', 'localhost') REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
REDIS_PORT = int(os.environ.get('REDIS_PORT', '6379')) REDIS_PORT = int(os.environ.get("REDIS_PORT", "6379"))
REDIS_DB = int(os.environ.get('REDIS_DB', '0')) REDIS_DB = int(os.environ.get("REDIS_DB", "0"))
CACHE_ENABLED = os.environ.get('CACHE_ENABLED', 'true').lower() == 'true' CACHE_ENABLED = os.environ.get("CACHE_ENABLED", "true").lower() == "true"
# Initialize Redis client with connection error handling # Initialize Redis client with connection error handling
if not CACHE_ENABLED: if not CACHE_ENABLED:
@ -40,7 +40,7 @@ else:
db=REDIS_DB, db=REDIS_DB,
decode_responses=True, decode_responses=True,
socket_connect_timeout=5, socket_connect_timeout=5,
socket_timeout=5 socket_timeout=5,
) )
# Test connection # Test connection
redis_client.ping() redis_client.ping()
@ -50,12 +50,16 @@ else:
redis_client = None redis_client = None
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
priv_help = False if not os.environ.get('PRIVATE_IN_SCHEMA') else os.environ.get('PRIVATE_IN_SCHEMA').upper() priv_help = (
PRIVATE_IN_SCHEMA = True if priv_help == 'TRUE' else False False
if not os.environ.get("PRIVATE_IN_SCHEMA")
else os.environ.get("PRIVATE_IN_SCHEMA").upper()
)
PRIVATE_IN_SCHEMA = True if priv_help == "TRUE" else False
def valid_token(token): def valid_token(token):
return token == os.environ.get('API_TOKEN') return token == os.environ.get("API_TOKEN")
def update_season_batting_stats(player_ids, season, db_connection): def update_season_batting_stats(player_ids, season, db_connection):
@ -72,7 +76,9 @@ def update_season_batting_stats(player_ids, season, db_connection):
if isinstance(player_ids, int): if isinstance(player_ids, int):
player_ids = [player_ids] player_ids = [player_ids]
logger.info(f"Updating season batting stats for {len(player_ids)} players in season {season}") logger.info(
f"Updating season batting stats for {len(player_ids)} players in season {season}"
)
try: try:
# SQL query to recalculate and upsert batting stats # SQL query to recalculate and upsert batting stats
@ -221,7 +227,9 @@ def update_season_batting_stats(player_ids, season, db_connection):
# Execute the query with parameters using the passed database connection # Execute the query with parameters using the passed database connection
db_connection.execute_sql(query, [season, player_ids, season, player_ids]) db_connection.execute_sql(query, [season, player_ids, season, player_ids])
logger.info(f"Successfully updated season batting stats for {len(player_ids)} players in season {season}") logger.info(
f"Successfully updated season batting stats for {len(player_ids)} players in season {season}"
)
except Exception as e: except Exception as e:
logger.error(f"Error updating season batting stats: {e}") logger.error(f"Error updating season batting stats: {e}")
@ -242,7 +250,9 @@ def update_season_pitching_stats(player_ids, season, db_connection):
if isinstance(player_ids, int): if isinstance(player_ids, int):
player_ids = [player_ids] player_ids = [player_ids]
logger.info(f"Updating season pitching stats for {len(player_ids)} players in season {season}") logger.info(
f"Updating season pitching stats for {len(player_ids)} players in season {season}"
)
try: try:
# SQL query to recalculate and upsert pitching stats # SQL query to recalculate and upsert pitching stats
@ -464,7 +474,9 @@ def update_season_pitching_stats(player_ids, season, db_connection):
# Execute the query with parameters using the passed database connection # Execute the query with parameters using the passed database connection
db_connection.execute_sql(query, [season, player_ids, season, player_ids]) db_connection.execute_sql(query, [season, player_ids, season, player_ids])
logger.info(f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}") logger.info(
f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}"
)
except Exception as e: except Exception as e:
logger.error(f"Error updating season pitching stats: {e}") logger.error(f"Error updating season pitching stats: {e}")
@ -484,9 +496,7 @@ def send_webhook_message(message: str) -> bool:
webhook_url = "https://discord.com/api/webhooks/1408811717424840876/7RXG_D5IqovA3Jwa9YOobUjVcVMuLc6cQyezABcWuXaHo5Fvz1en10M7J43o3OJ3bzGW" webhook_url = "https://discord.com/api/webhooks/1408811717424840876/7RXG_D5IqovA3Jwa9YOobUjVcVMuLc6cQyezABcWuXaHo5Fvz1en10M7J43o3OJ3bzGW"
try: try:
payload = { payload = {"content": message}
"content": message
}
response = requests.post(webhook_url, json=payload, timeout=10) response = requests.post(webhook_url, json=payload, timeout=10)
response.raise_for_status() response.raise_for_status()
@ -502,7 +512,9 @@ def send_webhook_message(message: str) -> bool:
return False return False
def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True): def cache_result(
ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True
):
""" """
Decorator to cache function results in Redis with parameter normalization. Decorator to cache function results in Redis with parameter normalization.
@ -520,6 +532,7 @@ def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool
# These will use the same cache entry when normalize_params=True: # These will use the same cache entry when normalize_params=True:
# get_player_stats(123, None) and get_player_stats(123) # get_player_stats(123, None) and get_player_stats(123)
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
@ -533,15 +546,16 @@ def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool
if normalize_params: if normalize_params:
# Remove None values and empty collections # Remove None values and empty collections
normalized_kwargs = { normalized_kwargs = {
k: v for k, v in kwargs.items() k: v
for k, v in kwargs.items()
if v is not None and v != [] and v != "" and v != {} if v is not None and v != [] and v != "" and v != {}
} }
# Generate more readable cache key # Generate more readable cache key
args_str = "_".join(str(arg) for arg in args if arg is not None) args_str = "_".join(str(arg) for arg in args if arg is not None)
kwargs_str = "_".join([ kwargs_str = "_".join(
f"{k}={v}" for k, v in sorted(normalized_kwargs.items()) [f"{k}={v}" for k, v in sorted(normalized_kwargs.items())]
]) )
# Combine args and kwargs for cache key # Combine args and kwargs for cache key
key_parts = [key_prefix, func.__name__] key_parts = [key_prefix, func.__name__]
@ -572,10 +586,12 @@ def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool
redis_client.setex( redis_client.setex(
cache_key, cache_key,
ttl, ttl,
json.dumps(result, default=str, ensure_ascii=False) json.dumps(result, default=str, ensure_ascii=False),
) )
else: else:
logger.debug(f"Skipping cache for Response object from {func.__name__}") logger.debug(
f"Skipping cache for Response object from {func.__name__}"
)
return result return result
@ -585,6 +601,7 @@ def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool
return await func(*args, **kwargs) return await func(*args, **kwargs)
return wrapper return wrapper
return decorator return decorator
@ -607,7 +624,9 @@ def invalidate_cache(pattern: str = "*"):
keys = redis_client.keys(pattern) keys = redis_client.keys(pattern)
if keys: if keys:
deleted = redis_client.delete(*keys) deleted = redis_client.delete(*keys)
logger.info(f"Invalidated {deleted} cache entries matching pattern: {pattern}") logger.info(
f"Invalidated {deleted} cache entries matching pattern: {pattern}"
)
return deleted return deleted
else: else:
logger.debug(f"No cache entries found matching pattern: {pattern}") logger.debug(f"No cache entries found matching pattern: {pattern}")
@ -634,7 +653,7 @@ def get_cache_stats() -> dict:
"memory_used": info.get("used_memory_human", "unknown"), "memory_used": info.get("used_memory_human", "unknown"),
"total_keys": redis_client.dbsize(), "total_keys": redis_client.dbsize(),
"connected_clients": info.get("connected_clients", 0), "connected_clients": info.get("connected_clients", 0),
"uptime_seconds": info.get("uptime_in_seconds", 0) "uptime_seconds": info.get("uptime_in_seconds", 0),
} }
except Exception as e: except Exception as e:
logger.error(f"Error getting cache stats: {e}") logger.error(f"Error getting cache stats: {e}")
@ -645,7 +664,7 @@ def add_cache_headers(
max_age: int = 300, max_age: int = 300,
cache_type: str = "public", cache_type: str = "public",
vary: Optional[str] = None, vary: Optional[str] = None,
etag: bool = False etag: bool = False,
): ):
""" """
Decorator to add HTTP cache headers to FastAPI responses. Decorator to add HTTP cache headers to FastAPI responses.
@ -665,6 +684,7 @@ def add_cache_headers(
async def get_user_data(): async def get_user_data():
return {"data": "user specific"} return {"data": "user specific"}
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
@ -677,7 +697,7 @@ def add_cache_headers(
# Convert to Response with JSON content # Convert to Response with JSON content
response = Response( response = Response(
content=json.dumps(result, default=str, ensure_ascii=False), content=json.dumps(result, default=str, ensure_ascii=False),
media_type="application/json" media_type="application/json",
) )
else: else:
# Handle other response types # Handle other response types
@ -695,20 +715,23 @@ def add_cache_headers(
response.headers["Vary"] = vary response.headers["Vary"] = vary
# Add ETag if requested # Add ETag if requested
if etag and (hasattr(result, '__dict__') or isinstance(result, (dict, list))): if etag and (
hasattr(result, "__dict__") or isinstance(result, (dict, list))
):
content_hash = hashlib.md5( content_hash = hashlib.md5(
json.dumps(result, default=str, sort_keys=True).encode() json.dumps(result, default=str, sort_keys=True).encode()
).hexdigest() ).hexdigest()
response.headers["ETag"] = f'"{content_hash}"' response.headers["ETag"] = f'"{content_hash}"'
# Add Last-Modified header with current time for dynamic content # Add Last-Modified header with current time for dynamic content
response.headers["Last-Modified"] = datetime.datetime.now(datetime.timezone.utc).strftime( response.headers["Last-Modified"] = datetime.datetime.now(
"%a, %d %b %Y %H:%M:%S GMT" datetime.timezone.utc
) ).strftime("%a, %d %b %Y %H:%M:%S GMT")
return response return response
return wrapper return wrapper
return decorator return decorator
@ -718,6 +741,7 @@ def handle_db_errors(func):
Ensures proper cleanup of database connections and provides consistent error handling. Ensures proper cleanup of database connections and provides consistent error handling.
Includes comprehensive logging with function context, timing, and stack traces. Includes comprehensive logging with function context, timing, and stack traces.
""" """
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
import time import time
@ -734,18 +758,24 @@ def handle_db_errors(func):
try: try:
# Log sanitized arguments (avoid logging tokens, passwords, etc.) # Log sanitized arguments (avoid logging tokens, passwords, etc.)
for arg in args: for arg in args:
if hasattr(arg, '__dict__') and hasattr(arg, 'url'): # FastAPI Request object if hasattr(arg, "__dict__") and hasattr(
safe_args.append(f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})") arg, "url"
): # FastAPI Request object
safe_args.append(
f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})"
)
else: else:
safe_args.append(str(arg)[:100]) # Truncate long values safe_args.append(str(arg)[:100]) # Truncate long values
for key, value in kwargs.items(): for key, value in kwargs.items():
if key.lower() in ['token', 'password', 'secret', 'key']: if key.lower() in ["token", "password", "secret", "key"]:
safe_kwargs[key] = '[REDACTED]' safe_kwargs[key] = "[REDACTED]"
else: else:
safe_kwargs[key] = str(value)[:100] # Truncate long values safe_kwargs[key] = str(value)[:100] # Truncate long values
logger.info(f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}") logger.info(
f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}"
)
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
@ -775,8 +805,12 @@ def handle_db_errors(func):
db.close() db.close()
logger.info(f"Database connection closed for {func_name}") logger.info(f"Database connection closed for {func_name}")
except Exception as close_error: except Exception as close_error:
logger.error(f"Error closing database connection in {func_name}: {close_error}") logger.error(
f"Error closing database connection in {func_name}: {close_error}"
)
raise HTTPException(status_code=500, detail=f'Database error in {func_name}: {str(e)}') raise HTTPException(
status_code=500, detail=f"Database error in {func_name}: {str(e)}"
)
return wrapper return wrapper