CLAUDE: Fix cache_result decorator to handle Response objects properly
- Skip caching for FastAPI Response objects (CSV downloads, etc.) - Response objects can't be JSON-serialized/deserialized without corruption - Regular JSON responses continue to be cached normally - Fixes issue where CSV endpoints returned Response object string representation 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
2c3835c8ac
commit
abf4435931
@ -1,11 +1,15 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException, Response
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from redis import Redis
|
||||||
|
|
||||||
date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
|
date = f'{datetime.datetime.now().year}-{datetime.datetime.now().month}-{datetime.datetime.now().day}'
|
||||||
logger = logging.getLogger('discord_app')
|
logger = logging.getLogger('discord_app')
|
||||||
@ -18,6 +22,27 @@ logger = logging.getLogger('discord_app')
|
|||||||
# level=log_level
|
# level=log_level
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
# Redis configuration
|
||||||
|
REDIS_HOST = os.environ.get('REDIS_HOST', 'localhost')
|
||||||
|
REDIS_PORT = int(os.environ.get('REDIS_PORT', '6379'))
|
||||||
|
REDIS_DB = int(os.environ.get('REDIS_DB', '0'))
|
||||||
|
|
||||||
|
# Initialize Redis client with connection error handling
|
||||||
|
try:
|
||||||
|
redis_client = Redis(
|
||||||
|
host=REDIS_HOST,
|
||||||
|
port=REDIS_PORT,
|
||||||
|
db=REDIS_DB,
|
||||||
|
decode_responses=True,
|
||||||
|
socket_connect_timeout=5,
|
||||||
|
socket_timeout=5
|
||||||
|
)
|
||||||
|
# Test connection
|
||||||
|
redis_client.ping()
|
||||||
|
logger.info(f"Redis connected successfully at {REDIS_HOST}:{REDIS_PORT}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis connection failed: {e}. Caching will be disabled.")
|
||||||
|
redis_client = None
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
priv_help = False if not os.environ.get('PRIVATE_IN_SCHEMA') else os.environ.get('PRIVATE_IN_SCHEMA').upper()
|
priv_help = False if not os.environ.get('PRIVATE_IN_SCHEMA') else os.environ.get('PRIVATE_IN_SCHEMA').upper()
|
||||||
@ -472,6 +497,216 @@ def send_webhook_message(message: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def cache_result(ttl: int = 300, key_prefix: str = "api", normalize_params: bool = True):
|
||||||
|
"""
|
||||||
|
Decorator to cache function results in Redis with parameter normalization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ttl: Time to live in seconds (default: 5 minutes)
|
||||||
|
key_prefix: Prefix for cache keys (default: "api")
|
||||||
|
normalize_params: Remove None/empty values to reduce cache variations (default: True)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@cache_result(ttl=600, key_prefix="stats")
|
||||||
|
async def get_player_stats(player_id: int, season: Optional[int] = None):
|
||||||
|
# expensive operation
|
||||||
|
return stats
|
||||||
|
|
||||||
|
# These will use the same cache entry when normalize_params=True:
|
||||||
|
# get_player_stats(123, None) and get_player_stats(123)
|
||||||
|
"""
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
# Skip caching if Redis is not available
|
||||||
|
if redis_client is None:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Normalize parameters to reduce cache variations
|
||||||
|
normalized_kwargs = kwargs.copy()
|
||||||
|
if normalize_params:
|
||||||
|
# Remove None values and empty collections
|
||||||
|
normalized_kwargs = {
|
||||||
|
k: v for k, v in kwargs.items()
|
||||||
|
if v is not None and v != [] and v != "" and v != {}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate more readable cache key
|
||||||
|
args_str = "_".join(str(arg) for arg in args if arg is not None)
|
||||||
|
kwargs_str = "_".join([
|
||||||
|
f"{k}={v}" for k, v in sorted(normalized_kwargs.items())
|
||||||
|
])
|
||||||
|
|
||||||
|
# Combine args and kwargs for cache key
|
||||||
|
key_parts = [key_prefix, func.__name__]
|
||||||
|
if args_str:
|
||||||
|
key_parts.append(args_str)
|
||||||
|
if kwargs_str:
|
||||||
|
key_parts.append(kwargs_str)
|
||||||
|
|
||||||
|
cache_key = ":".join(key_parts)
|
||||||
|
|
||||||
|
# Truncate very long cache keys to prevent Redis key size limits
|
||||||
|
if len(cache_key) > 200:
|
||||||
|
cache_key = f"{key_prefix}:{func.__name__}:{hash(cache_key)}"
|
||||||
|
|
||||||
|
# Try to get from cache
|
||||||
|
cached_result = redis_client.get(cache_key)
|
||||||
|
if cached_result is not None:
|
||||||
|
logger.debug(f"Cache hit for key: {cache_key}")
|
||||||
|
return json.loads(cached_result)
|
||||||
|
|
||||||
|
# Cache miss - execute function
|
||||||
|
logger.debug(f"Cache miss for key: {cache_key}")
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Skip caching for Response objects (like CSV downloads) as they can't be properly serialized
|
||||||
|
if not isinstance(result, Response):
|
||||||
|
# Store in cache with TTL
|
||||||
|
redis_client.setex(
|
||||||
|
cache_key,
|
||||||
|
ttl,
|
||||||
|
json.dumps(result, default=str, ensure_ascii=False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(f"Skipping cache for Response object from {func.__name__}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# If caching fails, log error and continue without caching
|
||||||
|
logger.error(f"Cache error for {func.__name__}: {e}")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def invalidate_cache(pattern: str = "*"):
|
||||||
|
"""
|
||||||
|
Invalidate cache entries matching a pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Redis pattern to match keys (default: "*" for all)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
invalidate_cache("stats:*") # Clear all stats cache
|
||||||
|
invalidate_cache("api:get_player_*") # Clear specific player cache
|
||||||
|
"""
|
||||||
|
if redis_client is None:
|
||||||
|
logger.warning("Cannot invalidate cache: Redis not available")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
keys = redis_client.keys(pattern)
|
||||||
|
if keys:
|
||||||
|
deleted = redis_client.delete(*keys)
|
||||||
|
logger.info(f"Invalidated {deleted} cache entries matching pattern: {pattern}")
|
||||||
|
return deleted
|
||||||
|
else:
|
||||||
|
logger.debug(f"No cache entries found matching pattern: {pattern}")
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error invalidating cache: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_stats() -> dict:
|
||||||
|
"""
|
||||||
|
Get Redis cache statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Cache statistics including memory usage, key count, etc.
|
||||||
|
"""
|
||||||
|
if redis_client is None:
|
||||||
|
return {"status": "unavailable", "message": "Redis not connected"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
info = redis_client.info()
|
||||||
|
return {
|
||||||
|
"status": "connected",
|
||||||
|
"memory_used": info.get("used_memory_human", "unknown"),
|
||||||
|
"total_keys": redis_client.dbsize(),
|
||||||
|
"connected_clients": info.get("connected_clients", 0),
|
||||||
|
"uptime_seconds": info.get("uptime_in_seconds", 0)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting cache stats: {e}")
|
||||||
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def add_cache_headers(
|
||||||
|
max_age: int = 300,
|
||||||
|
cache_type: str = "public",
|
||||||
|
vary: Optional[str] = None,
|
||||||
|
etag: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decorator to add HTTP cache headers to FastAPI responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_age: Cache duration in seconds (default: 5 minutes)
|
||||||
|
cache_type: "public", "private", or "no-cache" (default: "public")
|
||||||
|
vary: Vary header value (e.g., "Accept-Encoding, Authorization")
|
||||||
|
etag: Whether to generate ETag based on response content
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@add_cache_headers(max_age=1800, cache_type="public")
|
||||||
|
async def get_static_data():
|
||||||
|
return {"data": "static content"}
|
||||||
|
|
||||||
|
@add_cache_headers(max_age=60, cache_type="private", vary="Authorization")
|
||||||
|
async def get_user_data():
|
||||||
|
return {"data": "user specific"}
|
||||||
|
"""
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Handle different response types
|
||||||
|
if isinstance(result, Response):
|
||||||
|
response = result
|
||||||
|
elif isinstance(result, dict) or isinstance(result, list):
|
||||||
|
# Convert to Response with JSON content
|
||||||
|
response = Response(
|
||||||
|
content=json.dumps(result, default=str, ensure_ascii=False),
|
||||||
|
media_type="application/json"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Handle other response types
|
||||||
|
response = Response(content=str(result))
|
||||||
|
|
||||||
|
# Build Cache-Control header
|
||||||
|
cache_control_parts = [cache_type]
|
||||||
|
if cache_type != "no-cache" and max_age > 0:
|
||||||
|
cache_control_parts.append(f"max-age={max_age}")
|
||||||
|
|
||||||
|
response.headers["Cache-Control"] = ", ".join(cache_control_parts)
|
||||||
|
|
||||||
|
# Add Vary header if specified
|
||||||
|
if vary:
|
||||||
|
response.headers["Vary"] = vary
|
||||||
|
|
||||||
|
# Add ETag if requested
|
||||||
|
if etag and (hasattr(result, '__dict__') or isinstance(result, (dict, list))):
|
||||||
|
content_hash = hashlib.md5(
|
||||||
|
json.dumps(result, default=str, sort_keys=True).encode()
|
||||||
|
).hexdigest()
|
||||||
|
response.headers["ETag"] = f'"{content_hash}"'
|
||||||
|
|
||||||
|
# Add Last-Modified header with current time for dynamic content
|
||||||
|
response.headers["Last-Modified"] = datetime.datetime.now(datetime.timezone.utc).strftime(
|
||||||
|
"%a, %d %b %Y %H:%M:%S GMT"
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def handle_db_errors(func):
|
def handle_db_errors(func):
|
||||||
"""
|
"""
|
||||||
Decorator to handle database connection errors and transaction rollbacks.
|
Decorator to handle database connection errors and transaction rollbacks.
|
||||||
@ -493,7 +728,7 @@ def handle_db_errors(func):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Log sanitized arguments (avoid logging tokens, passwords, etc.)
|
# Log sanitized arguments (avoid logging tokens, passwords, etc.)
|
||||||
for i, arg in enumerate(args):
|
for arg in args:
|
||||||
if hasattr(arg, '__dict__') and hasattr(arg, 'url'): # FastAPI Request object
|
if hasattr(arg, '__dict__') and hasattr(arg, 'url'): # FastAPI Request object
|
||||||
safe_args.append(f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})")
|
safe_args.append(f"Request({getattr(arg, 'method', 'UNKNOWN')} {getattr(arg, 'url', 'unknown')})")
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user