strat-gameplay-webapp/backend/app/middleware/rate_limit.py
Cal Corum 2a392b87f8 CLAUDE: Add rate limiting, pool monitoring, and exception infrastructure
- Add rate_limit.py middleware with per-client throttling and cleanup task
- Add pool_monitor.py for database connection pool health monitoring
- Add custom exceptions module (GameEngineError, DatabaseError, etc.)
- Add config settings for eviction intervals, session timeouts, memory limits
- Add unit tests for rate limiting and pool monitoring

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-28 12:06:10 -06:00

329 lines
11 KiB
Python

"""
Rate limiting utilities for WebSocket and API endpoints.
Implements token bucket algorithm for smooth rate limiting that allows
bursts while enforcing long-term rate limits.
Key features:
- Per-connection rate limiting for WebSocket events
- Per-game rate limiting for game actions (decisions, rolls, substitutions)
- Per-user rate limiting for REST API endpoints
- Automatic cleanup of stale buckets
- Non-blocking async design
Author: Claude
Date: 2025-11-27
"""
import asyncio
import logging
from dataclasses import dataclass, field
from functools import wraps
from typing import TYPE_CHECKING, Callable
import pendulum
from app.config import get_settings
if TYPE_CHECKING:
from socketio import AsyncServer
logger = logging.getLogger(f"{__name__}.RateLimiter")
@dataclass
class RateLimitBucket:
"""
Token bucket for rate limiting.
The token bucket algorithm works by:
1. Bucket starts with max_tokens
2. Each request consumes tokens
3. Tokens refill at refill_rate per second
4. Requests are denied when tokens < 1
This allows short bursts while enforcing average rate limits.
"""
tokens: float
max_tokens: int
refill_rate: float # tokens per second
last_refill: pendulum.DateTime = field(
default_factory=lambda: pendulum.now("UTC")
)
def consume(self, tokens: int = 1) -> bool:
"""
Try to consume tokens.
Returns True if allowed, False if rate limited.
"""
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def _refill(self) -> None:
"""Refill tokens based on time elapsed since last refill."""
now = pendulum.now("UTC")
elapsed = (now - self.last_refill).total_seconds()
refill_amount = elapsed * self.refill_rate
if refill_amount > 0:
self.tokens = min(self.max_tokens, self.tokens + refill_amount)
self.last_refill = now
def get_wait_time(self) -> float:
"""
Get seconds until a token will be available.
Useful for informing clients how long to wait.
"""
if self.tokens >= 1:
return 0.0
tokens_needed = 1 - self.tokens
return tokens_needed / self.refill_rate
class RateLimiter:
"""
Rate limiter for WebSocket connections and API endpoints.
Uses token bucket algorithm for smooth rate limiting that allows
bursts while enforcing average rate limits.
Thread-safe for async operations via atomic dict operations.
"""
def __init__(self):
# Per-connection buckets (sid -> bucket)
self._connection_buckets: dict[str, RateLimitBucket] = {}
# Per-game buckets (game_id:action -> bucket)
self._game_buckets: dict[str, RateLimitBucket] = {}
# Per-user API buckets (user_id -> bucket)
self._user_buckets: dict[int, RateLimitBucket] = {}
# Cleanup task handle
self._cleanup_task: asyncio.Task | None = None
# Settings cache
self._settings = get_settings()
def get_connection_bucket(self, sid: str) -> RateLimitBucket:
"""Get or create bucket for WebSocket connection."""
if sid not in self._connection_buckets:
limit = self._settings.rate_limit_websocket_per_minute
self._connection_buckets[sid] = RateLimitBucket(
tokens=limit,
max_tokens=limit,
refill_rate=limit / 60, # per minute -> per second
)
return self._connection_buckets[sid]
def get_game_bucket(self, game_id: str, action: str) -> RateLimitBucket:
"""
Get or create bucket for game-specific action.
Actions: 'decision', 'roll', 'substitution'
"""
key = f"{game_id}:{action}"
if key not in self._game_buckets:
# Get limit based on action type
if action == "decision":
limit = self._settings.rate_limit_decision_per_game
elif action == "roll":
limit = self._settings.rate_limit_roll_per_game
elif action == "substitution":
limit = self._settings.rate_limit_substitution_per_game
else:
limit = 30 # Default for unknown actions
self._game_buckets[key] = RateLimitBucket(
tokens=limit,
max_tokens=limit,
refill_rate=limit / 60,
)
return self._game_buckets[key]
def get_user_bucket(self, user_id: int) -> RateLimitBucket:
"""Get or create bucket for API user."""
if user_id not in self._user_buckets:
limit = self._settings.rate_limit_api_per_minute
self._user_buckets[user_id] = RateLimitBucket(
tokens=limit,
max_tokens=limit,
refill_rate=limit / 60,
)
return self._user_buckets[user_id]
async def check_websocket_limit(self, sid: str) -> bool:
"""
Check if WebSocket event is allowed for connection.
Returns True if allowed, False if rate limited.
"""
bucket = self.get_connection_bucket(sid)
allowed = bucket.consume()
if not allowed:
logger.warning(f"Rate limited WebSocket connection: {sid}")
return allowed
async def check_game_limit(self, game_id: str, action: str) -> bool:
"""
Check if game action is allowed.
Returns True if allowed, False if rate limited.
"""
bucket = self.get_game_bucket(game_id, action)
allowed = bucket.consume()
if not allowed:
logger.warning(f"Rate limited game action: {game_id} {action}")
return allowed
async def check_api_limit(self, user_id: int) -> bool:
"""
Check if API call is allowed for user.
Returns True if allowed, False if rate limited.
"""
bucket = self.get_user_bucket(user_id)
allowed = bucket.consume()
if not allowed:
logger.warning(f"Rate limited API user: {user_id}")
return allowed
def remove_connection(self, sid: str) -> None:
"""Clean up buckets when connection closes."""
self._connection_buckets.pop(sid, None)
def remove_game(self, game_id: str) -> None:
"""Clean up all buckets for a game when it ends."""
keys_to_remove = [
key for key in self._game_buckets if key.startswith(f"{game_id}:")
]
for key in keys_to_remove:
del self._game_buckets[key]
async def cleanup_stale_buckets(self) -> None:
"""
Background task to periodically clean up stale buckets.
Removes buckets that haven't been used in 10 minutes to prevent
memory leaks from abandoned connections.
"""
interval = self._settings.rate_limit_cleanup_interval
stale_threshold_seconds = 600 # 10 minutes
logger.info(f"Starting rate limiter cleanup task (interval: {interval}s)")
while True:
try:
await asyncio.sleep(interval)
now = pendulum.now("UTC")
# Clean connection buckets
stale_connections = [
sid
for sid, bucket in self._connection_buckets.items()
if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds
]
for sid in stale_connections:
del self._connection_buckets[sid]
# Clean game buckets
stale_games = [
key
for key, bucket in self._game_buckets.items()
if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds
]
for key in stale_games:
del self._game_buckets[key]
# Clean user buckets
stale_users = [
user_id
for user_id, bucket in self._user_buckets.items()
if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds
]
for user_id in stale_users:
del self._user_buckets[user_id]
if stale_connections or stale_games or stale_users:
logger.debug(
f"Rate limiter cleanup: {len(stale_connections)} connections, "
f"{len(stale_games)} games, {len(stale_users)} users"
)
except asyncio.CancelledError:
logger.info("Rate limiter cleanup task cancelled")
break
except Exception as e:
logger.error(f"Rate limiter cleanup error: {e}", exc_info=True)
# Continue running despite errors
def get_stats(self) -> dict:
"""Get rate limiter statistics for monitoring."""
return {
"connection_buckets": len(self._connection_buckets),
"game_buckets": len(self._game_buckets),
"user_buckets": len(self._user_buckets),
}
# Global rate limiter instance
rate_limiter = RateLimiter()
def rate_limited(action: str = "general"):
"""
Decorator for rate-limited WebSocket handlers.
Checks both connection-level and game-level rate limits before
allowing the handler to execute.
Args:
action: Type of action for game-level limits.
Options: 'decision', 'roll', 'substitution', 'general'
'general' only applies connection-level limit.
Usage:
@sio.event
@rate_limited(action="decision")
async def submit_defensive_decision(sid, data):
...
Note: The decorator must be applied AFTER @sio.event to work correctly.
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(sid: str, data=None, *args, **kwargs):
# Import here to avoid circular imports
from app.websocket.connection_manager import ConnectionManager
# Check connection-level limit first
if not await rate_limiter.check_websocket_limit(sid):
# Get the connection manager to emit error
# Note: We can't easily access the manager here, so we'll
# handle rate limiting in the handlers themselves
logger.warning(f"Rate limited handler call from {sid}")
return {"error": "rate_limited", "message": "Rate limited. Please slow down."}
# Check game-level limit if applicable
if action != "general" and isinstance(data, dict):
game_id = data.get("game_id")
if game_id:
if not await rate_limiter.check_game_limit(str(game_id), action):
logger.warning(
f"Rate limited game action {action} for game {game_id}"
)
return {
"error": "game_rate_limited",
"message": f"Too many {action} requests for this game.",
}
return await func(sid, data, *args, **kwargs)
return wrapper
return decorator