""" Rate limiting utilities for WebSocket and API endpoints. Implements token bucket algorithm for smooth rate limiting that allows bursts while enforcing long-term rate limits. Key features: - Per-connection rate limiting for WebSocket events - Per-game rate limiting for game actions (decisions, rolls, substitutions) - Per-user rate limiting for REST API endpoints - Automatic cleanup of stale buckets - Non-blocking async design Author: Claude Date: 2025-11-27 """ import asyncio import logging from dataclasses import dataclass, field from functools import wraps from typing import TYPE_CHECKING, Callable import pendulum from app.config import get_settings if TYPE_CHECKING: from socketio import AsyncServer logger = logging.getLogger(f"{__name__}.RateLimiter") @dataclass class RateLimitBucket: """ Token bucket for rate limiting. The token bucket algorithm works by: 1. Bucket starts with max_tokens 2. Each request consumes tokens 3. Tokens refill at refill_rate per second 4. Requests are denied when tokens < 1 This allows short bursts while enforcing average rate limits. """ tokens: float max_tokens: int refill_rate: float # tokens per second last_refill: pendulum.DateTime = field( default_factory=lambda: pendulum.now("UTC") ) def consume(self, tokens: int = 1) -> bool: """ Try to consume tokens. Returns True if allowed, False if rate limited. """ self._refill() if self.tokens >= tokens: self.tokens -= tokens return True return False def _refill(self) -> None: """Refill tokens based on time elapsed since last refill.""" now = pendulum.now("UTC") elapsed = (now - self.last_refill).total_seconds() refill_amount = elapsed * self.refill_rate if refill_amount > 0: self.tokens = min(self.max_tokens, self.tokens + refill_amount) self.last_refill = now def get_wait_time(self) -> float: """ Get seconds until a token will be available. Useful for informing clients how long to wait. """ if self.tokens >= 1: return 0.0 tokens_needed = 1 - self.tokens return tokens_needed / self.refill_rate class RateLimiter: """ Rate limiter for WebSocket connections and API endpoints. Uses token bucket algorithm for smooth rate limiting that allows bursts while enforcing average rate limits. Thread-safe for async operations via atomic dict operations. """ def __init__(self): # Per-connection buckets (sid -> bucket) self._connection_buckets: dict[str, RateLimitBucket] = {} # Per-game buckets (game_id:action -> bucket) self._game_buckets: dict[str, RateLimitBucket] = {} # Per-user API buckets (user_id -> bucket) self._user_buckets: dict[int, RateLimitBucket] = {} # Cleanup task handle self._cleanup_task: asyncio.Task | None = None # Settings cache self._settings = get_settings() def get_connection_bucket(self, sid: str) -> RateLimitBucket: """Get or create bucket for WebSocket connection.""" if sid not in self._connection_buckets: limit = self._settings.rate_limit_websocket_per_minute self._connection_buckets[sid] = RateLimitBucket( tokens=limit, max_tokens=limit, refill_rate=limit / 60, # per minute -> per second ) return self._connection_buckets[sid] def get_game_bucket(self, game_id: str, action: str) -> RateLimitBucket: """ Get or create bucket for game-specific action. Actions: 'decision', 'roll', 'substitution' """ key = f"{game_id}:{action}" if key not in self._game_buckets: # Get limit based on action type if action == "decision": limit = self._settings.rate_limit_decision_per_game elif action == "roll": limit = self._settings.rate_limit_roll_per_game elif action == "substitution": limit = self._settings.rate_limit_substitution_per_game else: limit = 30 # Default for unknown actions self._game_buckets[key] = RateLimitBucket( tokens=limit, max_tokens=limit, refill_rate=limit / 60, ) return self._game_buckets[key] def get_user_bucket(self, user_id: int) -> RateLimitBucket: """Get or create bucket for API user.""" if user_id not in self._user_buckets: limit = self._settings.rate_limit_api_per_minute self._user_buckets[user_id] = RateLimitBucket( tokens=limit, max_tokens=limit, refill_rate=limit / 60, ) return self._user_buckets[user_id] async def check_websocket_limit(self, sid: str) -> bool: """ Check if WebSocket event is allowed for connection. Returns True if allowed, False if rate limited. """ bucket = self.get_connection_bucket(sid) allowed = bucket.consume() if not allowed: logger.warning(f"Rate limited WebSocket connection: {sid}") return allowed async def check_game_limit(self, game_id: str, action: str) -> bool: """ Check if game action is allowed. Returns True if allowed, False if rate limited. """ bucket = self.get_game_bucket(game_id, action) allowed = bucket.consume() if not allowed: logger.warning(f"Rate limited game action: {game_id} {action}") return allowed async def check_api_limit(self, user_id: int) -> bool: """ Check if API call is allowed for user. Returns True if allowed, False if rate limited. """ bucket = self.get_user_bucket(user_id) allowed = bucket.consume() if not allowed: logger.warning(f"Rate limited API user: {user_id}") return allowed def remove_connection(self, sid: str) -> None: """Clean up buckets when connection closes.""" self._connection_buckets.pop(sid, None) def remove_game(self, game_id: str) -> None: """Clean up all buckets for a game when it ends.""" keys_to_remove = [ key for key in self._game_buckets if key.startswith(f"{game_id}:") ] for key in keys_to_remove: del self._game_buckets[key] async def cleanup_stale_buckets(self) -> None: """ Background task to periodically clean up stale buckets. Removes buckets that haven't been used in 10 minutes to prevent memory leaks from abandoned connections. """ interval = self._settings.rate_limit_cleanup_interval stale_threshold_seconds = 600 # 10 minutes logger.info(f"Starting rate limiter cleanup task (interval: {interval}s)") while True: try: await asyncio.sleep(interval) now = pendulum.now("UTC") # Clean connection buckets stale_connections = [ sid for sid, bucket in self._connection_buckets.items() if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds ] for sid in stale_connections: del self._connection_buckets[sid] # Clean game buckets stale_games = [ key for key, bucket in self._game_buckets.items() if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds ] for key in stale_games: del self._game_buckets[key] # Clean user buckets stale_users = [ user_id for user_id, bucket in self._user_buckets.items() if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds ] for user_id in stale_users: del self._user_buckets[user_id] if stale_connections or stale_games or stale_users: logger.debug( f"Rate limiter cleanup: {len(stale_connections)} connections, " f"{len(stale_games)} games, {len(stale_users)} users" ) except asyncio.CancelledError: logger.info("Rate limiter cleanup task cancelled") break except Exception as e: logger.error(f"Rate limiter cleanup error: {e}", exc_info=True) # Continue running despite errors def get_stats(self) -> dict: """Get rate limiter statistics for monitoring.""" return { "connection_buckets": len(self._connection_buckets), "game_buckets": len(self._game_buckets), "user_buckets": len(self._user_buckets), } # Global rate limiter instance rate_limiter = RateLimiter() def rate_limited(action: str = "general"): """ Decorator for rate-limited WebSocket handlers. Checks both connection-level and game-level rate limits before allowing the handler to execute. Args: action: Type of action for game-level limits. Options: 'decision', 'roll', 'substitution', 'general' 'general' only applies connection-level limit. Usage: @sio.event @rate_limited(action="decision") async def submit_defensive_decision(sid, data): ... Note: The decorator must be applied AFTER @sio.event to work correctly. """ def decorator(func: Callable): @wraps(func) async def wrapper(sid: str, data=None, *args, **kwargs): # Import here to avoid circular imports from app.websocket.connection_manager import ConnectionManager # Check connection-level limit first if not await rate_limiter.check_websocket_limit(sid): # Get the connection manager to emit error # Note: We can't easily access the manager here, so we'll # handle rate limiting in the handlers themselves logger.warning(f"Rate limited handler call from {sid}") return {"error": "rate_limited", "message": "Rate limited. Please slow down."} # Check game-level limit if applicable if action != "general" and isinstance(data, dict): game_id = data.get("game_id") if game_id: if not await rate_limiter.check_game_limit(str(game_id), action): logger.warning( f"Rate limited game action {action} for game {game_id}" ) return { "error": "game_rate_limited", "message": f"Too many {action} requests for this game.", } return await func(sid, data, *args, **kwargs) return wrapper return decorator