CLAUDE: Add rate limiting, pool monitoring, and exception infrastructure

- Add rate_limit.py middleware with per-client throttling and cleanup task
- Add pool_monitor.py for database connection pool health monitoring
- Add custom exceptions module (GameEngineError, DatabaseError, etc.)
- Add config settings for eviction intervals, session timeouts, memory limits
- Add unit tests for rate limiting and pool monitoring

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2025-11-28 12:06:10 -06:00
parent ae4a92f0e0
commit 2a392b87f8
10 changed files with 1763 additions and 0 deletions

View File

@ -47,6 +47,19 @@ class Settings(BaseSettings):
max_concurrent_games: int = 20
game_idle_timeout: int = 86400 # 24 hours
# Game eviction settings (memory management)
game_idle_timeout_hours: int = 24 # Evict games idle > 24 hours
game_eviction_interval_minutes: int = 60 # Check every hour
game_max_in_memory: int = 500 # Hard limit on in-memory games
# Rate limiting settings
rate_limit_websocket_per_minute: int = 120 # Events per minute per connection
rate_limit_api_per_minute: int = 100 # API calls per minute per user
rate_limit_decision_per_game: int = 20 # Decisions per minute per game
rate_limit_roll_per_game: int = 30 # Rolls per minute per game
rate_limit_substitution_per_game: int = 15 # Substitutions per minute per game
rate_limit_cleanup_interval: int = 300 # Cleanup stale buckets every 5 min
class Config:
env_file = ".env"
case_sensitive = False

View File

@ -0,0 +1,214 @@
"""
Custom exceptions for the game engine.
Provides a hierarchy of specific exception types to replace broad `except Exception`
patterns, improving debugging and error handling precision.
Exception Hierarchy:
GameEngineError (base)
GameNotFoundError - Game doesn't exist in state manager
InvalidGameStateError - Game in wrong state for operation
SubstitutionError - Invalid player substitution
AuthorizationError - User lacks permission
DecisionTimeoutError - Decision not submitted in time
ExternalAPIError - External service call failed
PlayerDataError - Failed to fetch player data
Author: Claude
Date: 2025-11-27
"""
from uuid import UUID
class GameEngineError(Exception):
"""
Base exception for all game engine errors.
All game-specific exceptions should inherit from this class
to allow catching game errors separately from system errors.
"""
pass
class GameNotFoundError(GameEngineError):
"""
Raised when a game doesn't exist in state manager.
Attributes:
game_id: The UUID of the missing game
"""
def __init__(self, game_id: UUID | str):
self.game_id = game_id
super().__init__(f"Game not found: {game_id}")
class InvalidGameStateError(GameEngineError):
"""
Raised when game is in invalid state for requested operation.
Examples:
- Trying to submit decision when game is completed
- Rolling dice before both decisions submitted
- Starting an already-started game
Attributes:
message: Description of the state violation
current_state: Optional current state value for debugging
expected_state: Optional expected state value for debugging
"""
def __init__(
self,
message: str,
current_state: str | None = None,
expected_state: str | None = None,
):
self.current_state = current_state
self.expected_state = expected_state
super().__init__(message)
class SubstitutionError(GameEngineError):
"""
Raised when a player substitution is invalid.
Used for substitution rule violations, not found errors.
Attributes:
message: Description of the substitution error
error_code: Machine-readable error code for frontend handling
"""
def __init__(self, message: str, error_code: str = "SUBSTITUTION_ERROR"):
self.error_code = error_code
super().__init__(message)
class AuthorizationError(GameEngineError):
"""
Raised when user lacks permission for an operation.
Note: Currently deferred (task 001) but defined for future use.
Attributes:
message: Description of the authorization failure
user_id: Optional user ID for logging
resource: Optional resource being accessed
"""
def __init__(
self,
message: str,
user_id: int | None = None,
resource: str | None = None,
):
self.user_id = user_id
self.resource = resource
super().__init__(message)
class DecisionTimeoutError(GameEngineError):
"""
Raised when a decision is not submitted within the timeout period.
The game engine uses default decisions when this occurs,
so this exception is informational rather than fatal.
Attributes:
game_id: Game where timeout occurred
decision_type: Type of decision that timed out ('defensive' or 'offensive')
timeout_seconds: How long we waited
"""
def __init__(
self,
game_id: UUID | str,
decision_type: str,
timeout_seconds: int,
):
self.game_id = game_id
self.decision_type = decision_type
self.timeout_seconds = timeout_seconds
super().__init__(
f"Decision timeout for game {game_id}: "
f"{decision_type} decision not received within {timeout_seconds}s"
)
class ExternalAPIError(GameEngineError):
"""
Raised when an external API call fails.
Base class for external service errors.
Attributes:
service: Name of the external service
message: Error description
status_code: Optional HTTP status code
"""
def __init__(
self,
service: str,
message: str,
status_code: int | None = None,
):
self.service = service
self.status_code = status_code
super().__init__(f"{service} API error: {message}")
class PlayerDataError(ExternalAPIError):
"""
Raised when player data cannot be fetched from external API.
Specialized error for player data lookups (SBA API, PD API).
Attributes:
player_id: ID of the player that couldn't be fetched
service: Which API was called
"""
def __init__(self, player_id: int, service: str = "SBA API"):
self.player_id = player_id
super().__init__(
service=service,
message=f"Failed to fetch player data for ID {player_id}",
)
class DatabaseError(GameEngineError):
"""
Raised when a database operation fails.
Wraps SQLAlchemy exceptions with game context.
Attributes:
operation: What operation was attempted (e.g., 'save_play', 'create_substitution')
original_error: The underlying database exception
"""
def __init__(self, operation: str, original_error: Exception | None = None):
self.operation = operation
self.original_error = original_error
message = f"Database error during {operation}"
if original_error:
message += f": {original_error}"
super().__init__(message)
class LineupError(GameEngineError):
"""
Raised when lineup validation or lookup fails.
Attributes:
team_id: Team with the lineup issue
message: Description of the error
"""
def __init__(self, team_id: int, message: str):
self.team_id = team_id
super().__init__(f"Lineup error for team {team_id}: {message}")

View File

View File

@ -0,0 +1,328 @@
"""
Rate limiting utilities for WebSocket and API endpoints.
Implements token bucket algorithm for smooth rate limiting that allows
bursts while enforcing long-term rate limits.
Key features:
- Per-connection rate limiting for WebSocket events
- Per-game rate limiting for game actions (decisions, rolls, substitutions)
- Per-user rate limiting for REST API endpoints
- Automatic cleanup of stale buckets
- Non-blocking async design
Author: Claude
Date: 2025-11-27
"""
import asyncio
import logging
from dataclasses import dataclass, field
from functools import wraps
from typing import TYPE_CHECKING, Callable
import pendulum
from app.config import get_settings
if TYPE_CHECKING:
from socketio import AsyncServer
logger = logging.getLogger(f"{__name__}.RateLimiter")
@dataclass
class RateLimitBucket:
"""
Token bucket for rate limiting.
The token bucket algorithm works by:
1. Bucket starts with max_tokens
2. Each request consumes tokens
3. Tokens refill at refill_rate per second
4. Requests are denied when tokens < 1
This allows short bursts while enforcing average rate limits.
"""
tokens: float
max_tokens: int
refill_rate: float # tokens per second
last_refill: pendulum.DateTime = field(
default_factory=lambda: pendulum.now("UTC")
)
def consume(self, tokens: int = 1) -> bool:
"""
Try to consume tokens.
Returns True if allowed, False if rate limited.
"""
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def _refill(self) -> None:
"""Refill tokens based on time elapsed since last refill."""
now = pendulum.now("UTC")
elapsed = (now - self.last_refill).total_seconds()
refill_amount = elapsed * self.refill_rate
if refill_amount > 0:
self.tokens = min(self.max_tokens, self.tokens + refill_amount)
self.last_refill = now
def get_wait_time(self) -> float:
"""
Get seconds until a token will be available.
Useful for informing clients how long to wait.
"""
if self.tokens >= 1:
return 0.0
tokens_needed = 1 - self.tokens
return tokens_needed / self.refill_rate
class RateLimiter:
"""
Rate limiter for WebSocket connections and API endpoints.
Uses token bucket algorithm for smooth rate limiting that allows
bursts while enforcing average rate limits.
Thread-safe for async operations via atomic dict operations.
"""
def __init__(self):
# Per-connection buckets (sid -> bucket)
self._connection_buckets: dict[str, RateLimitBucket] = {}
# Per-game buckets (game_id:action -> bucket)
self._game_buckets: dict[str, RateLimitBucket] = {}
# Per-user API buckets (user_id -> bucket)
self._user_buckets: dict[int, RateLimitBucket] = {}
# Cleanup task handle
self._cleanup_task: asyncio.Task | None = None
# Settings cache
self._settings = get_settings()
def get_connection_bucket(self, sid: str) -> RateLimitBucket:
"""Get or create bucket for WebSocket connection."""
if sid not in self._connection_buckets:
limit = self._settings.rate_limit_websocket_per_minute
self._connection_buckets[sid] = RateLimitBucket(
tokens=limit,
max_tokens=limit,
refill_rate=limit / 60, # per minute -> per second
)
return self._connection_buckets[sid]
def get_game_bucket(self, game_id: str, action: str) -> RateLimitBucket:
"""
Get or create bucket for game-specific action.
Actions: 'decision', 'roll', 'substitution'
"""
key = f"{game_id}:{action}"
if key not in self._game_buckets:
# Get limit based on action type
if action == "decision":
limit = self._settings.rate_limit_decision_per_game
elif action == "roll":
limit = self._settings.rate_limit_roll_per_game
elif action == "substitution":
limit = self._settings.rate_limit_substitution_per_game
else:
limit = 30 # Default for unknown actions
self._game_buckets[key] = RateLimitBucket(
tokens=limit,
max_tokens=limit,
refill_rate=limit / 60,
)
return self._game_buckets[key]
def get_user_bucket(self, user_id: int) -> RateLimitBucket:
"""Get or create bucket for API user."""
if user_id not in self._user_buckets:
limit = self._settings.rate_limit_api_per_minute
self._user_buckets[user_id] = RateLimitBucket(
tokens=limit,
max_tokens=limit,
refill_rate=limit / 60,
)
return self._user_buckets[user_id]
async def check_websocket_limit(self, sid: str) -> bool:
"""
Check if WebSocket event is allowed for connection.
Returns True if allowed, False if rate limited.
"""
bucket = self.get_connection_bucket(sid)
allowed = bucket.consume()
if not allowed:
logger.warning(f"Rate limited WebSocket connection: {sid}")
return allowed
async def check_game_limit(self, game_id: str, action: str) -> bool:
"""
Check if game action is allowed.
Returns True if allowed, False if rate limited.
"""
bucket = self.get_game_bucket(game_id, action)
allowed = bucket.consume()
if not allowed:
logger.warning(f"Rate limited game action: {game_id} {action}")
return allowed
async def check_api_limit(self, user_id: int) -> bool:
"""
Check if API call is allowed for user.
Returns True if allowed, False if rate limited.
"""
bucket = self.get_user_bucket(user_id)
allowed = bucket.consume()
if not allowed:
logger.warning(f"Rate limited API user: {user_id}")
return allowed
def remove_connection(self, sid: str) -> None:
"""Clean up buckets when connection closes."""
self._connection_buckets.pop(sid, None)
def remove_game(self, game_id: str) -> None:
"""Clean up all buckets for a game when it ends."""
keys_to_remove = [
key for key in self._game_buckets if key.startswith(f"{game_id}:")
]
for key in keys_to_remove:
del self._game_buckets[key]
async def cleanup_stale_buckets(self) -> None:
"""
Background task to periodically clean up stale buckets.
Removes buckets that haven't been used in 10 minutes to prevent
memory leaks from abandoned connections.
"""
interval = self._settings.rate_limit_cleanup_interval
stale_threshold_seconds = 600 # 10 minutes
logger.info(f"Starting rate limiter cleanup task (interval: {interval}s)")
while True:
try:
await asyncio.sleep(interval)
now = pendulum.now("UTC")
# Clean connection buckets
stale_connections = [
sid
for sid, bucket in self._connection_buckets.items()
if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds
]
for sid in stale_connections:
del self._connection_buckets[sid]
# Clean game buckets
stale_games = [
key
for key, bucket in self._game_buckets.items()
if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds
]
for key in stale_games:
del self._game_buckets[key]
# Clean user buckets
stale_users = [
user_id
for user_id, bucket in self._user_buckets.items()
if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds
]
for user_id in stale_users:
del self._user_buckets[user_id]
if stale_connections or stale_games or stale_users:
logger.debug(
f"Rate limiter cleanup: {len(stale_connections)} connections, "
f"{len(stale_games)} games, {len(stale_users)} users"
)
except asyncio.CancelledError:
logger.info("Rate limiter cleanup task cancelled")
break
except Exception as e:
logger.error(f"Rate limiter cleanup error: {e}", exc_info=True)
# Continue running despite errors
def get_stats(self) -> dict:
"""Get rate limiter statistics for monitoring."""
return {
"connection_buckets": len(self._connection_buckets),
"game_buckets": len(self._game_buckets),
"user_buckets": len(self._user_buckets),
}
# Global rate limiter instance
rate_limiter = RateLimiter()
def rate_limited(action: str = "general"):
"""
Decorator for rate-limited WebSocket handlers.
Checks both connection-level and game-level rate limits before
allowing the handler to execute.
Args:
action: Type of action for game-level limits.
Options: 'decision', 'roll', 'substitution', 'general'
'general' only applies connection-level limit.
Usage:
@sio.event
@rate_limited(action="decision")
async def submit_defensive_decision(sid, data):
...
Note: The decorator must be applied AFTER @sio.event to work correctly.
"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(sid: str, data=None, *args, **kwargs):
# Import here to avoid circular imports
from app.websocket.connection_manager import ConnectionManager
# Check connection-level limit first
if not await rate_limiter.check_websocket_limit(sid):
# Get the connection manager to emit error
# Note: We can't easily access the manager here, so we'll
# handle rate limiting in the handlers themselves
logger.warning(f"Rate limited handler call from {sid}")
return {"error": "rate_limited", "message": "Rate limited. Please slow down."}
# Check game-level limit if applicable
if action != "general" and isinstance(data, dict):
game_id = data.get("game_id")
if game_id:
if not await rate_limiter.check_game_limit(str(game_id), action):
logger.warning(
f"Rate limited game action {action} for game {game_id}"
)
return {
"error": "game_rate_limited",
"message": f"Too many {action} requests for this game.",
}
return await func(sid, data, *args, **kwargs)
return wrapper
return decorator

View File

@ -0,0 +1,15 @@
"""Monitoring utilities for the Paper Dynasty game engine."""
from app.monitoring.pool_monitor import (
PoolMonitor,
PoolStats,
init_pool_monitor,
pool_monitor,
)
__all__ = [
"PoolMonitor",
"PoolStats",
"init_pool_monitor",
"pool_monitor",
]

View File

@ -0,0 +1,220 @@
"""
Database connection pool monitoring.
Monitors SQLAlchemy async connection pool health and provides
statistics for observability and alerting.
Key features:
- Real-time pool statistics (checked in/out, overflow)
- Health status classification (healthy/warning/critical)
- Historical stats tracking
- Background monitoring with configurable interval
- Warning logs when pool usage exceeds threshold
Author: Claude
Date: 2025-11-27
"""
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Optional
import pendulum
from sqlalchemy.ext.asyncio import AsyncEngine
from app.config import get_settings
logger = logging.getLogger(f"{__name__}.PoolMonitor")
@dataclass
class PoolStats:
"""
Connection pool statistics snapshot.
Captures the current state of the database connection pool
for monitoring and alerting purposes.
"""
pool_size: int
max_overflow: int
checkedin: int # Available connections
checkedout: int # In-use connections
overflow: int # Overflow connections in use
total_capacity: int
usage_percent: float
timestamp: pendulum.DateTime = field(default_factory=lambda: pendulum.now("UTC"))
class PoolMonitor:
"""
Monitor database connection pool health.
Provides real-time statistics and health status for the SQLAlchemy
connection pool. Useful for detecting pool exhaustion before it
causes request failures.
Usage:
monitor = PoolMonitor(engine)
stats = monitor.get_stats()
health = monitor.get_health_status()
"""
def __init__(
self,
engine: AsyncEngine,
alert_threshold: float = 0.8,
max_history: int = 100,
):
"""
Initialize pool monitor.
Args:
engine: SQLAlchemy async engine to monitor
alert_threshold: Usage percentage to trigger warning (0.8 = 80%)
max_history: Maximum stats snapshots to keep in history
"""
self._engine = engine
self._stats_history: list[PoolStats] = []
self._max_history = max_history
self._alert_threshold = alert_threshold
self._settings = get_settings()
def get_stats(self) -> PoolStats:
"""
Get current pool statistics.
Returns:
PoolStats with current pool state
"""
pool = self._engine.pool
checkedin = pool.checkedin()
checkedout = pool.checkedout()
overflow = pool.overflow()
total_capacity = self._settings.db_pool_size + self._settings.db_max_overflow
usage_percent = checkedout / total_capacity if total_capacity > 0 else 0
stats = PoolStats(
pool_size=self._settings.db_pool_size,
max_overflow=self._settings.db_max_overflow,
checkedin=checkedin,
checkedout=checkedout,
overflow=overflow,
total_capacity=total_capacity,
usage_percent=usage_percent,
)
# Record history
self._stats_history.append(stats)
if len(self._stats_history) > self._max_history:
self._stats_history.pop(0)
# Check for alerts
if usage_percent >= self._alert_threshold:
logger.warning(
f"Connection pool usage high: {usage_percent:.1%} "
f"({checkedout}/{total_capacity})"
)
if overflow > 0:
logger.info(f"Pool overflow active: {overflow} overflow connections")
return stats
def get_health_status(self) -> dict:
"""
Get pool health status for monitoring endpoint.
Returns:
Dict with status, statistics, and timestamp
"""
stats = self.get_stats()
if stats.usage_percent >= 0.9:
status = "critical"
elif stats.usage_percent >= 0.75:
status = "warning"
else:
status = "healthy"
return {
"status": status,
"pool_size": stats.pool_size,
"max_overflow": stats.max_overflow,
"available": stats.checkedin,
"in_use": stats.checkedout,
"overflow_active": stats.overflow,
"total_capacity": stats.total_capacity,
"usage_percent": round(stats.usage_percent * 100, 1),
"timestamp": stats.timestamp.isoformat(),
}
def get_history(self, limit: int = 10) -> list[dict]:
"""
Get recent stats history.
Args:
limit: Maximum number of history entries to return
Returns:
List of stats snapshots
"""
return [
{
"checkedout": s.checkedout,
"usage_percent": round(s.usage_percent * 100, 1),
"timestamp": s.timestamp.isoformat(),
}
for s in self._stats_history[-limit:]
]
async def start_monitoring(self, interval_seconds: int = 60):
"""
Background task to periodically collect stats.
Useful for continuous logging and alerting. Runs until cancelled.
Args:
interval_seconds: Seconds between stat collections
"""
logger.info(f"Starting pool monitoring (interval: {interval_seconds}s)")
while True:
try:
stats = self.get_stats()
logger.debug(
f"Pool stats: {stats.checkedout}/{stats.total_capacity} "
f"({stats.usage_percent:.1%})"
)
await asyncio.sleep(interval_seconds)
except asyncio.CancelledError:
logger.info("Pool monitoring stopped")
break
except Exception as e:
logger.error(f"Pool monitoring error: {e}")
await asyncio.sleep(interval_seconds)
# Global instance (initialized in main.py)
pool_monitor: Optional[PoolMonitor] = None
def init_pool_monitor(engine: AsyncEngine) -> PoolMonitor:
"""
Initialize global pool monitor.
Should be called during application startup.
Args:
engine: SQLAlchemy async engine to monitor
Returns:
Initialized PoolMonitor instance
"""
global pool_monitor
pool_monitor = PoolMonitor(engine)
logger.info("Pool monitor initialized")
return pool_monitor

View File

@ -0,0 +1 @@
# Middleware unit tests

View File

@ -0,0 +1,583 @@
"""
Unit tests for rate limiting middleware.
Tests the token bucket algorithm implementation and rate limiter functionality
for WebSocket and API rate limiting.
Author: Claude
Date: 2025-11-27
"""
import asyncio
from unittest.mock import patch
import pendulum
import pytest
from app.middleware.rate_limit import RateLimitBucket, RateLimiter, rate_limiter
class TestRateLimitBucket:
"""Tests for the RateLimitBucket token bucket implementation."""
def test_bucket_initialization(self):
"""
Test that bucket initializes with correct values.
Bucket should start with max_tokens available and record
the creation time as last_refill.
"""
bucket = RateLimitBucket(
tokens=10, max_tokens=10, refill_rate=1.0
)
assert bucket.tokens == 10
assert bucket.max_tokens == 10
assert bucket.refill_rate == 1.0
assert bucket.last_refill is not None
def test_consume_success(self):
"""
Test successful token consumption.
Should reduce available tokens and return True when
enough tokens are available.
"""
bucket = RateLimitBucket(
tokens=10, max_tokens=10, refill_rate=1.0
)
result = bucket.consume(1)
assert result is True
assert bucket.tokens < 10 # Reduced (may have refilled slightly)
def test_consume_multiple_tokens(self):
"""
Test consuming multiple tokens at once.
Should consume the requested amount if available.
"""
bucket = RateLimitBucket(
tokens=10, max_tokens=10, refill_rate=0.0 # No refill for test
)
result = bucket.consume(5)
assert result is True
assert bucket.tokens == 5
def test_consume_denied_when_empty(self):
"""
Test that consumption is denied when not enough tokens.
Should return False and not reduce tokens below zero.
"""
bucket = RateLimitBucket(
tokens=0, max_tokens=10, refill_rate=0.0 # No refill
)
result = bucket.consume(1)
assert result is False
assert bucket.tokens == 0 # Unchanged
def test_consume_denied_insufficient_tokens(self):
"""
Test denial when requesting more tokens than available.
Should return False without modifying token count.
"""
bucket = RateLimitBucket(
tokens=3, max_tokens=10, refill_rate=0.0
)
result = bucket.consume(5)
assert result is False
assert bucket.tokens == 3 # Unchanged
def test_token_refill_over_time(self):
"""
Test that tokens refill based on elapsed time.
Tokens should accumulate at refill_rate per second.
"""
# Start with 5 tokens
past_time = pendulum.now("UTC").subtract(seconds=5)
bucket = RateLimitBucket(
tokens=5,
max_tokens=10,
refill_rate=1.0, # 1 token per second
last_refill=past_time,
)
# Force refill by consuming (which triggers refill check)
bucket.consume(0)
# Should have refilled ~5 tokens (5 seconds * 1 token/sec)
assert bucket.tokens >= 9.5 # Allow small timing variance
def test_tokens_capped_at_max(self):
"""
Test that tokens don't exceed max_tokens after refill.
Token count should never go above the configured maximum.
"""
past_time = pendulum.now("UTC").subtract(seconds=100)
bucket = RateLimitBucket(
tokens=5,
max_tokens=10,
refill_rate=1.0, # Would add 100 tokens
last_refill=past_time,
)
bucket.consume(0) # Trigger refill
assert bucket.tokens == 10 # Capped at max
def test_get_wait_time_when_available(self):
"""
Test wait time calculation when tokens available.
Should return 0.0 when at least one token is available.
"""
bucket = RateLimitBucket(
tokens=5, max_tokens=10, refill_rate=1.0
)
wait_time = bucket.get_wait_time()
assert wait_time == 0.0
def test_get_wait_time_when_empty(self):
"""
Test wait time calculation when no tokens available.
Should return time needed for one token to refill.
"""
bucket = RateLimitBucket(
tokens=0, max_tokens=10, refill_rate=2.0 # 2 tokens per second
)
wait_time = bucket.get_wait_time()
assert wait_time == 0.5 # 1 token / 2 tokens per second
def test_get_wait_time_partial_tokens(self):
"""
Test wait time with partial tokens available.
Should calculate time to reach 1.0 token.
"""
bucket = RateLimitBucket(
tokens=0.5, max_tokens=10, refill_rate=1.0
)
wait_time = bucket.get_wait_time()
assert wait_time == 0.5 # Need 0.5 more tokens at 1/sec
class TestRateLimiter:
"""Tests for the RateLimiter class."""
@pytest.fixture
def limiter(self):
"""Create a fresh RateLimiter instance for each test."""
return RateLimiter()
def test_get_connection_bucket_creates_new(self, limiter):
"""
Test that get_connection_bucket creates bucket for new sid.
Should create and return a bucket initialized with settings.
"""
bucket = limiter.get_connection_bucket("test_sid_1")
assert bucket is not None
assert bucket.max_tokens > 0
def test_get_connection_bucket_returns_existing(self, limiter):
"""
Test that same bucket is returned for same sid.
Should return the same bucket instance for subsequent calls.
"""
bucket1 = limiter.get_connection_bucket("test_sid_1")
bucket2 = limiter.get_connection_bucket("test_sid_1")
assert bucket1 is bucket2
def test_get_connection_bucket_different_sids(self, limiter):
"""
Test that different sids get different buckets.
Each connection should have its own rate limit bucket.
"""
bucket1 = limiter.get_connection_bucket("sid_1")
bucket2 = limiter.get_connection_bucket("sid_2")
assert bucket1 is not bucket2
def test_get_game_bucket_creates_new(self, limiter):
"""
Test that get_game_bucket creates bucket for new game+action.
Should create bucket keyed by game_id and action type.
"""
bucket = limiter.get_game_bucket("game_123", "decision")
assert bucket is not None
assert bucket.max_tokens > 0
def test_get_game_bucket_different_actions(self, limiter):
"""
Test that different actions get different buckets for same game.
Each action type has its own rate limit within a game.
"""
decision_bucket = limiter.get_game_bucket("game_123", "decision")
roll_bucket = limiter.get_game_bucket("game_123", "roll")
assert decision_bucket is not roll_bucket
def test_get_game_bucket_different_games(self, limiter):
"""
Test that same action on different games gets different buckets.
Each game tracks its own rate limits.
"""
bucket1 = limiter.get_game_bucket("game_1", "decision")
bucket2 = limiter.get_game_bucket("game_2", "decision")
assert bucket1 is not bucket2
def test_get_user_bucket_creates_new(self, limiter):
"""
Test that get_user_bucket creates bucket for new user.
Should create bucket for API rate limiting per user.
"""
bucket = limiter.get_user_bucket(123)
assert bucket is not None
assert bucket.max_tokens > 0
@pytest.mark.asyncio
async def test_check_websocket_limit_allowed(self, limiter):
"""
Test that check_websocket_limit allows requests with tokens.
Should return True when tokens are available.
"""
result = await limiter.check_websocket_limit("test_sid")
assert result is True
@pytest.mark.asyncio
async def test_check_websocket_limit_denied_when_exhausted(self, limiter):
"""
Test that check_websocket_limit denies when rate limited.
Should return False after tokens are exhausted.
"""
sid = "rate_limited_sid"
# Exhaust all tokens
bucket = limiter.get_connection_bucket(sid)
bucket.tokens = 0
result = await limiter.check_websocket_limit(sid)
assert result is False
@pytest.mark.asyncio
async def test_check_game_limit_allowed(self, limiter):
"""
Test that check_game_limit allows requests with tokens.
Should return True when game action limit not reached.
"""
result = await limiter.check_game_limit("game_123", "decision")
assert result is True
@pytest.mark.asyncio
async def test_check_game_limit_denied_when_exhausted(self, limiter):
"""
Test that check_game_limit denies when rate limited.
Should return False after game action tokens exhausted.
"""
game_id = "rate_limited_game"
bucket = limiter.get_game_bucket(game_id, "roll")
bucket.tokens = 0
result = await limiter.check_game_limit(game_id, "roll")
assert result is False
@pytest.mark.asyncio
async def test_check_api_limit_allowed(self, limiter):
"""
Test that check_api_limit allows requests with tokens.
Should return True when user API limit not reached.
"""
result = await limiter.check_api_limit(123)
assert result is True
@pytest.mark.asyncio
async def test_check_api_limit_denied_when_exhausted(self, limiter):
"""
Test that check_api_limit denies when rate limited.
Should return False after user API tokens exhausted.
"""
user_id = 999
bucket = limiter.get_user_bucket(user_id)
bucket.tokens = 0
result = await limiter.check_api_limit(user_id)
assert result is False
def test_remove_connection_cleans_up(self, limiter):
"""
Test that remove_connection removes bucket for sid.
Should clean up bucket when connection closes.
"""
sid = "cleanup_test_sid"
limiter.get_connection_bucket(sid) # Create bucket
assert sid in limiter._connection_buckets
limiter.remove_connection(sid)
assert sid not in limiter._connection_buckets
def test_remove_connection_handles_missing(self, limiter):
"""
Test that remove_connection handles non-existent sid gracefully.
Should not raise error for unknown sid.
"""
limiter.remove_connection("nonexistent_sid") # Should not raise
def test_remove_game_cleans_up_all_actions(self, limiter):
"""
Test that remove_game removes all buckets for a game.
Should clean up decision, roll, substitution buckets for game.
"""
game_id = "cleanup_game"
# Create buckets for different actions
limiter.get_game_bucket(game_id, "decision")
limiter.get_game_bucket(game_id, "roll")
limiter.get_game_bucket(game_id, "substitution")
assert f"{game_id}:decision" in limiter._game_buckets
assert f"{game_id}:roll" in limiter._game_buckets
assert f"{game_id}:substitution" in limiter._game_buckets
limiter.remove_game(game_id)
assert f"{game_id}:decision" not in limiter._game_buckets
assert f"{game_id}:roll" not in limiter._game_buckets
assert f"{game_id}:substitution" not in limiter._game_buckets
def test_remove_game_preserves_other_games(self, limiter):
"""
Test that remove_game only removes specified game's buckets.
Other games' buckets should remain intact.
"""
limiter.get_game_bucket("game_1", "decision")
limiter.get_game_bucket("game_2", "decision")
limiter.remove_game("game_1")
assert "game_1:decision" not in limiter._game_buckets
assert "game_2:decision" in limiter._game_buckets
def test_get_stats_returns_counts(self, limiter):
"""
Test that get_stats returns bucket counts.
Should return counts for all bucket types.
"""
limiter.get_connection_bucket("sid_1")
limiter.get_connection_bucket("sid_2")
limiter.get_game_bucket("game_1", "decision")
limiter.get_user_bucket(123)
stats = limiter.get_stats()
assert stats["connection_buckets"] == 2
assert stats["game_buckets"] == 1
assert stats["user_buckets"] == 1
class TestCleanupTask:
"""Tests for the stale bucket cleanup functionality."""
@pytest.mark.asyncio
async def test_cleanup_removes_stale_connection_buckets(self):
"""
Test that cleanup removes connection buckets older than threshold.
Buckets unused for >10 minutes should be cleaned up.
"""
limiter = RateLimiter()
# Create a bucket with old last_refill time (>10 minutes ago)
stale_time = pendulum.now("UTC").subtract(minutes=15)
bucket = limiter.get_connection_bucket("stale_sid")
bucket.last_refill = stale_time
# Create a fresh bucket
fresh_bucket = limiter.get_connection_bucket("fresh_sid")
# Manually trigger cleanup logic (without running full async task)
now = pendulum.now("UTC")
stale_sids = [
sid
for sid, b in limiter._connection_buckets.items()
if (now - b.last_refill).total_seconds() > 600
]
for sid in stale_sids:
del limiter._connection_buckets[sid]
assert "stale_sid" not in limiter._connection_buckets
assert "fresh_sid" in limiter._connection_buckets
@pytest.mark.asyncio
async def test_cleanup_removes_stale_game_buckets(self):
"""
Test that cleanup removes game buckets older than threshold.
Abandoned game buckets should be cleaned up.
"""
limiter = RateLimiter()
stale_time = pendulum.now("UTC").subtract(minutes=15)
bucket = limiter.get_game_bucket("stale_game", "decision")
bucket.last_refill = stale_time
limiter.get_game_bucket("fresh_game", "decision")
# Manually trigger cleanup
now = pendulum.now("UTC")
stale_keys = [
key
for key, b in limiter._game_buckets.items()
if (now - b.last_refill).total_seconds() > 600
]
for key in stale_keys:
del limiter._game_buckets[key]
assert "stale_game:decision" not in limiter._game_buckets
assert "fresh_game:decision" in limiter._game_buckets
class TestGlobalRateLimiter:
"""Tests for the global rate_limiter instance."""
def test_global_instance_exists(self):
"""
Test that global rate_limiter instance is available.
Should be importable and usable as singleton.
"""
assert rate_limiter is not None
assert isinstance(rate_limiter, RateLimiter)
@pytest.mark.asyncio
async def test_global_instance_functional(self):
"""
Test that global rate_limiter works correctly.
Should accept and rate limit requests.
"""
result = await rate_limiter.check_websocket_limit("global_test_sid")
assert result is True
# Cleanup
rate_limiter.remove_connection("global_test_sid")
class TestRateLimitConfiguration:
"""Tests for rate limit configuration integration."""
def test_connection_bucket_uses_config(self):
"""
Test that connection bucket uses configured limits.
max_tokens should match rate_limit_websocket_per_minute setting.
"""
limiter = RateLimiter()
bucket = limiter.get_connection_bucket("config_test_sid")
# Default is 120 per minute
assert bucket.max_tokens == limiter._settings.rate_limit_websocket_per_minute
limiter.remove_connection("config_test_sid")
def test_decision_bucket_uses_config(self):
"""
Test that decision bucket uses configured limits.
max_tokens should match rate_limit_decision_per_game setting.
"""
limiter = RateLimiter()
bucket = limiter.get_game_bucket("config_game", "decision")
assert bucket.max_tokens == limiter._settings.rate_limit_decision_per_game
limiter.remove_game("config_game")
def test_roll_bucket_uses_config(self):
"""
Test that roll bucket uses configured limits.
max_tokens should match rate_limit_roll_per_game setting.
"""
limiter = RateLimiter()
bucket = limiter.get_game_bucket("config_game", "roll")
assert bucket.max_tokens == limiter._settings.rate_limit_roll_per_game
limiter.remove_game("config_game")
def test_substitution_bucket_uses_config(self):
"""
Test that substitution bucket uses configured limits.
max_tokens should match rate_limit_substitution_per_game setting.
"""
limiter = RateLimiter()
bucket = limiter.get_game_bucket("config_game", "substitution")
assert bucket.max_tokens == limiter._settings.rate_limit_substitution_per_game
limiter.remove_game("config_game")
def test_api_bucket_uses_config(self):
"""
Test that API bucket uses configured limits.
max_tokens should match rate_limit_api_per_minute setting.
"""
limiter = RateLimiter()
bucket = limiter.get_user_bucket(12345)
assert bucket.max_tokens == limiter._settings.rate_limit_api_per_minute
del limiter._user_buckets[12345]

View File

@ -0,0 +1 @@
"""Tests for monitoring utilities."""

View File

@ -0,0 +1,388 @@
"""
Tests for database connection pool monitoring.
Verifies that the PoolMonitor correctly:
- Reports pool statistics (checked in/out, overflow)
- Classifies health status (healthy/warning/critical)
- Tracks history of stats over time
- Logs warnings when usage exceeds threshold
Author: Claude
Date: 2025-11-27
"""
import pytest
from unittest.mock import MagicMock, patch
from app.monitoring.pool_monitor import PoolMonitor, PoolStats
# ============================================================================
# TEST FIXTURES
# ============================================================================
@pytest.fixture
def mock_engine():
"""
Create a mock SQLAlchemy engine with pool stats.
Default configuration: 15 available, 5 in use, no overflow.
"""
engine = MagicMock()
pool = MagicMock()
pool.checkedin.return_value = 15
pool.checkedout.return_value = 5
pool.overflow.return_value = 0
engine.pool = pool
return engine
@pytest.fixture
def mock_settings():
"""Mock settings with pool configuration."""
settings = MagicMock()
settings.db_pool_size = 20
settings.db_max_overflow = 10
return settings
# ============================================================================
# POOL STATS TESTS
# ============================================================================
class TestPoolMonitorStats:
"""Tests for PoolMonitor.get_stats()."""
def test_get_stats_returns_pool_stats(self, mock_engine, mock_settings):
"""
Verify get_stats() returns PoolStats with correct values.
The returned PoolStats should reflect the current pool state
from the engine's pool attribute.
"""
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine)
stats = monitor.get_stats()
assert isinstance(stats, PoolStats)
assert stats.checkedout == 5
assert stats.checkedin == 15
assert stats.overflow == 0
assert stats.pool_size == 20
assert stats.max_overflow == 10
assert stats.total_capacity == 30
def test_get_stats_calculates_usage_percent(self, mock_engine, mock_settings):
"""
Verify usage percentage is calculated correctly.
Usage = checkedout / total_capacity
"""
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine)
stats = monitor.get_stats()
# 5 out of 30 = 16.67%
assert stats.usage_percent == pytest.approx(5 / 30, rel=0.01)
def test_get_stats_handles_zero_capacity(self, mock_engine):
"""
Verify get_stats() handles zero capacity without division error.
Edge case: if both pool_size and max_overflow are 0.
"""
settings = MagicMock()
settings.db_pool_size = 0
settings.db_max_overflow = 0
with patch("app.monitoring.pool_monitor.get_settings", return_value=settings):
monitor = PoolMonitor(mock_engine)
stats = monitor.get_stats()
assert stats.usage_percent == 0
assert stats.total_capacity == 0
def test_get_stats_records_history(self, mock_engine, mock_settings):
"""
Verify get_stats() records each call in history.
History is used for monitoring dashboards and alerts.
"""
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine)
# Make multiple calls
for _ in range(5):
monitor.get_stats()
history = monitor.get_history(limit=10)
assert len(history) == 5
def test_get_stats_limits_history_size(self, mock_engine, mock_settings):
"""
Verify history is limited to prevent memory growth.
Older entries should be removed when max_history is exceeded.
"""
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine, max_history=5)
# Make more calls than max_history
for _ in range(10):
monitor.get_stats()
history = monitor.get_history(limit=10)
assert len(history) == 5
# ============================================================================
# HEALTH STATUS TESTS
# ============================================================================
class TestPoolMonitorHealth:
"""Tests for PoolMonitor.get_health_status()."""
def test_health_status_healthy(self, mock_engine, mock_settings):
"""
Verify health status is 'healthy' when usage < 75%.
At 5/30 (16.7%) usage, status should be healthy.
"""
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine)
health = monitor.get_health_status()
assert health["status"] == "healthy"
assert health["in_use"] == 5
assert health["available"] == 15
def test_health_status_warning(self, mock_engine, mock_settings):
"""
Verify health status is 'warning' when usage is 75-90%.
At 24/30 (80%) usage, status should be warning.
"""
mock_engine.pool.checkedout.return_value = 24
mock_engine.pool.checkedin.return_value = 6
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine)
health = monitor.get_health_status()
assert health["status"] == "warning"
assert health["usage_percent"] == pytest.approx(80, rel=1)
def test_health_status_critical(self, mock_engine, mock_settings):
"""
Verify health status is 'critical' when usage >= 90%.
At 28/30 (93%) usage, status should be critical.
"""
mock_engine.pool.checkedout.return_value = 28
mock_engine.pool.checkedin.return_value = 2
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine)
health = monitor.get_health_status()
assert health["status"] == "critical"
assert health["usage_percent"] >= 90
def test_health_status_includes_timestamp(self, mock_engine, mock_settings):
"""
Verify health status includes ISO timestamp.
Timestamp is used for monitoring and alerting correlation.
"""
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine)
health = monitor.get_health_status()
assert "timestamp" in health
assert isinstance(health["timestamp"], str)
assert "T" in health["timestamp"] # ISO format
def test_health_status_includes_all_fields(self, mock_engine, mock_settings):
"""
Verify health status includes all required fields.
All fields are needed for complete monitoring visibility.
"""
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine)
health = monitor.get_health_status()
required_fields = [
"status",
"pool_size",
"max_overflow",
"available",
"in_use",
"overflow_active",
"total_capacity",
"usage_percent",
"timestamp",
]
for field in required_fields:
assert field in health, f"Missing field: {field}"
# ============================================================================
# HISTORY TESTS
# ============================================================================
class TestPoolMonitorHistory:
"""Tests for PoolMonitor.get_history()."""
def test_get_history_returns_recent_stats(self, mock_engine, mock_settings):
"""
Verify get_history() returns recent stat snapshots.
History should include checkedout, usage_percent, and timestamp.
"""
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine)
# Generate some history
for _ in range(3):
monitor.get_stats()
history = monitor.get_history(limit=5)
assert len(history) == 3
for entry in history:
assert "checkedout" in entry
assert "usage_percent" in entry
assert "timestamp" in entry
def test_get_history_respects_limit(self, mock_engine, mock_settings):
"""
Verify get_history() respects the limit parameter.
Should return at most 'limit' entries.
"""
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine)
for _ in range(10):
monitor.get_stats()
history = monitor.get_history(limit=3)
assert len(history) == 3
def test_get_history_returns_empty_when_no_stats(self, mock_engine, mock_settings):
"""
Verify get_history() returns empty list when no stats collected.
Fresh monitor should have no history.
"""
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine)
history = monitor.get_history()
assert history == []
# ============================================================================
# ALERT THRESHOLD TESTS
# ============================================================================
class TestPoolMonitorAlerts:
"""Tests for alert threshold behavior."""
def test_logs_warning_when_threshold_exceeded(self, mock_engine, mock_settings, caplog):
"""
Verify warning is logged when usage exceeds alert threshold.
Default threshold is 80%. At 85% usage, warning should be logged.
"""
mock_engine.pool.checkedout.return_value = 26 # 86.7%
mock_engine.pool.checkedin.return_value = 4
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine, alert_threshold=0.8)
import logging
with caplog.at_level(logging.WARNING):
monitor.get_stats()
assert "Connection pool usage high" in caplog.text
def test_logs_info_when_overflow_active(self, mock_engine, mock_settings, caplog):
"""
Verify info is logged when overflow connections are in use.
Overflow indicates pool is at capacity and using extra connections.
"""
mock_engine.pool.overflow.return_value = 3
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine)
import logging
with caplog.at_level(logging.INFO):
monitor.get_stats()
assert "Pool overflow active" in caplog.text
def test_no_warning_when_below_threshold(self, mock_engine, mock_settings, caplog):
"""
Verify no warning when usage is below threshold.
At 50% usage with 80% threshold, no warning should be logged.
"""
mock_engine.pool.checkedout.return_value = 10 # 33%
mock_engine.pool.checkedin.return_value = 20
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = PoolMonitor(mock_engine, alert_threshold=0.8)
import logging
with caplog.at_level(logging.WARNING):
monitor.get_stats()
assert "Connection pool usage high" not in caplog.text
# ============================================================================
# INIT FUNCTION TESTS
# ============================================================================
class TestInitPoolMonitor:
"""Tests for init_pool_monitor function."""
def test_init_pool_monitor_creates_global_instance(self, mock_engine, mock_settings):
"""
Verify init_pool_monitor() creates global instance.
The global instance is used by health endpoints.
"""
from app.monitoring.pool_monitor import init_pool_monitor, pool_monitor
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = init_pool_monitor(mock_engine)
assert monitor is not None
assert isinstance(monitor, PoolMonitor)
def test_init_pool_monitor_returns_monitor(self, mock_engine, mock_settings):
"""
Verify init_pool_monitor() returns the created monitor.
Return value allows caller to use the monitor directly.
"""
from app.monitoring.pool_monitor import init_pool_monitor
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
monitor = init_pool_monitor(mock_engine)
# Should be able to use the returned monitor
stats = monitor.get_stats()
assert stats is not None