# Plan 006: Add Rate Limiting **Priority**: HIGH **Effort**: 2-3 hours **Status**: NOT STARTED **Risk Level**: MEDIUM - DoS vulnerability --- ## Problem Statement No rate limiting exists on WebSocket events or REST API endpoints. A malicious or buggy client can: - Spam decision submissions - Flood dice roll requests - Overwhelm the server with requests - Cause denial of service ## Impact - **Availability**: Server can be overwhelmed - **Fairness**: Spammers can disrupt games - **Cost**: Excessive resource usage ## Files to Modify/Create | File | Action | |------|--------| | `backend/app/middleware/rate_limit.py` | Create rate limiter | | `backend/app/websocket/handlers.py` | Add rate limit checks | | `backend/app/api/routes.py` | Add rate limit decorator | | `backend/app/config.py` | Add rate limit settings | ## Implementation Steps ### Step 1: Add Configuration (10 min) Update `backend/app/config.py`: ```python class Settings(BaseSettings): # ... existing settings ... # Rate limiting rate_limit_websocket_per_minute: int = 60 # Events per minute per connection rate_limit_api_per_minute: int = 100 # API calls per minute per user rate_limit_decision_per_game: int = 10 # Decisions per minute per game rate_limit_roll_per_game: int = 20 # Rolls per minute per game ``` ### Step 2: Create Rate Limiter (45 min) Create `backend/app/middleware/rate_limit.py`: ```python """Rate limiting utilities for WebSocket and API endpoints.""" import asyncio from collections import defaultdict from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Callable import logging from app.config import settings logger = logging.getLogger(f"{__name__}.RateLimiter") @dataclass class RateLimitBucket: """Token bucket for rate limiting.""" tokens: int max_tokens: int refill_rate: float # tokens per second last_refill: datetime = field(default_factory=datetime.utcnow) def consume(self, tokens: int = 1) -> bool: """ Try to consume tokens. Returns True if allowed, False if rate limited. """ self._refill() if self.tokens >= tokens: self.tokens -= tokens return True return False def _refill(self): """Refill tokens based on time elapsed.""" now = datetime.utcnow() elapsed = (now - self.last_refill).total_seconds() refill_amount = int(elapsed * self.refill_rate) if refill_amount > 0: self.tokens = min(self.max_tokens, self.tokens + refill_amount) self.last_refill = now class RateLimiter: """ Rate limiter for WebSocket connections and API endpoints. Uses token bucket algorithm for smooth rate limiting. """ def __init__(self): # Per-connection buckets self._connection_buckets: dict[str, RateLimitBucket] = {} # Per-game buckets (for game-specific limits) self._game_buckets: dict[str, RateLimitBucket] = {} # Per-user API buckets self._user_buckets: dict[int, RateLimitBucket] = {} # Cleanup task self._cleanup_task: asyncio.Task | None = None def get_connection_bucket(self, sid: str) -> RateLimitBucket: """Get or create bucket for WebSocket connection.""" if sid not in self._connection_buckets: self._connection_buckets[sid] = RateLimitBucket( tokens=settings.rate_limit_websocket_per_minute, max_tokens=settings.rate_limit_websocket_per_minute, refill_rate=settings.rate_limit_websocket_per_minute / 60 ) return self._connection_buckets[sid] def get_game_bucket(self, game_id: str, action: str) -> RateLimitBucket: """Get or create bucket for game-specific action.""" key = f"{game_id}:{action}" if key not in self._game_buckets: if action == "decision": limit = settings.rate_limit_decision_per_game elif action == "roll": limit = settings.rate_limit_roll_per_game else: limit = 30 # Default self._game_buckets[key] = RateLimitBucket( tokens=limit, max_tokens=limit, refill_rate=limit / 60 ) return self._game_buckets[key] def get_user_bucket(self, user_id: int) -> RateLimitBucket: """Get or create bucket for API user.""" if user_id not in self._user_buckets: self._user_buckets[user_id] = RateLimitBucket( tokens=settings.rate_limit_api_per_minute, max_tokens=settings.rate_limit_api_per_minute, refill_rate=settings.rate_limit_api_per_minute / 60 ) return self._user_buckets[user_id] async def check_websocket_limit(self, sid: str) -> bool: """Check if WebSocket event is allowed.""" bucket = self.get_connection_bucket(sid) allowed = bucket.consume() if not allowed: logger.warning(f"Rate limited WebSocket connection: {sid}") return allowed async def check_game_limit(self, game_id: str, action: str) -> bool: """Check if game action is allowed.""" bucket = self.get_game_bucket(game_id, action) allowed = bucket.consume() if not allowed: logger.warning(f"Rate limited game action: {game_id} {action}") return allowed async def check_api_limit(self, user_id: int) -> bool: """Check if API call is allowed.""" bucket = self.get_user_bucket(user_id) allowed = bucket.consume() if not allowed: logger.warning(f"Rate limited API user: {user_id}") return allowed def remove_connection(self, sid: str): """Clean up when connection closes.""" self._connection_buckets.pop(sid, None) async def cleanup_stale_buckets(self): """Periodically clean up stale buckets.""" while True: await asyncio.sleep(300) # Every 5 minutes now = datetime.utcnow() stale_threshold = timedelta(minutes=10) # Clean connection buckets stale_connections = [ sid for sid, bucket in self._connection_buckets.items() if now - bucket.last_refill > stale_threshold ] for sid in stale_connections: del self._connection_buckets[sid] # Clean game buckets stale_games = [ key for key, bucket in self._game_buckets.items() if now - bucket.last_refill > stale_threshold ] for key in stale_games: del self._game_buckets[key] logger.debug(f"Cleaned {len(stale_connections)} connection, {len(stale_games)} game buckets") # Global rate limiter instance rate_limiter = RateLimiter() ``` ### Step 3: Create Decorator for Handlers (20 min) Add to `backend/app/middleware/rate_limit.py`: ```python from functools import wraps def rate_limited(action: str = "general"): """ Decorator for rate-limited WebSocket handlers. Usage: @sio.event @rate_limited(action="decision") async def submit_defensive_decision(sid, data): ... """ def decorator(func: Callable): @wraps(func) async def wrapper(sid, data, *args, **kwargs): # Check connection-level limit if not await rate_limiter.check_websocket_limit(sid): await sio.emit("error", { "message": "Rate limited. Please slow down.", "code": "RATE_LIMITED" }, to=sid) return # Check game-level limit if game_id in data game_id = data.get("game_id") if isinstance(data, dict) else None if game_id and action != "general": if not await rate_limiter.check_game_limit(str(game_id), action): await sio.emit("error", { "message": f"Too many {action} requests for this game.", "code": "GAME_RATE_LIMITED" }, to=sid) return return await func(sid, data, *args, **kwargs) return wrapper return decorator ``` ### Step 4: Apply to WebSocket Handlers (30 min) Update `backend/app/websocket/handlers.py`: ```python from app.middleware.rate_limit import rate_limited, rate_limiter @sio.event async def connect(sid, environ, auth): # ... existing logic ... pass @sio.event async def disconnect(sid): # Clean up rate limiter rate_limiter.remove_connection(sid) # ... existing logic ... @sio.event @rate_limited(action="decision") async def submit_defensive_decision(sid, data): # ... existing logic (rate limiting handled by decorator) ... pass @sio.event @rate_limited(action="decision") async def submit_offensive_decision(sid, data): # ... existing logic ... pass @sio.event @rate_limited(action="roll") async def roll_dice(sid, data): # ... existing logic ... pass @sio.event @rate_limited(action="substitution") async def request_pinch_hitter(sid, data): # ... existing logic ... pass @sio.event @rate_limited(action="substitution") async def request_defensive_replacement(sid, data): # ... existing logic ... pass @sio.event @rate_limited(action="substitution") async def request_pitching_change(sid, data): # ... existing logic ... pass # Read-only handlers get general rate limit @sio.event @rate_limited() async def get_lineup(sid, data): # ... existing logic ... pass @sio.event @rate_limited() async def get_box_score(sid, data): # ... existing logic ... pass ``` ### Step 5: Add API Rate Limiting (20 min) Update `backend/app/api/routes.py`: ```python from fastapi import Depends, HTTPException from app.middleware.rate_limit import rate_limiter async def check_rate_limit(user_id: int = Depends(get_current_user_id)): """Dependency for API rate limiting.""" if not await rate_limiter.check_api_limit(user_id): raise HTTPException( status_code=429, detail="Rate limit exceeded. Please try again later." ) return user_id @router.post("/games", dependencies=[Depends(check_rate_limit)]) async def create_game(...): # ... existing logic ... pass @router.get("/games/{game_id}", dependencies=[Depends(check_rate_limit)]) async def get_game(...): # ... existing logic ... pass ``` ### Step 6: Start Cleanup Task (10 min) Update `backend/app/main.py`: ```python from app.middleware.rate_limit import rate_limiter @asynccontextmanager async def lifespan(app: FastAPI): # Start rate limiter cleanup cleanup_task = asyncio.create_task(rate_limiter.cleanup_stale_buckets()) yield # Stop cleanup task cleanup_task.cancel() ``` ### Step 7: Write Tests (30 min) Create `backend/tests/unit/middleware/test_rate_limit.py`: ```python import pytest from app.middleware.rate_limit import RateLimiter, RateLimitBucket class TestRateLimiting: """Tests for rate limiting.""" def test_bucket_allows_under_limit(self): """Bucket allows requests under limit.""" bucket = RateLimitBucket(tokens=10, max_tokens=10, refill_rate=1) assert bucket.consume() is True assert bucket.tokens == 9 def test_bucket_denies_over_limit(self): """Bucket denies requests over limit.""" bucket = RateLimitBucket(tokens=1, max_tokens=10, refill_rate=0.1) assert bucket.consume() is True assert bucket.consume() is False def test_bucket_refills_over_time(self): """Bucket refills tokens over time.""" bucket = RateLimitBucket(tokens=0, max_tokens=10, refill_rate=100) # Simulate time passing bucket.last_refill = bucket.last_refill.replace( second=bucket.last_refill.second - 1 ) bucket._refill() assert bucket.tokens > 0 @pytest.mark.asyncio async def test_rate_limiter_tracks_connections(self): """Rate limiter tracks separate connections.""" limiter = RateLimiter() # Different connections get different buckets bucket1 = limiter.get_connection_bucket("sid1") bucket2 = limiter.get_connection_bucket("sid2") assert bucket1 is not bucket2 @pytest.mark.asyncio async def test_rate_limiter_cleans_up_on_disconnect(self): """Rate limiter cleans up on disconnect.""" limiter = RateLimiter() limiter.get_connection_bucket("sid1") assert "sid1" in limiter._connection_buckets limiter.remove_connection("sid1") assert "sid1" not in limiter._connection_buckets ``` ## Verification Checklist - [ ] WebSocket events are rate limited - [ ] Game-specific limits work (decisions, rolls) - [ ] API endpoints are rate limited - [ ] Rate limit errors return clear messages - [ ] Cleanup removes stale buckets - [ ] Tests pass ## Monitoring After deployment, monitor: - Rate limit hit frequency in logs - Memory usage of rate limiter - False positive rate (legitimate users blocked) ## Rollback Plan If issues arise: 1. Increase rate limits in config 2. Disable decorator temporarily 3. Remove rate limit checks from handlers ## Dependencies - None (can be implemented independently) ## Notes - Consider Redis-backed rate limiting for horizontal scaling - May want different limits for authenticated vs anonymous - Future: Add configurable rate limits per user tier