Compare commits
2 Commits
main
...
ai/major-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55bf035db0 | ||
|
|
7734e558a9 |
@ -11,8 +11,8 @@ from fastapi import HTTPException, Response
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from redis import Redis
|
||||
|
||||
date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
|
||||
logger = logging.getLogger('discord_app')
|
||||
date = f"{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}"
|
||||
logger = logging.getLogger("discord_app")
|
||||
|
||||
# date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
|
||||
# log_level = logger.info if os.environ.get('LOG_LEVEL') == 'INFO' else 'WARN'
|
||||
@ -23,10 +23,10 @@ logger = logging.getLogger('discord_app')
|
||||
# )
|
||||
|
||||
# Redis configuration
|
||||
REDIS_HOST = os.environ.get('REDIS_HOST', 'localhost')
|
||||
REDIS_PORT = int(os.environ.get('REDIS_PORT', '6379'))
|
||||
REDIS_DB = int(os.environ.get('REDIS_DB', '0'))
|
||||
CACHE_ENABLED = os.environ.get('CACHE_ENABLED', 'true').lower() == 'true'
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", "6379"))
|
||||
REDIS_DB = int(os.environ.get("REDIS_DB", "0"))
|
||||
CACHE_ENABLED = os.environ.get("CACHE_ENABLED", "true").lower() == "true"
|
||||
|
||||
# Initialize Redis client with connection error handling
|
||||
if not CACHE_ENABLED:
|
||||
@ -40,7 +40,7 @@ else:
|
||||
db=REDIS_DB,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=5,
|
||||
socket_timeout=5
|
||||
socket_timeout=5,
|
||||
)
|
||||
# Test connection
|
||||
redis_client.ping()
|
||||
@ -50,12 +50,16 @@ else:
|
||||
redis_client = None
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
priv_help = False if not os.environ.get('PRIVATE_IN_SCHEMA') else os.environ.get('PRIVATE_IN_SCHEMA').upper()
|
||||
PRIVATE_IN_SCHEMA = True if priv_help == 'TRUE' else False
|
||||
priv_help = (
|
||||
False
|
||||
if not os.environ.get("PRIVATE_IN_SCHEMA")
|
||||
else os.environ.get("PRIVATE_IN_SCHEMA").upper()
|
||||
)
|
||||
PRIVATE_IN_SCHEMA = True if priv_help == "TRUE" else False
|
||||
|
||||
|
||||
def valid_token(token):
|
||||
return token == os.environ.get('API_TOKEN')
|
||||
return token == os.environ.get("API_TOKEN")
|
||||
|
||||
|
||||
def update_season_batting_stats(player_ids, season, db_connection):
|
||||
@ -63,17 +67,19 @@ def update_season_batting_stats(player_ids, season, db_connection):
|
||||
Update season batting stats for specific players in a given season.
|
||||
Recalculates stats from stratplay data and upserts into seasonbattingstats table.
|
||||
"""
|
||||
|
||||
|
||||
if not player_ids:
|
||||
logger.warning("update_season_batting_stats called with empty player_ids list")
|
||||
return
|
||||
|
||||
|
||||
# Convert single player_id to list for consistency
|
||||
if isinstance(player_ids, int):
|
||||
player_ids = [player_ids]
|
||||
|
||||
logger.info(f"Updating season batting stats for {len(player_ids)} players in season {season}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Updating season batting stats for {len(player_ids)} players in season {season}"
|
||||
)
|
||||
|
||||
try:
|
||||
# SQL query to recalculate and upsert batting stats
|
||||
query = """
|
||||
@ -217,12 +223,14 @@ def update_season_batting_stats(player_ids, season, db_connection):
|
||||
sb = EXCLUDED.sb,
|
||||
cs = EXCLUDED.cs;
|
||||
"""
|
||||
|
||||
|
||||
# Execute the query with parameters using the passed database connection
|
||||
db_connection.execute_sql(query, [season, player_ids, season, player_ids])
|
||||
|
||||
logger.info(f"Successfully updated season batting stats for {len(player_ids)} players in season {season}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Successfully updated season batting stats for {len(player_ids)} players in season {season}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating season batting stats: {e}")
|
||||
raise
|
||||
@ -233,17 +241,19 @@ def update_season_pitching_stats(player_ids, season, db_connection):
|
||||
Update season pitching stats for specific players in a given season.
|
||||
Recalculates stats from stratplay and decision data and upserts into seasonpitchingstats table.
|
||||
"""
|
||||
|
||||
|
||||
if not player_ids:
|
||||
logger.warning("update_season_pitching_stats called with empty player_ids list")
|
||||
return
|
||||
|
||||
|
||||
# Convert single player_id to list for consistency
|
||||
if isinstance(player_ids, int):
|
||||
player_ids = [player_ids]
|
||||
|
||||
logger.info(f"Updating season pitching stats for {len(player_ids)} players in season {season}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Updating season pitching stats for {len(player_ids)} players in season {season}"
|
||||
)
|
||||
|
||||
try:
|
||||
# SQL query to recalculate and upsert pitching stats
|
||||
query = """
|
||||
@ -460,12 +470,14 @@ def update_season_pitching_stats(player_ids, season, db_connection):
|
||||
rbipercent = EXCLUDED.rbipercent,
|
||||
re24 = EXCLUDED.re24;
|
||||
"""
|
||||
|
||||
|
||||
# Execute the query with parameters using the passed database connection
|
||||
db_connection.execute_sql(query, [season, player_ids, season, player_ids])
|
||||
|
||||
logger.info(f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Successfully updated season pitching stats for {len(player_ids)} players in season {season}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating season pitching stats: {e}")
|
||||
raise
|
||||
@ -474,26 +486,24 @@ def update_season_pitching_stats(player_ids, season, db_connection):
|
||||
def send_webhook_message(message: str) -> bool:
|
||||
"""
|
||||
Send a message to Discord via webhook.
|
||||
|
||||
|
||||
Args:
|
||||
message: The message content to send
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
webhook_url = "https://discord.com/api/webhooks/1408811717424840876/7RXG_D5IqovA3Jwa9YOobUjVcVMuLc6cQyezABcWuXaHo5Fvz1en10M7J43o3OJ3bzGW"
|
||||
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"content": message
|
||||
}
|
||||
|
||||
payload = {"content": message}
|
||||
|
||||
response = requests.post(webhook_url, json=payload, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
logger.info(f"Webhook message sent successfully: {message[:100]}...")
|
||||
return True
|
||||
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to send webhook message: {e}")
|
||||
return False
|
||||
@ -502,99 +512,106 @@ def send_webhook_message(message: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True):
|
||||
def cache_result(
|
||||
ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True
|
||||
):
|
||||
"""
|
||||
Decorator to cache function results in Redis with parameter normalization.
|
||||
|
||||
|
||||
Args:
|
||||
ttl: Time to live in seconds (default: 5 minutes)
|
||||
key_prefix: Prefix for cache keys (default: "api")
|
||||
normalize_params: Remove None/empty values to reduce cache variations (default: True)
|
||||
|
||||
|
||||
Usage:
|
||||
@cache_result(ttl=600, key_prefix="stats")
|
||||
async def get_player_stats(player_id: int, season: Optional[int] = None):
|
||||
# expensive operation
|
||||
return stats
|
||||
|
||||
|
||||
# These will use the same cache entry when normalize_params=True:
|
||||
# get_player_stats(123, None) and get_player_stats(123)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Skip caching if Redis is not available
|
||||
if redis_client is None:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
|
||||
try:
|
||||
# Normalize parameters to reduce cache variations
|
||||
normalized_kwargs = kwargs.copy()
|
||||
if normalize_params:
|
||||
# Remove None values and empty collections
|
||||
normalized_kwargs = {
|
||||
k: v for k, v in kwargs.items()
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if v is not None and v != [] and v != "" and v != {}
|
||||
}
|
||||
|
||||
|
||||
# Generate more readable cache key
|
||||
args_str = "_".join(str(arg) for arg in args if arg is not None)
|
||||
kwargs_str = "_".join([
|
||||
f"{k}={v}" for k, v in sorted(normalized_kwargs.items())
|
||||
])
|
||||
|
||||
kwargs_str = "_".join(
|
||||
[f"{k}={v}" for k, v in sorted(normalized_kwargs.items())]
|
||||
)
|
||||
|
||||
# Combine args and kwargs for cache key
|
||||
key_parts = [key_prefix, func.__name__]
|
||||
if args_str:
|
||||
key_parts.append(args_str)
|
||||
if kwargs_str:
|
||||
key_parts.append(kwargs_str)
|
||||
|
||||
|
||||
cache_key = ":".join(key_parts)
|
||||
|
||||
|
||||
# Truncate very long cache keys to prevent Redis key size limits
|
||||
if len(cache_key) > 200:
|
||||
cache_key = f"{key_prefix}:{func.__name__}:{hash(cache_key)}"
|
||||
|
||||
|
||||
# Try to get from cache
|
||||
cached_result = redis_client.get(cache_key)
|
||||
if cached_result is not None:
|
||||
logger.debug(f"Cache hit for key: {cache_key}")
|
||||
return json.loads(cached_result)
|
||||
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for key: {cache_key}")
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
|
||||
# Skip caching for Response objects (like CSV downloads) as they can't be properly serialized
|
||||
if not isinstance(result, Response):
|
||||
# Store in cache with TTL
|
||||
redis_client.setex(
|
||||
cache_key,
|
||||
ttl,
|
||||
json.dumps(result, default=str, ensure_ascii=False)
|
||||
cache_key,
|
||||
ttl,
|
||||
json.dumps(result, default=str, ensure_ascii=False),
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Skipping cache for Response object from {func.__name__}")
|
||||
|
||||
logger.debug(
|
||||
f"Skipping cache for Response object from {func.__name__}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# If caching fails, log error and continue without caching
|
||||
logger.error(f"Cache error for {func.__name__}: {e}")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def invalidate_cache(pattern: str = "*"):
|
||||
"""
|
||||
Invalidate cache entries matching a pattern.
|
||||
|
||||
|
||||
Args:
|
||||
pattern: Redis pattern to match keys (default: "*" for all)
|
||||
|
||||
|
||||
Usage:
|
||||
invalidate_cache("stats:*") # Clear all stats cache
|
||||
invalidate_cache("api:get_player_*") # Clear specific player cache
|
||||
@ -602,12 +619,14 @@ def invalidate_cache(pattern: str = "*"):
|
||||
if redis_client is None:
|
||||
logger.warning("Cannot invalidate cache: Redis not available")
|
||||
return 0
|
||||
|
||||
|
||||
try:
|
||||
keys = redis_client.keys(pattern)
|
||||
if keys:
|
||||
deleted = redis_client.delete(*keys)
|
||||
logger.info(f"Invalidated {deleted} cache entries matching pattern: {pattern}")
|
||||
logger.info(
|
||||
f"Invalidated {deleted} cache entries matching pattern: {pattern}"
|
||||
)
|
||||
return deleted
|
||||
else:
|
||||
logger.debug(f"No cache entries found matching pattern: {pattern}")
|
||||
@ -620,13 +639,13 @@ def invalidate_cache(pattern: str = "*"):
|
||||
def get_cache_stats() -> dict:
|
||||
"""
|
||||
Get Redis cache statistics.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Cache statistics including memory usage, key count, etc.
|
||||
"""
|
||||
if redis_client is None:
|
||||
return {"status": "unavailable", "message": "Redis not connected"}
|
||||
|
||||
|
||||
try:
|
||||
info = redis_client.info()
|
||||
return {
|
||||
@ -634,7 +653,7 @@ def get_cache_stats() -> dict:
|
||||
"memory_used": info.get("used_memory_human", "unknown"),
|
||||
"total_keys": redis_client.dbsize(),
|
||||
"connected_clients": info.get("connected_clients", 0),
|
||||
"uptime_seconds": info.get("uptime_in_seconds", 0)
|
||||
"uptime_seconds": info.get("uptime_in_seconds", 0),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting cache stats: {e}")
|
||||
@ -642,34 +661,35 @@ def get_cache_stats() -> dict:
|
||||
|
||||
|
||||
def add_cache_headers(
|
||||
max_age: int = 300,
|
||||
max_age: int = 300,
|
||||
cache_type: str = "public",
|
||||
vary: Optional[str] = None,
|
||||
etag: bool = False
|
||||
etag: bool = False,
|
||||
):
|
||||
"""
|
||||
Decorator to add HTTP cache headers to FastAPI responses.
|
||||
|
||||
|
||||
Args:
|
||||
max_age: Cache duration in seconds (default: 5 minutes)
|
||||
cache_type: "public", "private", or "no-cache" (default: "public")
|
||||
vary: Vary header value (e.g., "Accept-Encoding, Authorization")
|
||||
etag: Whether to generate ETag based on response content
|
||||
|
||||
|
||||
Usage:
|
||||
@add_cache_headers(max_age=1800, cache_type="public")
|
||||
async def get_static_data():
|
||||
return {"data": "static content"}
|
||||
|
||||
|
||||
@add_cache_headers(max_age=60, cache_type="private", vary="Authorization")
|
||||
async def get_user_data():
|
||||
return {"data": "user specific"}
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
|
||||
# Handle different response types
|
||||
if isinstance(result, Response):
|
||||
response = result
|
||||
@ -677,38 +697,41 @@ def add_cache_headers(
|
||||
# Convert to Response with JSON content
|
||||
response = Response(
|
||||
content=json.dumps(result, default=str, ensure_ascii=False),
|
||||
media_type="application/json"
|
||||
media_type="application/json",
|
||||
)
|
||||
else:
|
||||
# Handle other response types
|
||||
response = Response(content=str(result))
|
||||
|
||||
|
||||
# Build Cache-Control header
|
||||
cache_control_parts = [cache_type]
|
||||
if cache_type != "no-cache" and max_age > 0:
|
||||
cache_control_parts.append(f"max-age={max_age}")
|
||||
|
||||
|
||||
response.headers["Cache-Control"] = ", ".join(cache_control_parts)
|
||||
|
||||
|
||||
# Add Vary header if specified
|
||||
if vary:
|
||||
response.headers["Vary"] = vary
|
||||
|
||||
|
||||
# Add ETag if requested
|
||||
if etag and (hasattr(result, '__dict__') or isinstance(result, (dict, list))):
|
||||
if etag and (
|
||||
hasattr(result, "__dict__") or isinstance(result, (dict, list))
|
||||
):
|
||||
content_hash = hashlib.md5(
|
||||
json.dumps(result, default=str, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
response.headers["ETag"] = f'"{content_hash}"'
|
||||
|
||||
|
||||
# Add Last-Modified header with current time for dynamic content
|
||||
response.headers["Last-Modified"] = datetime.datetime.now(datetime.timezone.utc).strftime(
|
||||
"%a, %d %b %Y %H:%M:%S GMT"
|
||||
)
|
||||
|
||||
response.headers["Last-Modified"] = datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
).strftime("%a, %d %b %Y %H:%M:%S GMT")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@ -718,52 +741,59 @@ def handle_db_errors(func):
|
||||
Ensures proper cleanup of database connections and provides consistent error handling.
|
||||
Includes comprehensive logging with function context, timing, and stack traces.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
import time
|
||||
import traceback
|
||||
from .db_engine import db # Import here to avoid circular imports
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
func_name = f"{func.__module__}.{func.__name__}"
|
||||
|
||||
|
||||
# Sanitize arguments for logging (exclude sensitive data)
|
||||
safe_args = []
|
||||
safe_kwargs = {}
|
||||
|
||||
|
||||
try:
|
||||
# Log sanitized arguments (avoid logging tokens, passwords, etc.)
|
||||
for arg in args:
|
||||
if hasattr(arg, '__dict__') and hasattr(arg, 'url'): # FastAPI Request object
|
||||
safe_args.append(f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})")
|
||||
if hasattr(arg, "__dict__") and hasattr(
|
||||
arg, "url"
|
||||
): # FastAPI Request object
|
||||
safe_args.append(
|
||||
f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})"
|
||||
)
|
||||
else:
|
||||
safe_args.append(str(arg)[:100]) # Truncate long values
|
||||
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key.lower() in ['token', 'password', 'secret', 'key']:
|
||||
safe_kwargs[key] = '[REDACTED]'
|
||||
if key.lower() in ["token", "password", "secret", "key"]:
|
||||
safe_kwargs[key] = "[REDACTED]"
|
||||
else:
|
||||
safe_kwargs[key] = str(value)[:100] # Truncate long values
|
||||
|
||||
logger.info(f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Starting {func_name} - args: {safe_args}, kwargs: {safe_kwargs}"
|
||||
)
|
||||
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"Completed {func_name} successfully in {elapsed_time:.3f}s")
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
error_trace = traceback.format_exc()
|
||||
|
||||
|
||||
logger.error(f"Database error in {func_name} after {elapsed_time:.3f}s")
|
||||
logger.error(f"Function args: {safe_args}")
|
||||
logger.error(f"Function kwargs: {safe_kwargs}")
|
||||
logger.error(f"Exception: {str(e)}")
|
||||
logger.error(f"Full traceback:\n{error_trace}")
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"Attempting database rollback for {func_name}")
|
||||
db.rollback()
|
||||
@ -775,8 +805,12 @@ def handle_db_errors(func):
|
||||
db.close()
|
||||
logger.info(f"Database connection closed for {func_name}")
|
||||
except Exception as close_error:
|
||||
logger.error(f"Error closing database connection in {func_name}: {close_error}")
|
||||
|
||||
raise HTTPException(status_code=500, detail=f'Database error in {func_name}: {str(e)}')
|
||||
|
||||
logger.error(
|
||||
f"Error closing database connection in {func_name}: {close_error}"
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Database error in {func_name}: {str(e)}"
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
Loading…
Reference in New Issue
Block a user