strat-gameplay-webapp/backend/tests/unit/middleware/test_rate_limit.py
Cal Corum 2a392b87f8 CLAUDE: Add rate limiting, pool monitoring, and exception infrastructure
- Add rate_limit.py middleware with per-client throttling and cleanup task
- Add pool_monitor.py for database connection pool health monitoring
- Add custom exceptions module (GameEngineError, DatabaseError, etc.)
- Add config settings for eviction intervals, session timeouts, memory limits
- Add unit tests for rate limiting and pool monitoring

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-28 12:06:10 -06:00

584 lines
18 KiB
Python

"""
Unit tests for rate limiting middleware.
Tests the token bucket algorithm implementation and rate limiter functionality
for WebSocket and API rate limiting.
Author: Claude
Date: 2025-11-27
"""
import asyncio
from unittest.mock import patch
import pendulum
import pytest
from app.middleware.rate_limit import RateLimitBucket, RateLimiter, rate_limiter
class TestRateLimitBucket:
"""Tests for the RateLimitBucket token bucket implementation."""
def test_bucket_initialization(self):
"""
Test that bucket initializes with correct values.
Bucket should start with max_tokens available and record
the creation time as last_refill.
"""
bucket = RateLimitBucket(
tokens=10, max_tokens=10, refill_rate=1.0
)
assert bucket.tokens == 10
assert bucket.max_tokens == 10
assert bucket.refill_rate == 1.0
assert bucket.last_refill is not None
def test_consume_success(self):
"""
Test successful token consumption.
Should reduce available tokens and return True when
enough tokens are available.
"""
bucket = RateLimitBucket(
tokens=10, max_tokens=10, refill_rate=1.0
)
result = bucket.consume(1)
assert result is True
assert bucket.tokens < 10 # Reduced (may have refilled slightly)
def test_consume_multiple_tokens(self):
"""
Test consuming multiple tokens at once.
Should consume the requested amount if available.
"""
bucket = RateLimitBucket(
tokens=10, max_tokens=10, refill_rate=0.0 # No refill for test
)
result = bucket.consume(5)
assert result is True
assert bucket.tokens == 5
def test_consume_denied_when_empty(self):
"""
Test that consumption is denied when not enough tokens.
Should return False and not reduce tokens below zero.
"""
bucket = RateLimitBucket(
tokens=0, max_tokens=10, refill_rate=0.0 # No refill
)
result = bucket.consume(1)
assert result is False
assert bucket.tokens == 0 # Unchanged
def test_consume_denied_insufficient_tokens(self):
"""
Test denial when requesting more tokens than available.
Should return False without modifying token count.
"""
bucket = RateLimitBucket(
tokens=3, max_tokens=10, refill_rate=0.0
)
result = bucket.consume(5)
assert result is False
assert bucket.tokens == 3 # Unchanged
def test_token_refill_over_time(self):
"""
Test that tokens refill based on elapsed time.
Tokens should accumulate at refill_rate per second.
"""
# Start with 5 tokens
past_time = pendulum.now("UTC").subtract(seconds=5)
bucket = RateLimitBucket(
tokens=5,
max_tokens=10,
refill_rate=1.0, # 1 token per second
last_refill=past_time,
)
# Force refill by consuming (which triggers refill check)
bucket.consume(0)
# Should have refilled ~5 tokens (5 seconds * 1 token/sec)
assert bucket.tokens >= 9.5 # Allow small timing variance
def test_tokens_capped_at_max(self):
"""
Test that tokens don't exceed max_tokens after refill.
Token count should never go above the configured maximum.
"""
past_time = pendulum.now("UTC").subtract(seconds=100)
bucket = RateLimitBucket(
tokens=5,
max_tokens=10,
refill_rate=1.0, # Would add 100 tokens
last_refill=past_time,
)
bucket.consume(0) # Trigger refill
assert bucket.tokens == 10 # Capped at max
def test_get_wait_time_when_available(self):
"""
Test wait time calculation when tokens available.
Should return 0.0 when at least one token is available.
"""
bucket = RateLimitBucket(
tokens=5, max_tokens=10, refill_rate=1.0
)
wait_time = bucket.get_wait_time()
assert wait_time == 0.0
def test_get_wait_time_when_empty(self):
"""
Test wait time calculation when no tokens available.
Should return time needed for one token to refill.
"""
bucket = RateLimitBucket(
tokens=0, max_tokens=10, refill_rate=2.0 # 2 tokens per second
)
wait_time = bucket.get_wait_time()
assert wait_time == 0.5 # 1 token / 2 tokens per second
def test_get_wait_time_partial_tokens(self):
"""
Test wait time with partial tokens available.
Should calculate time to reach 1.0 token.
"""
bucket = RateLimitBucket(
tokens=0.5, max_tokens=10, refill_rate=1.0
)
wait_time = bucket.get_wait_time()
assert wait_time == 0.5 # Need 0.5 more tokens at 1/sec
class TestRateLimiter:
"""Tests for the RateLimiter class."""
@pytest.fixture
def limiter(self):
"""Create a fresh RateLimiter instance for each test."""
return RateLimiter()
def test_get_connection_bucket_creates_new(self, limiter):
"""
Test that get_connection_bucket creates bucket for new sid.
Should create and return a bucket initialized with settings.
"""
bucket = limiter.get_connection_bucket("test_sid_1")
assert bucket is not None
assert bucket.max_tokens > 0
def test_get_connection_bucket_returns_existing(self, limiter):
"""
Test that same bucket is returned for same sid.
Should return the same bucket instance for subsequent calls.
"""
bucket1 = limiter.get_connection_bucket("test_sid_1")
bucket2 = limiter.get_connection_bucket("test_sid_1")
assert bucket1 is bucket2
def test_get_connection_bucket_different_sids(self, limiter):
"""
Test that different sids get different buckets.
Each connection should have its own rate limit bucket.
"""
bucket1 = limiter.get_connection_bucket("sid_1")
bucket2 = limiter.get_connection_bucket("sid_2")
assert bucket1 is not bucket2
def test_get_game_bucket_creates_new(self, limiter):
"""
Test that get_game_bucket creates bucket for new game+action.
Should create bucket keyed by game_id and action type.
"""
bucket = limiter.get_game_bucket("game_123", "decision")
assert bucket is not None
assert bucket.max_tokens > 0
def test_get_game_bucket_different_actions(self, limiter):
"""
Test that different actions get different buckets for same game.
Each action type has its own rate limit within a game.
"""
decision_bucket = limiter.get_game_bucket("game_123", "decision")
roll_bucket = limiter.get_game_bucket("game_123", "roll")
assert decision_bucket is not roll_bucket
def test_get_game_bucket_different_games(self, limiter):
"""
Test that same action on different games gets different buckets.
Each game tracks its own rate limits.
"""
bucket1 = limiter.get_game_bucket("game_1", "decision")
bucket2 = limiter.get_game_bucket("game_2", "decision")
assert bucket1 is not bucket2
def test_get_user_bucket_creates_new(self, limiter):
"""
Test that get_user_bucket creates bucket for new user.
Should create bucket for API rate limiting per user.
"""
bucket = limiter.get_user_bucket(123)
assert bucket is not None
assert bucket.max_tokens > 0
@pytest.mark.asyncio
async def test_check_websocket_limit_allowed(self, limiter):
"""
Test that check_websocket_limit allows requests with tokens.
Should return True when tokens are available.
"""
result = await limiter.check_websocket_limit("test_sid")
assert result is True
@pytest.mark.asyncio
async def test_check_websocket_limit_denied_when_exhausted(self, limiter):
"""
Test that check_websocket_limit denies when rate limited.
Should return False after tokens are exhausted.
"""
sid = "rate_limited_sid"
# Exhaust all tokens
bucket = limiter.get_connection_bucket(sid)
bucket.tokens = 0
result = await limiter.check_websocket_limit(sid)
assert result is False
@pytest.mark.asyncio
async def test_check_game_limit_allowed(self, limiter):
"""
Test that check_game_limit allows requests with tokens.
Should return True when game action limit not reached.
"""
result = await limiter.check_game_limit("game_123", "decision")
assert result is True
@pytest.mark.asyncio
async def test_check_game_limit_denied_when_exhausted(self, limiter):
"""
Test that check_game_limit denies when rate limited.
Should return False after game action tokens exhausted.
"""
game_id = "rate_limited_game"
bucket = limiter.get_game_bucket(game_id, "roll")
bucket.tokens = 0
result = await limiter.check_game_limit(game_id, "roll")
assert result is False
@pytest.mark.asyncio
async def test_check_api_limit_allowed(self, limiter):
"""
Test that check_api_limit allows requests with tokens.
Should return True when user API limit not reached.
"""
result = await limiter.check_api_limit(123)
assert result is True
@pytest.mark.asyncio
async def test_check_api_limit_denied_when_exhausted(self, limiter):
"""
Test that check_api_limit denies when rate limited.
Should return False after user API tokens exhausted.
"""
user_id = 999
bucket = limiter.get_user_bucket(user_id)
bucket.tokens = 0
result = await limiter.check_api_limit(user_id)
assert result is False
def test_remove_connection_cleans_up(self, limiter):
"""
Test that remove_connection removes bucket for sid.
Should clean up bucket when connection closes.
"""
sid = "cleanup_test_sid"
limiter.get_connection_bucket(sid) # Create bucket
assert sid in limiter._connection_buckets
limiter.remove_connection(sid)
assert sid not in limiter._connection_buckets
def test_remove_connection_handles_missing(self, limiter):
"""
Test that remove_connection handles non-existent sid gracefully.
Should not raise error for unknown sid.
"""
limiter.remove_connection("nonexistent_sid") # Should not raise
def test_remove_game_cleans_up_all_actions(self, limiter):
"""
Test that remove_game removes all buckets for a game.
Should clean up decision, roll, substitution buckets for game.
"""
game_id = "cleanup_game"
# Create buckets for different actions
limiter.get_game_bucket(game_id, "decision")
limiter.get_game_bucket(game_id, "roll")
limiter.get_game_bucket(game_id, "substitution")
assert f"{game_id}:decision" in limiter._game_buckets
assert f"{game_id}:roll" in limiter._game_buckets
assert f"{game_id}:substitution" in limiter._game_buckets
limiter.remove_game(game_id)
assert f"{game_id}:decision" not in limiter._game_buckets
assert f"{game_id}:roll" not in limiter._game_buckets
assert f"{game_id}:substitution" not in limiter._game_buckets
def test_remove_game_preserves_other_games(self, limiter):
"""
Test that remove_game only removes specified game's buckets.
Other games' buckets should remain intact.
"""
limiter.get_game_bucket("game_1", "decision")
limiter.get_game_bucket("game_2", "decision")
limiter.remove_game("game_1")
assert "game_1:decision" not in limiter._game_buckets
assert "game_2:decision" in limiter._game_buckets
def test_get_stats_returns_counts(self, limiter):
"""
Test that get_stats returns bucket counts.
Should return counts for all bucket types.
"""
limiter.get_connection_bucket("sid_1")
limiter.get_connection_bucket("sid_2")
limiter.get_game_bucket("game_1", "decision")
limiter.get_user_bucket(123)
stats = limiter.get_stats()
assert stats["connection_buckets"] == 2
assert stats["game_buckets"] == 1
assert stats["user_buckets"] == 1
class TestCleanupTask:
"""Tests for the stale bucket cleanup functionality."""
@pytest.mark.asyncio
async def test_cleanup_removes_stale_connection_buckets(self):
"""
Test that cleanup removes connection buckets older than threshold.
Buckets unused for >10 minutes should be cleaned up.
"""
limiter = RateLimiter()
# Create a bucket with old last_refill time (>10 minutes ago)
stale_time = pendulum.now("UTC").subtract(minutes=15)
bucket = limiter.get_connection_bucket("stale_sid")
bucket.last_refill = stale_time
# Create a fresh bucket
fresh_bucket = limiter.get_connection_bucket("fresh_sid")
# Manually trigger cleanup logic (without running full async task)
now = pendulum.now("UTC")
stale_sids = [
sid
for sid, b in limiter._connection_buckets.items()
if (now - b.last_refill).total_seconds() > 600
]
for sid in stale_sids:
del limiter._connection_buckets[sid]
assert "stale_sid" not in limiter._connection_buckets
assert "fresh_sid" in limiter._connection_buckets
@pytest.mark.asyncio
async def test_cleanup_removes_stale_game_buckets(self):
"""
Test that cleanup removes game buckets older than threshold.
Abandoned game buckets should be cleaned up.
"""
limiter = RateLimiter()
stale_time = pendulum.now("UTC").subtract(minutes=15)
bucket = limiter.get_game_bucket("stale_game", "decision")
bucket.last_refill = stale_time
limiter.get_game_bucket("fresh_game", "decision")
# Manually trigger cleanup
now = pendulum.now("UTC")
stale_keys = [
key
for key, b in limiter._game_buckets.items()
if (now - b.last_refill).total_seconds() > 600
]
for key in stale_keys:
del limiter._game_buckets[key]
assert "stale_game:decision" not in limiter._game_buckets
assert "fresh_game:decision" in limiter._game_buckets
class TestGlobalRateLimiter:
"""Tests for the global rate_limiter instance."""
def test_global_instance_exists(self):
"""
Test that global rate_limiter instance is available.
Should be importable and usable as singleton.
"""
assert rate_limiter is not None
assert isinstance(rate_limiter, RateLimiter)
@pytest.mark.asyncio
async def test_global_instance_functional(self):
"""
Test that global rate_limiter works correctly.
Should accept and rate limit requests.
"""
result = await rate_limiter.check_websocket_limit("global_test_sid")
assert result is True
# Cleanup
rate_limiter.remove_connection("global_test_sid")
class TestRateLimitConfiguration:
"""Tests for rate limit configuration integration."""
def test_connection_bucket_uses_config(self):
"""
Test that connection bucket uses configured limits.
max_tokens should match rate_limit_websocket_per_minute setting.
"""
limiter = RateLimiter()
bucket = limiter.get_connection_bucket("config_test_sid")
# Default is 120 per minute
assert bucket.max_tokens == limiter._settings.rate_limit_websocket_per_minute
limiter.remove_connection("config_test_sid")
def test_decision_bucket_uses_config(self):
"""
Test that decision bucket uses configured limits.
max_tokens should match rate_limit_decision_per_game setting.
"""
limiter = RateLimiter()
bucket = limiter.get_game_bucket("config_game", "decision")
assert bucket.max_tokens == limiter._settings.rate_limit_decision_per_game
limiter.remove_game("config_game")
def test_roll_bucket_uses_config(self):
"""
Test that roll bucket uses configured limits.
max_tokens should match rate_limit_roll_per_game setting.
"""
limiter = RateLimiter()
bucket = limiter.get_game_bucket("config_game", "roll")
assert bucket.max_tokens == limiter._settings.rate_limit_roll_per_game
limiter.remove_game("config_game")
def test_substitution_bucket_uses_config(self):
"""
Test that substitution bucket uses configured limits.
max_tokens should match rate_limit_substitution_per_game setting.
"""
limiter = RateLimiter()
bucket = limiter.get_game_bucket("config_game", "substitution")
assert bucket.max_tokens == limiter._settings.rate_limit_substitution_per_game
limiter.remove_game("config_game")
def test_api_bucket_uses_config(self):
"""
Test that API bucket uses configured limits.
max_tokens should match rate_limit_api_per_minute setting.
"""
limiter = RateLimiter()
bucket = limiter.get_user_bucket(12345)
assert bucket.max_tokens == limiter._settings.rate_limit_api_per_minute
del limiter._user_buckets[12345]