- Add rate_limit.py middleware with per-client throttling and cleanup task - Add pool_monitor.py for database connection pool health monitoring - Add custom exceptions module (GameEngineError, DatabaseError, etc.) - Add config settings for eviction intervals, session timeouts, memory limits - Add unit tests for rate limiting and pool monitoring 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
329 lines
11 KiB
Python
329 lines
11 KiB
Python
"""
|
|
Rate limiting utilities for WebSocket and API endpoints.
|
|
|
|
Implements token bucket algorithm for smooth rate limiting that allows
|
|
bursts while enforcing long-term rate limits.
|
|
|
|
Key features:
|
|
- Per-connection rate limiting for WebSocket events
|
|
- Per-game rate limiting for game actions (decisions, rolls, substitutions)
|
|
- Per-user rate limiting for REST API endpoints
|
|
- Automatic cleanup of stale buckets
|
|
- Non-blocking async design
|
|
|
|
Author: Claude
|
|
Date: 2025-11-27
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from functools import wraps
|
|
from typing import TYPE_CHECKING, Callable
|
|
|
|
import pendulum
|
|
|
|
from app.config import get_settings
|
|
|
|
if TYPE_CHECKING:
|
|
from socketio import AsyncServer
|
|
|
|
logger = logging.getLogger(f"{__name__}.RateLimiter")
|
|
|
|
|
|
@dataclass
|
|
class RateLimitBucket:
|
|
"""
|
|
Token bucket for rate limiting.
|
|
|
|
The token bucket algorithm works by:
|
|
1. Bucket starts with max_tokens
|
|
2. Each request consumes tokens
|
|
3. Tokens refill at refill_rate per second
|
|
4. Requests are denied when tokens < 1
|
|
|
|
This allows short bursts while enforcing average rate limits.
|
|
"""
|
|
|
|
tokens: float
|
|
max_tokens: int
|
|
refill_rate: float # tokens per second
|
|
last_refill: pendulum.DateTime = field(
|
|
default_factory=lambda: pendulum.now("UTC")
|
|
)
|
|
|
|
def consume(self, tokens: int = 1) -> bool:
|
|
"""
|
|
Try to consume tokens.
|
|
|
|
Returns True if allowed, False if rate limited.
|
|
"""
|
|
self._refill()
|
|
if self.tokens >= tokens:
|
|
self.tokens -= tokens
|
|
return True
|
|
return False
|
|
|
|
def _refill(self) -> None:
|
|
"""Refill tokens based on time elapsed since last refill."""
|
|
now = pendulum.now("UTC")
|
|
elapsed = (now - self.last_refill).total_seconds()
|
|
refill_amount = elapsed * self.refill_rate
|
|
|
|
if refill_amount > 0:
|
|
self.tokens = min(self.max_tokens, self.tokens + refill_amount)
|
|
self.last_refill = now
|
|
|
|
def get_wait_time(self) -> float:
|
|
"""
|
|
Get seconds until a token will be available.
|
|
|
|
Useful for informing clients how long to wait.
|
|
"""
|
|
if self.tokens >= 1:
|
|
return 0.0
|
|
tokens_needed = 1 - self.tokens
|
|
return tokens_needed / self.refill_rate
|
|
|
|
|
|
class RateLimiter:
|
|
"""
|
|
Rate limiter for WebSocket connections and API endpoints.
|
|
|
|
Uses token bucket algorithm for smooth rate limiting that allows
|
|
bursts while enforcing average rate limits.
|
|
|
|
Thread-safe for async operations via atomic dict operations.
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Per-connection buckets (sid -> bucket)
|
|
self._connection_buckets: dict[str, RateLimitBucket] = {}
|
|
# Per-game buckets (game_id:action -> bucket)
|
|
self._game_buckets: dict[str, RateLimitBucket] = {}
|
|
# Per-user API buckets (user_id -> bucket)
|
|
self._user_buckets: dict[int, RateLimitBucket] = {}
|
|
# Cleanup task handle
|
|
self._cleanup_task: asyncio.Task | None = None
|
|
# Settings cache
|
|
self._settings = get_settings()
|
|
|
|
def get_connection_bucket(self, sid: str) -> RateLimitBucket:
|
|
"""Get or create bucket for WebSocket connection."""
|
|
if sid not in self._connection_buckets:
|
|
limit = self._settings.rate_limit_websocket_per_minute
|
|
self._connection_buckets[sid] = RateLimitBucket(
|
|
tokens=limit,
|
|
max_tokens=limit,
|
|
refill_rate=limit / 60, # per minute -> per second
|
|
)
|
|
return self._connection_buckets[sid]
|
|
|
|
def get_game_bucket(self, game_id: str, action: str) -> RateLimitBucket:
|
|
"""
|
|
Get or create bucket for game-specific action.
|
|
|
|
Actions: 'decision', 'roll', 'substitution'
|
|
"""
|
|
key = f"{game_id}:{action}"
|
|
if key not in self._game_buckets:
|
|
# Get limit based on action type
|
|
if action == "decision":
|
|
limit = self._settings.rate_limit_decision_per_game
|
|
elif action == "roll":
|
|
limit = self._settings.rate_limit_roll_per_game
|
|
elif action == "substitution":
|
|
limit = self._settings.rate_limit_substitution_per_game
|
|
else:
|
|
limit = 30 # Default for unknown actions
|
|
|
|
self._game_buckets[key] = RateLimitBucket(
|
|
tokens=limit,
|
|
max_tokens=limit,
|
|
refill_rate=limit / 60,
|
|
)
|
|
return self._game_buckets[key]
|
|
|
|
def get_user_bucket(self, user_id: int) -> RateLimitBucket:
|
|
"""Get or create bucket for API user."""
|
|
if user_id not in self._user_buckets:
|
|
limit = self._settings.rate_limit_api_per_minute
|
|
self._user_buckets[user_id] = RateLimitBucket(
|
|
tokens=limit,
|
|
max_tokens=limit,
|
|
refill_rate=limit / 60,
|
|
)
|
|
return self._user_buckets[user_id]
|
|
|
|
async def check_websocket_limit(self, sid: str) -> bool:
|
|
"""
|
|
Check if WebSocket event is allowed for connection.
|
|
|
|
Returns True if allowed, False if rate limited.
|
|
"""
|
|
bucket = self.get_connection_bucket(sid)
|
|
allowed = bucket.consume()
|
|
if not allowed:
|
|
logger.warning(f"Rate limited WebSocket connection: {sid}")
|
|
return allowed
|
|
|
|
async def check_game_limit(self, game_id: str, action: str) -> bool:
|
|
"""
|
|
Check if game action is allowed.
|
|
|
|
Returns True if allowed, False if rate limited.
|
|
"""
|
|
bucket = self.get_game_bucket(game_id, action)
|
|
allowed = bucket.consume()
|
|
if not allowed:
|
|
logger.warning(f"Rate limited game action: {game_id} {action}")
|
|
return allowed
|
|
|
|
async def check_api_limit(self, user_id: int) -> bool:
|
|
"""
|
|
Check if API call is allowed for user.
|
|
|
|
Returns True if allowed, False if rate limited.
|
|
"""
|
|
bucket = self.get_user_bucket(user_id)
|
|
allowed = bucket.consume()
|
|
if not allowed:
|
|
logger.warning(f"Rate limited API user: {user_id}")
|
|
return allowed
|
|
|
|
def remove_connection(self, sid: str) -> None:
|
|
"""Clean up buckets when connection closes."""
|
|
self._connection_buckets.pop(sid, None)
|
|
|
|
def remove_game(self, game_id: str) -> None:
|
|
"""Clean up all buckets for a game when it ends."""
|
|
keys_to_remove = [
|
|
key for key in self._game_buckets if key.startswith(f"{game_id}:")
|
|
]
|
|
for key in keys_to_remove:
|
|
del self._game_buckets[key]
|
|
|
|
async def cleanup_stale_buckets(self) -> None:
|
|
"""
|
|
Background task to periodically clean up stale buckets.
|
|
|
|
Removes buckets that haven't been used in 10 minutes to prevent
|
|
memory leaks from abandoned connections.
|
|
"""
|
|
interval = self._settings.rate_limit_cleanup_interval
|
|
stale_threshold_seconds = 600 # 10 minutes
|
|
|
|
logger.info(f"Starting rate limiter cleanup task (interval: {interval}s)")
|
|
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(interval)
|
|
now = pendulum.now("UTC")
|
|
|
|
# Clean connection buckets
|
|
stale_connections = [
|
|
sid
|
|
for sid, bucket in self._connection_buckets.items()
|
|
if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds
|
|
]
|
|
for sid in stale_connections:
|
|
del self._connection_buckets[sid]
|
|
|
|
# Clean game buckets
|
|
stale_games = [
|
|
key
|
|
for key, bucket in self._game_buckets.items()
|
|
if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds
|
|
]
|
|
for key in stale_games:
|
|
del self._game_buckets[key]
|
|
|
|
# Clean user buckets
|
|
stale_users = [
|
|
user_id
|
|
for user_id, bucket in self._user_buckets.items()
|
|
if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds
|
|
]
|
|
for user_id in stale_users:
|
|
del self._user_buckets[user_id]
|
|
|
|
if stale_connections or stale_games or stale_users:
|
|
logger.debug(
|
|
f"Rate limiter cleanup: {len(stale_connections)} connections, "
|
|
f"{len(stale_games)} games, {len(stale_users)} users"
|
|
)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("Rate limiter cleanup task cancelled")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Rate limiter cleanup error: {e}", exc_info=True)
|
|
# Continue running despite errors
|
|
|
|
def get_stats(self) -> dict:
|
|
"""Get rate limiter statistics for monitoring."""
|
|
return {
|
|
"connection_buckets": len(self._connection_buckets),
|
|
"game_buckets": len(self._game_buckets),
|
|
"user_buckets": len(self._user_buckets),
|
|
}
|
|
|
|
|
|
# Global rate limiter instance
|
|
rate_limiter = RateLimiter()
|
|
|
|
|
|
def rate_limited(action: str = "general"):
|
|
"""
|
|
Decorator for rate-limited WebSocket handlers.
|
|
|
|
Checks both connection-level and game-level rate limits before
|
|
allowing the handler to execute.
|
|
|
|
Args:
|
|
action: Type of action for game-level limits.
|
|
Options: 'decision', 'roll', 'substitution', 'general'
|
|
'general' only applies connection-level limit.
|
|
|
|
Usage:
|
|
@sio.event
|
|
@rate_limited(action="decision")
|
|
async def submit_defensive_decision(sid, data):
|
|
...
|
|
|
|
Note: The decorator must be applied AFTER @sio.event to work correctly.
|
|
"""
|
|
|
|
def decorator(func: Callable):
|
|
@wraps(func)
|
|
async def wrapper(sid: str, data=None, *args, **kwargs):
|
|
# Import here to avoid circular imports
|
|
from app.websocket.connection_manager import ConnectionManager
|
|
|
|
# Check connection-level limit first
|
|
if not await rate_limiter.check_websocket_limit(sid):
|
|
# Get the connection manager to emit error
|
|
# Note: We can't easily access the manager here, so we'll
|
|
# handle rate limiting in the handlers themselves
|
|
logger.warning(f"Rate limited handler call from {sid}")
|
|
return {"error": "rate_limited", "message": "Rate limited. Please slow down."}
|
|
|
|
# Check game-level limit if applicable
|
|
if action != "general" and isinstance(data, dict):
|
|
game_id = data.get("game_id")
|
|
if game_id:
|
|
if not await rate_limiter.check_game_limit(str(game_id), action):
|
|
logger.warning(
|
|
f"Rate limited game action {action} for game {game_id}"
|
|
)
|
|
return {
|
|
"error": "game_rate_limited",
|
|
"message": f"Too many {action} requests for this game.",
|
|
}
|
|
|
|
return await func(sid, data, *args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|