""" Unit tests for rate limiting middleware. Tests the token bucket algorithm implementation and rate limiter functionality for WebSocket and API rate limiting. Author: Claude Date: 2025-11-27 """ import asyncio from unittest.mock import patch import pendulum import pytest from app.middleware.rate_limit import RateLimitBucket, RateLimiter, rate_limiter class TestRateLimitBucket: """Tests for the RateLimitBucket token bucket implementation.""" def test_bucket_initialization(self): """ Test that bucket initializes with correct values. Bucket should start with max_tokens available and record the creation time as last_refill. """ bucket = RateLimitBucket( tokens=10, max_tokens=10, refill_rate=1.0 ) assert bucket.tokens == 10 assert bucket.max_tokens == 10 assert bucket.refill_rate == 1.0 assert bucket.last_refill is not None def test_consume_success(self): """ Test successful token consumption. Should reduce available tokens and return True when enough tokens are available. """ bucket = RateLimitBucket( tokens=10, max_tokens=10, refill_rate=1.0 ) result = bucket.consume(1) assert result is True assert bucket.tokens < 10 # Reduced (may have refilled slightly) def test_consume_multiple_tokens(self): """ Test consuming multiple tokens at once. Should consume the requested amount if available. """ bucket = RateLimitBucket( tokens=10, max_tokens=10, refill_rate=0.0 # No refill for test ) result = bucket.consume(5) assert result is True assert bucket.tokens == 5 def test_consume_denied_when_empty(self): """ Test that consumption is denied when not enough tokens. Should return False and not reduce tokens below zero. """ bucket = RateLimitBucket( tokens=0, max_tokens=10, refill_rate=0.0 # No refill ) result = bucket.consume(1) assert result is False assert bucket.tokens == 0 # Unchanged def test_consume_denied_insufficient_tokens(self): """ Test denial when requesting more tokens than available. Should return False without modifying token count. """ bucket = RateLimitBucket( tokens=3, max_tokens=10, refill_rate=0.0 ) result = bucket.consume(5) assert result is False assert bucket.tokens == 3 # Unchanged def test_token_refill_over_time(self): """ Test that tokens refill based on elapsed time. Tokens should accumulate at refill_rate per second. """ # Start with 5 tokens past_time = pendulum.now("UTC").subtract(seconds=5) bucket = RateLimitBucket( tokens=5, max_tokens=10, refill_rate=1.0, # 1 token per second last_refill=past_time, ) # Force refill by consuming (which triggers refill check) bucket.consume(0) # Should have refilled ~5 tokens (5 seconds * 1 token/sec) assert bucket.tokens >= 9.5 # Allow small timing variance def test_tokens_capped_at_max(self): """ Test that tokens don't exceed max_tokens after refill. Token count should never go above the configured maximum. """ past_time = pendulum.now("UTC").subtract(seconds=100) bucket = RateLimitBucket( tokens=5, max_tokens=10, refill_rate=1.0, # Would add 100 tokens last_refill=past_time, ) bucket.consume(0) # Trigger refill assert bucket.tokens == 10 # Capped at max def test_get_wait_time_when_available(self): """ Test wait time calculation when tokens available. Should return 0.0 when at least one token is available. """ bucket = RateLimitBucket( tokens=5, max_tokens=10, refill_rate=1.0 ) wait_time = bucket.get_wait_time() assert wait_time == 0.0 def test_get_wait_time_when_empty(self): """ Test wait time calculation when no tokens available. Should return time needed for one token to refill. """ bucket = RateLimitBucket( tokens=0, max_tokens=10, refill_rate=2.0 # 2 tokens per second ) wait_time = bucket.get_wait_time() assert wait_time == 0.5 # 1 token / 2 tokens per second def test_get_wait_time_partial_tokens(self): """ Test wait time with partial tokens available. Should calculate time to reach 1.0 token. """ bucket = RateLimitBucket( tokens=0.5, max_tokens=10, refill_rate=1.0 ) wait_time = bucket.get_wait_time() assert wait_time == 0.5 # Need 0.5 more tokens at 1/sec class TestRateLimiter: """Tests for the RateLimiter class.""" @pytest.fixture def limiter(self): """Create a fresh RateLimiter instance for each test.""" return RateLimiter() def test_get_connection_bucket_creates_new(self, limiter): """ Test that get_connection_bucket creates bucket for new sid. Should create and return a bucket initialized with settings. """ bucket = limiter.get_connection_bucket("test_sid_1") assert bucket is not None assert bucket.max_tokens > 0 def test_get_connection_bucket_returns_existing(self, limiter): """ Test that same bucket is returned for same sid. Should return the same bucket instance for subsequent calls. """ bucket1 = limiter.get_connection_bucket("test_sid_1") bucket2 = limiter.get_connection_bucket("test_sid_1") assert bucket1 is bucket2 def test_get_connection_bucket_different_sids(self, limiter): """ Test that different sids get different buckets. Each connection should have its own rate limit bucket. """ bucket1 = limiter.get_connection_bucket("sid_1") bucket2 = limiter.get_connection_bucket("sid_2") assert bucket1 is not bucket2 def test_get_game_bucket_creates_new(self, limiter): """ Test that get_game_bucket creates bucket for new game+action. Should create bucket keyed by game_id and action type. """ bucket = limiter.get_game_bucket("game_123", "decision") assert bucket is not None assert bucket.max_tokens > 0 def test_get_game_bucket_different_actions(self, limiter): """ Test that different actions get different buckets for same game. Each action type has its own rate limit within a game. """ decision_bucket = limiter.get_game_bucket("game_123", "decision") roll_bucket = limiter.get_game_bucket("game_123", "roll") assert decision_bucket is not roll_bucket def test_get_game_bucket_different_games(self, limiter): """ Test that same action on different games gets different buckets. Each game tracks its own rate limits. """ bucket1 = limiter.get_game_bucket("game_1", "decision") bucket2 = limiter.get_game_bucket("game_2", "decision") assert bucket1 is not bucket2 def test_get_user_bucket_creates_new(self, limiter): """ Test that get_user_bucket creates bucket for new user. Should create bucket for API rate limiting per user. """ bucket = limiter.get_user_bucket(123) assert bucket is not None assert bucket.max_tokens > 0 @pytest.mark.asyncio async def test_check_websocket_limit_allowed(self, limiter): """ Test that check_websocket_limit allows requests with tokens. Should return True when tokens are available. """ result = await limiter.check_websocket_limit("test_sid") assert result is True @pytest.mark.asyncio async def test_check_websocket_limit_denied_when_exhausted(self, limiter): """ Test that check_websocket_limit denies when rate limited. Should return False after tokens are exhausted. """ sid = "rate_limited_sid" # Exhaust all tokens bucket = limiter.get_connection_bucket(sid) bucket.tokens = 0 result = await limiter.check_websocket_limit(sid) assert result is False @pytest.mark.asyncio async def test_check_game_limit_allowed(self, limiter): """ Test that check_game_limit allows requests with tokens. Should return True when game action limit not reached. """ result = await limiter.check_game_limit("game_123", "decision") assert result is True @pytest.mark.asyncio async def test_check_game_limit_denied_when_exhausted(self, limiter): """ Test that check_game_limit denies when rate limited. Should return False after game action tokens exhausted. """ game_id = "rate_limited_game" bucket = limiter.get_game_bucket(game_id, "roll") bucket.tokens = 0 result = await limiter.check_game_limit(game_id, "roll") assert result is False @pytest.mark.asyncio async def test_check_api_limit_allowed(self, limiter): """ Test that check_api_limit allows requests with tokens. Should return True when user API limit not reached. """ result = await limiter.check_api_limit(123) assert result is True @pytest.mark.asyncio async def test_check_api_limit_denied_when_exhausted(self, limiter): """ Test that check_api_limit denies when rate limited. Should return False after user API tokens exhausted. """ user_id = 999 bucket = limiter.get_user_bucket(user_id) bucket.tokens = 0 result = await limiter.check_api_limit(user_id) assert result is False def test_remove_connection_cleans_up(self, limiter): """ Test that remove_connection removes bucket for sid. Should clean up bucket when connection closes. """ sid = "cleanup_test_sid" limiter.get_connection_bucket(sid) # Create bucket assert sid in limiter._connection_buckets limiter.remove_connection(sid) assert sid not in limiter._connection_buckets def test_remove_connection_handles_missing(self, limiter): """ Test that remove_connection handles non-existent sid gracefully. Should not raise error for unknown sid. """ limiter.remove_connection("nonexistent_sid") # Should not raise def test_remove_game_cleans_up_all_actions(self, limiter): """ Test that remove_game removes all buckets for a game. Should clean up decision, roll, substitution buckets for game. """ game_id = "cleanup_game" # Create buckets for different actions limiter.get_game_bucket(game_id, "decision") limiter.get_game_bucket(game_id, "roll") limiter.get_game_bucket(game_id, "substitution") assert f"{game_id}:decision" in limiter._game_buckets assert f"{game_id}:roll" in limiter._game_buckets assert f"{game_id}:substitution" in limiter._game_buckets limiter.remove_game(game_id) assert f"{game_id}:decision" not in limiter._game_buckets assert f"{game_id}:roll" not in limiter._game_buckets assert f"{game_id}:substitution" not in limiter._game_buckets def test_remove_game_preserves_other_games(self, limiter): """ Test that remove_game only removes specified game's buckets. Other games' buckets should remain intact. """ limiter.get_game_bucket("game_1", "decision") limiter.get_game_bucket("game_2", "decision") limiter.remove_game("game_1") assert "game_1:decision" not in limiter._game_buckets assert "game_2:decision" in limiter._game_buckets def test_get_stats_returns_counts(self, limiter): """ Test that get_stats returns bucket counts. Should return counts for all bucket types. """ limiter.get_connection_bucket("sid_1") limiter.get_connection_bucket("sid_2") limiter.get_game_bucket("game_1", "decision") limiter.get_user_bucket(123) stats = limiter.get_stats() assert stats["connection_buckets"] == 2 assert stats["game_buckets"] == 1 assert stats["user_buckets"] == 1 class TestCleanupTask: """Tests for the stale bucket cleanup functionality.""" @pytest.mark.asyncio async def test_cleanup_removes_stale_connection_buckets(self): """ Test that cleanup removes connection buckets older than threshold. Buckets unused for >10 minutes should be cleaned up. """ limiter = RateLimiter() # Create a bucket with old last_refill time (>10 minutes ago) stale_time = pendulum.now("UTC").subtract(minutes=15) bucket = limiter.get_connection_bucket("stale_sid") bucket.last_refill = stale_time # Create a fresh bucket fresh_bucket = limiter.get_connection_bucket("fresh_sid") # Manually trigger cleanup logic (without running full async task) now = pendulum.now("UTC") stale_sids = [ sid for sid, b in limiter._connection_buckets.items() if (now - b.last_refill).total_seconds() > 600 ] for sid in stale_sids: del limiter._connection_buckets[sid] assert "stale_sid" not in limiter._connection_buckets assert "fresh_sid" in limiter._connection_buckets @pytest.mark.asyncio async def test_cleanup_removes_stale_game_buckets(self): """ Test that cleanup removes game buckets older than threshold. Abandoned game buckets should be cleaned up. """ limiter = RateLimiter() stale_time = pendulum.now("UTC").subtract(minutes=15) bucket = limiter.get_game_bucket("stale_game", "decision") bucket.last_refill = stale_time limiter.get_game_bucket("fresh_game", "decision") # Manually trigger cleanup now = pendulum.now("UTC") stale_keys = [ key for key, b in limiter._game_buckets.items() if (now - b.last_refill).total_seconds() > 600 ] for key in stale_keys: del limiter._game_buckets[key] assert "stale_game:decision" not in limiter._game_buckets assert "fresh_game:decision" in limiter._game_buckets class TestGlobalRateLimiter: """Tests for the global rate_limiter instance.""" def test_global_instance_exists(self): """ Test that global rate_limiter instance is available. Should be importable and usable as singleton. """ assert rate_limiter is not None assert isinstance(rate_limiter, RateLimiter) @pytest.mark.asyncio async def test_global_instance_functional(self): """ Test that global rate_limiter works correctly. Should accept and rate limit requests. """ result = await rate_limiter.check_websocket_limit("global_test_sid") assert result is True # Cleanup rate_limiter.remove_connection("global_test_sid") class TestRateLimitConfiguration: """Tests for rate limit configuration integration.""" def test_connection_bucket_uses_config(self): """ Test that connection bucket uses configured limits. max_tokens should match rate_limit_websocket_per_minute setting. """ limiter = RateLimiter() bucket = limiter.get_connection_bucket("config_test_sid") # Default is 120 per minute assert bucket.max_tokens == limiter._settings.rate_limit_websocket_per_minute limiter.remove_connection("config_test_sid") def test_decision_bucket_uses_config(self): """ Test that decision bucket uses configured limits. max_tokens should match rate_limit_decision_per_game setting. """ limiter = RateLimiter() bucket = limiter.get_game_bucket("config_game", "decision") assert bucket.max_tokens == limiter._settings.rate_limit_decision_per_game limiter.remove_game("config_game") def test_roll_bucket_uses_config(self): """ Test that roll bucket uses configured limits. max_tokens should match rate_limit_roll_per_game setting. """ limiter = RateLimiter() bucket = limiter.get_game_bucket("config_game", "roll") assert bucket.max_tokens == limiter._settings.rate_limit_roll_per_game limiter.remove_game("config_game") def test_substitution_bucket_uses_config(self): """ Test that substitution bucket uses configured limits. max_tokens should match rate_limit_substitution_per_game setting. """ limiter = RateLimiter() bucket = limiter.get_game_bucket("config_game", "substitution") assert bucket.max_tokens == limiter._settings.rate_limit_substitution_per_game limiter.remove_game("config_game") def test_api_bucket_uses_config(self): """ Test that API bucket uses configured limits. max_tokens should match rate_limit_api_per_minute setting. """ limiter = RateLimiter() bucket = limiter.get_user_bucket(12345) assert bucket.max_tokens == limiter._settings.rate_limit_api_per_minute del limiter._user_buckets[12345]