- Add rate_limit.py middleware with per-client throttling and cleanup task - Add pool_monitor.py for database connection pool health monitoring - Add custom exceptions module (GameEngineError, DatabaseError, etc.) - Add config settings for eviction intervals, session timeouts, memory limits - Add unit tests for rate limiting and pool monitoring 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
584 lines
18 KiB
Python
584 lines
18 KiB
Python
"""
|
|
Unit tests for rate limiting middleware.
|
|
|
|
Tests the token bucket algorithm implementation and rate limiter functionality
|
|
for WebSocket and API rate limiting.
|
|
|
|
Author: Claude
|
|
Date: 2025-11-27
|
|
"""
|
|
|
|
import asyncio
|
|
from unittest.mock import patch
|
|
|
|
import pendulum
|
|
import pytest
|
|
|
|
from app.middleware.rate_limit import RateLimitBucket, RateLimiter, rate_limiter
|
|
|
|
|
|
class TestRateLimitBucket:
|
|
"""Tests for the RateLimitBucket token bucket implementation."""
|
|
|
|
def test_bucket_initialization(self):
|
|
"""
|
|
Test that bucket initializes with correct values.
|
|
|
|
Bucket should start with max_tokens available and record
|
|
the creation time as last_refill.
|
|
"""
|
|
bucket = RateLimitBucket(
|
|
tokens=10, max_tokens=10, refill_rate=1.0
|
|
)
|
|
|
|
assert bucket.tokens == 10
|
|
assert bucket.max_tokens == 10
|
|
assert bucket.refill_rate == 1.0
|
|
assert bucket.last_refill is not None
|
|
|
|
def test_consume_success(self):
|
|
"""
|
|
Test successful token consumption.
|
|
|
|
Should reduce available tokens and return True when
|
|
enough tokens are available.
|
|
"""
|
|
bucket = RateLimitBucket(
|
|
tokens=10, max_tokens=10, refill_rate=1.0
|
|
)
|
|
|
|
result = bucket.consume(1)
|
|
|
|
assert result is True
|
|
assert bucket.tokens < 10 # Reduced (may have refilled slightly)
|
|
|
|
def test_consume_multiple_tokens(self):
|
|
"""
|
|
Test consuming multiple tokens at once.
|
|
|
|
Should consume the requested amount if available.
|
|
"""
|
|
bucket = RateLimitBucket(
|
|
tokens=10, max_tokens=10, refill_rate=0.0 # No refill for test
|
|
)
|
|
|
|
result = bucket.consume(5)
|
|
|
|
assert result is True
|
|
assert bucket.tokens == 5
|
|
|
|
def test_consume_denied_when_empty(self):
|
|
"""
|
|
Test that consumption is denied when not enough tokens.
|
|
|
|
Should return False and not reduce tokens below zero.
|
|
"""
|
|
bucket = RateLimitBucket(
|
|
tokens=0, max_tokens=10, refill_rate=0.0 # No refill
|
|
)
|
|
|
|
result = bucket.consume(1)
|
|
|
|
assert result is False
|
|
assert bucket.tokens == 0 # Unchanged
|
|
|
|
def test_consume_denied_insufficient_tokens(self):
|
|
"""
|
|
Test denial when requesting more tokens than available.
|
|
|
|
Should return False without modifying token count.
|
|
"""
|
|
bucket = RateLimitBucket(
|
|
tokens=3, max_tokens=10, refill_rate=0.0
|
|
)
|
|
|
|
result = bucket.consume(5)
|
|
|
|
assert result is False
|
|
assert bucket.tokens == 3 # Unchanged
|
|
|
|
def test_token_refill_over_time(self):
|
|
"""
|
|
Test that tokens refill based on elapsed time.
|
|
|
|
Tokens should accumulate at refill_rate per second.
|
|
"""
|
|
# Start with 5 tokens
|
|
past_time = pendulum.now("UTC").subtract(seconds=5)
|
|
bucket = RateLimitBucket(
|
|
tokens=5,
|
|
max_tokens=10,
|
|
refill_rate=1.0, # 1 token per second
|
|
last_refill=past_time,
|
|
)
|
|
|
|
# Force refill by consuming (which triggers refill check)
|
|
bucket.consume(0)
|
|
|
|
# Should have refilled ~5 tokens (5 seconds * 1 token/sec)
|
|
assert bucket.tokens >= 9.5 # Allow small timing variance
|
|
|
|
def test_tokens_capped_at_max(self):
|
|
"""
|
|
Test that tokens don't exceed max_tokens after refill.
|
|
|
|
Token count should never go above the configured maximum.
|
|
"""
|
|
past_time = pendulum.now("UTC").subtract(seconds=100)
|
|
bucket = RateLimitBucket(
|
|
tokens=5,
|
|
max_tokens=10,
|
|
refill_rate=1.0, # Would add 100 tokens
|
|
last_refill=past_time,
|
|
)
|
|
|
|
bucket.consume(0) # Trigger refill
|
|
|
|
assert bucket.tokens == 10 # Capped at max
|
|
|
|
def test_get_wait_time_when_available(self):
|
|
"""
|
|
Test wait time calculation when tokens available.
|
|
|
|
Should return 0.0 when at least one token is available.
|
|
"""
|
|
bucket = RateLimitBucket(
|
|
tokens=5, max_tokens=10, refill_rate=1.0
|
|
)
|
|
|
|
wait_time = bucket.get_wait_time()
|
|
|
|
assert wait_time == 0.0
|
|
|
|
def test_get_wait_time_when_empty(self):
|
|
"""
|
|
Test wait time calculation when no tokens available.
|
|
|
|
Should return time needed for one token to refill.
|
|
"""
|
|
bucket = RateLimitBucket(
|
|
tokens=0, max_tokens=10, refill_rate=2.0 # 2 tokens per second
|
|
)
|
|
|
|
wait_time = bucket.get_wait_time()
|
|
|
|
assert wait_time == 0.5 # 1 token / 2 tokens per second
|
|
|
|
def test_get_wait_time_partial_tokens(self):
|
|
"""
|
|
Test wait time with partial tokens available.
|
|
|
|
Should calculate time to reach 1.0 token.
|
|
"""
|
|
bucket = RateLimitBucket(
|
|
tokens=0.5, max_tokens=10, refill_rate=1.0
|
|
)
|
|
|
|
wait_time = bucket.get_wait_time()
|
|
|
|
assert wait_time == 0.5 # Need 0.5 more tokens at 1/sec
|
|
|
|
|
|
class TestRateLimiter:
|
|
"""Tests for the RateLimiter class."""
|
|
|
|
@pytest.fixture
|
|
def limiter(self):
|
|
"""Create a fresh RateLimiter instance for each test."""
|
|
return RateLimiter()
|
|
|
|
def test_get_connection_bucket_creates_new(self, limiter):
|
|
"""
|
|
Test that get_connection_bucket creates bucket for new sid.
|
|
|
|
Should create and return a bucket initialized with settings.
|
|
"""
|
|
bucket = limiter.get_connection_bucket("test_sid_1")
|
|
|
|
assert bucket is not None
|
|
assert bucket.max_tokens > 0
|
|
|
|
def test_get_connection_bucket_returns_existing(self, limiter):
|
|
"""
|
|
Test that same bucket is returned for same sid.
|
|
|
|
Should return the same bucket instance for subsequent calls.
|
|
"""
|
|
bucket1 = limiter.get_connection_bucket("test_sid_1")
|
|
bucket2 = limiter.get_connection_bucket("test_sid_1")
|
|
|
|
assert bucket1 is bucket2
|
|
|
|
def test_get_connection_bucket_different_sids(self, limiter):
|
|
"""
|
|
Test that different sids get different buckets.
|
|
|
|
Each connection should have its own rate limit bucket.
|
|
"""
|
|
bucket1 = limiter.get_connection_bucket("sid_1")
|
|
bucket2 = limiter.get_connection_bucket("sid_2")
|
|
|
|
assert bucket1 is not bucket2
|
|
|
|
def test_get_game_bucket_creates_new(self, limiter):
|
|
"""
|
|
Test that get_game_bucket creates bucket for new game+action.
|
|
|
|
Should create bucket keyed by game_id and action type.
|
|
"""
|
|
bucket = limiter.get_game_bucket("game_123", "decision")
|
|
|
|
assert bucket is not None
|
|
assert bucket.max_tokens > 0
|
|
|
|
def test_get_game_bucket_different_actions(self, limiter):
|
|
"""
|
|
Test that different actions get different buckets for same game.
|
|
|
|
Each action type has its own rate limit within a game.
|
|
"""
|
|
decision_bucket = limiter.get_game_bucket("game_123", "decision")
|
|
roll_bucket = limiter.get_game_bucket("game_123", "roll")
|
|
|
|
assert decision_bucket is not roll_bucket
|
|
|
|
def test_get_game_bucket_different_games(self, limiter):
|
|
"""
|
|
Test that same action on different games gets different buckets.
|
|
|
|
Each game tracks its own rate limits.
|
|
"""
|
|
bucket1 = limiter.get_game_bucket("game_1", "decision")
|
|
bucket2 = limiter.get_game_bucket("game_2", "decision")
|
|
|
|
assert bucket1 is not bucket2
|
|
|
|
def test_get_user_bucket_creates_new(self, limiter):
|
|
"""
|
|
Test that get_user_bucket creates bucket for new user.
|
|
|
|
Should create bucket for API rate limiting per user.
|
|
"""
|
|
bucket = limiter.get_user_bucket(123)
|
|
|
|
assert bucket is not None
|
|
assert bucket.max_tokens > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_websocket_limit_allowed(self, limiter):
|
|
"""
|
|
Test that check_websocket_limit allows requests with tokens.
|
|
|
|
Should return True when tokens are available.
|
|
"""
|
|
result = await limiter.check_websocket_limit("test_sid")
|
|
|
|
assert result is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_websocket_limit_denied_when_exhausted(self, limiter):
|
|
"""
|
|
Test that check_websocket_limit denies when rate limited.
|
|
|
|
Should return False after tokens are exhausted.
|
|
"""
|
|
sid = "rate_limited_sid"
|
|
|
|
# Exhaust all tokens
|
|
bucket = limiter.get_connection_bucket(sid)
|
|
bucket.tokens = 0
|
|
|
|
result = await limiter.check_websocket_limit(sid)
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_game_limit_allowed(self, limiter):
|
|
"""
|
|
Test that check_game_limit allows requests with tokens.
|
|
|
|
Should return True when game action limit not reached.
|
|
"""
|
|
result = await limiter.check_game_limit("game_123", "decision")
|
|
|
|
assert result is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_game_limit_denied_when_exhausted(self, limiter):
|
|
"""
|
|
Test that check_game_limit denies when rate limited.
|
|
|
|
Should return False after game action tokens exhausted.
|
|
"""
|
|
game_id = "rate_limited_game"
|
|
|
|
bucket = limiter.get_game_bucket(game_id, "roll")
|
|
bucket.tokens = 0
|
|
|
|
result = await limiter.check_game_limit(game_id, "roll")
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_api_limit_allowed(self, limiter):
|
|
"""
|
|
Test that check_api_limit allows requests with tokens.
|
|
|
|
Should return True when user API limit not reached.
|
|
"""
|
|
result = await limiter.check_api_limit(123)
|
|
|
|
assert result is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_api_limit_denied_when_exhausted(self, limiter):
|
|
"""
|
|
Test that check_api_limit denies when rate limited.
|
|
|
|
Should return False after user API tokens exhausted.
|
|
"""
|
|
user_id = 999
|
|
|
|
bucket = limiter.get_user_bucket(user_id)
|
|
bucket.tokens = 0
|
|
|
|
result = await limiter.check_api_limit(user_id)
|
|
|
|
assert result is False
|
|
|
|
def test_remove_connection_cleans_up(self, limiter):
|
|
"""
|
|
Test that remove_connection removes bucket for sid.
|
|
|
|
Should clean up bucket when connection closes.
|
|
"""
|
|
sid = "cleanup_test_sid"
|
|
limiter.get_connection_bucket(sid) # Create bucket
|
|
|
|
assert sid in limiter._connection_buckets
|
|
|
|
limiter.remove_connection(sid)
|
|
|
|
assert sid not in limiter._connection_buckets
|
|
|
|
def test_remove_connection_handles_missing(self, limiter):
|
|
"""
|
|
Test that remove_connection handles non-existent sid gracefully.
|
|
|
|
Should not raise error for unknown sid.
|
|
"""
|
|
limiter.remove_connection("nonexistent_sid") # Should not raise
|
|
|
|
def test_remove_game_cleans_up_all_actions(self, limiter):
|
|
"""
|
|
Test that remove_game removes all buckets for a game.
|
|
|
|
Should clean up decision, roll, substitution buckets for game.
|
|
"""
|
|
game_id = "cleanup_game"
|
|
|
|
# Create buckets for different actions
|
|
limiter.get_game_bucket(game_id, "decision")
|
|
limiter.get_game_bucket(game_id, "roll")
|
|
limiter.get_game_bucket(game_id, "substitution")
|
|
|
|
assert f"{game_id}:decision" in limiter._game_buckets
|
|
assert f"{game_id}:roll" in limiter._game_buckets
|
|
assert f"{game_id}:substitution" in limiter._game_buckets
|
|
|
|
limiter.remove_game(game_id)
|
|
|
|
assert f"{game_id}:decision" not in limiter._game_buckets
|
|
assert f"{game_id}:roll" not in limiter._game_buckets
|
|
assert f"{game_id}:substitution" not in limiter._game_buckets
|
|
|
|
def test_remove_game_preserves_other_games(self, limiter):
|
|
"""
|
|
Test that remove_game only removes specified game's buckets.
|
|
|
|
Other games' buckets should remain intact.
|
|
"""
|
|
limiter.get_game_bucket("game_1", "decision")
|
|
limiter.get_game_bucket("game_2", "decision")
|
|
|
|
limiter.remove_game("game_1")
|
|
|
|
assert "game_1:decision" not in limiter._game_buckets
|
|
assert "game_2:decision" in limiter._game_buckets
|
|
|
|
def test_get_stats_returns_counts(self, limiter):
|
|
"""
|
|
Test that get_stats returns bucket counts.
|
|
|
|
Should return counts for all bucket types.
|
|
"""
|
|
limiter.get_connection_bucket("sid_1")
|
|
limiter.get_connection_bucket("sid_2")
|
|
limiter.get_game_bucket("game_1", "decision")
|
|
limiter.get_user_bucket(123)
|
|
|
|
stats = limiter.get_stats()
|
|
|
|
assert stats["connection_buckets"] == 2
|
|
assert stats["game_buckets"] == 1
|
|
assert stats["user_buckets"] == 1
|
|
|
|
|
|
class TestCleanupTask:
|
|
"""Tests for the stale bucket cleanup functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_removes_stale_connection_buckets(self):
|
|
"""
|
|
Test that cleanup removes connection buckets older than threshold.
|
|
|
|
Buckets unused for >10 minutes should be cleaned up.
|
|
"""
|
|
limiter = RateLimiter()
|
|
|
|
# Create a bucket with old last_refill time (>10 minutes ago)
|
|
stale_time = pendulum.now("UTC").subtract(minutes=15)
|
|
bucket = limiter.get_connection_bucket("stale_sid")
|
|
bucket.last_refill = stale_time
|
|
|
|
# Create a fresh bucket
|
|
fresh_bucket = limiter.get_connection_bucket("fresh_sid")
|
|
|
|
# Manually trigger cleanup logic (without running full async task)
|
|
now = pendulum.now("UTC")
|
|
stale_sids = [
|
|
sid
|
|
for sid, b in limiter._connection_buckets.items()
|
|
if (now - b.last_refill).total_seconds() > 600
|
|
]
|
|
for sid in stale_sids:
|
|
del limiter._connection_buckets[sid]
|
|
|
|
assert "stale_sid" not in limiter._connection_buckets
|
|
assert "fresh_sid" in limiter._connection_buckets
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_removes_stale_game_buckets(self):
|
|
"""
|
|
Test that cleanup removes game buckets older than threshold.
|
|
|
|
Abandoned game buckets should be cleaned up.
|
|
"""
|
|
limiter = RateLimiter()
|
|
|
|
stale_time = pendulum.now("UTC").subtract(minutes=15)
|
|
bucket = limiter.get_game_bucket("stale_game", "decision")
|
|
bucket.last_refill = stale_time
|
|
|
|
limiter.get_game_bucket("fresh_game", "decision")
|
|
|
|
# Manually trigger cleanup
|
|
now = pendulum.now("UTC")
|
|
stale_keys = [
|
|
key
|
|
for key, b in limiter._game_buckets.items()
|
|
if (now - b.last_refill).total_seconds() > 600
|
|
]
|
|
for key in stale_keys:
|
|
del limiter._game_buckets[key]
|
|
|
|
assert "stale_game:decision" not in limiter._game_buckets
|
|
assert "fresh_game:decision" in limiter._game_buckets
|
|
|
|
|
|
class TestGlobalRateLimiter:
|
|
"""Tests for the global rate_limiter instance."""
|
|
|
|
def test_global_instance_exists(self):
|
|
"""
|
|
Test that global rate_limiter instance is available.
|
|
|
|
Should be importable and usable as singleton.
|
|
"""
|
|
assert rate_limiter is not None
|
|
assert isinstance(rate_limiter, RateLimiter)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_global_instance_functional(self):
|
|
"""
|
|
Test that global rate_limiter works correctly.
|
|
|
|
Should accept and rate limit requests.
|
|
"""
|
|
result = await rate_limiter.check_websocket_limit("global_test_sid")
|
|
|
|
assert result is True
|
|
|
|
# Cleanup
|
|
rate_limiter.remove_connection("global_test_sid")
|
|
|
|
|
|
class TestRateLimitConfiguration:
|
|
"""Tests for rate limit configuration integration."""
|
|
|
|
def test_connection_bucket_uses_config(self):
|
|
"""
|
|
Test that connection bucket uses configured limits.
|
|
|
|
max_tokens should match rate_limit_websocket_per_minute setting.
|
|
"""
|
|
limiter = RateLimiter()
|
|
bucket = limiter.get_connection_bucket("config_test_sid")
|
|
|
|
# Default is 120 per minute
|
|
assert bucket.max_tokens == limiter._settings.rate_limit_websocket_per_minute
|
|
|
|
limiter.remove_connection("config_test_sid")
|
|
|
|
def test_decision_bucket_uses_config(self):
|
|
"""
|
|
Test that decision bucket uses configured limits.
|
|
|
|
max_tokens should match rate_limit_decision_per_game setting.
|
|
"""
|
|
limiter = RateLimiter()
|
|
bucket = limiter.get_game_bucket("config_game", "decision")
|
|
|
|
assert bucket.max_tokens == limiter._settings.rate_limit_decision_per_game
|
|
|
|
limiter.remove_game("config_game")
|
|
|
|
def test_roll_bucket_uses_config(self):
|
|
"""
|
|
Test that roll bucket uses configured limits.
|
|
|
|
max_tokens should match rate_limit_roll_per_game setting.
|
|
"""
|
|
limiter = RateLimiter()
|
|
bucket = limiter.get_game_bucket("config_game", "roll")
|
|
|
|
assert bucket.max_tokens == limiter._settings.rate_limit_roll_per_game
|
|
|
|
limiter.remove_game("config_game")
|
|
|
|
def test_substitution_bucket_uses_config(self):
|
|
"""
|
|
Test that substitution bucket uses configured limits.
|
|
|
|
max_tokens should match rate_limit_substitution_per_game setting.
|
|
"""
|
|
limiter = RateLimiter()
|
|
bucket = limiter.get_game_bucket("config_game", "substitution")
|
|
|
|
assert bucket.max_tokens == limiter._settings.rate_limit_substitution_per_game
|
|
|
|
limiter.remove_game("config_game")
|
|
|
|
def test_api_bucket_uses_config(self):
|
|
"""
|
|
Test that API bucket uses configured limits.
|
|
|
|
max_tokens should match rate_limit_api_per_minute setting.
|
|
"""
|
|
limiter = RateLimiter()
|
|
bucket = limiter.get_user_bucket(12345)
|
|
|
|
assert bucket.max_tokens == limiter._settings.rate_limit_api_per_minute
|
|
|
|
del limiter._user_buckets[12345]
|