CLAUDE: Add rate limiting, pool monitoring, and exception infrastructure
- Add rate_limit.py middleware with per-client throttling and cleanup task - Add pool_monitor.py for database connection pool health monitoring - Add custom exceptions module (GameEngineError, DatabaseError, etc.) - Add config settings for eviction intervals, session timeouts, memory limits - Add unit tests for rate limiting and pool monitoring 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
ae4a92f0e0
commit
2a392b87f8
@ -47,6 +47,19 @@ class Settings(BaseSettings):
|
|||||||
max_concurrent_games: int = 20
|
max_concurrent_games: int = 20
|
||||||
game_idle_timeout: int = 86400 # 24 hours
|
game_idle_timeout: int = 86400 # 24 hours
|
||||||
|
|
||||||
|
# Game eviction settings (memory management)
|
||||||
|
game_idle_timeout_hours: int = 24 # Evict games idle > 24 hours
|
||||||
|
game_eviction_interval_minutes: int = 60 # Check every hour
|
||||||
|
game_max_in_memory: int = 500 # Hard limit on in-memory games
|
||||||
|
|
||||||
|
# Rate limiting settings
|
||||||
|
rate_limit_websocket_per_minute: int = 120 # Events per minute per connection
|
||||||
|
rate_limit_api_per_minute: int = 100 # API calls per minute per user
|
||||||
|
rate_limit_decision_per_game: int = 20 # Decisions per minute per game
|
||||||
|
rate_limit_roll_per_game: int = 30 # Rolls per minute per game
|
||||||
|
rate_limit_substitution_per_game: int = 15 # Substitutions per minute per game
|
||||||
|
rate_limit_cleanup_interval: int = 300 # Cleanup stale buckets every 5 min
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
case_sensitive = False
|
case_sensitive = False
|
||||||
|
|||||||
214
backend/app/core/exceptions.py
Normal file
214
backend/app/core/exceptions.py
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
"""
|
||||||
|
Custom exceptions for the game engine.
|
||||||
|
|
||||||
|
Provides a hierarchy of specific exception types to replace broad `except Exception`
|
||||||
|
patterns, improving debugging and error handling precision.
|
||||||
|
|
||||||
|
Exception Hierarchy:
|
||||||
|
GameEngineError (base)
|
||||||
|
├── GameNotFoundError - Game doesn't exist in state manager
|
||||||
|
├── InvalidGameStateError - Game in wrong state for operation
|
||||||
|
├── SubstitutionError - Invalid player substitution
|
||||||
|
├── AuthorizationError - User lacks permission
|
||||||
|
├── DecisionTimeoutError - Decision not submitted in time
|
||||||
|
└── ExternalAPIError - External service call failed
|
||||||
|
└── PlayerDataError - Failed to fetch player data
|
||||||
|
|
||||||
|
Author: Claude
|
||||||
|
Date: 2025-11-27
|
||||||
|
"""
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
class GameEngineError(Exception):
|
||||||
|
"""
|
||||||
|
Base exception for all game engine errors.
|
||||||
|
|
||||||
|
All game-specific exceptions should inherit from this class
|
||||||
|
to allow catching game errors separately from system errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GameNotFoundError(GameEngineError):
|
||||||
|
"""
|
||||||
|
Raised when a game doesn't exist in state manager.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
game_id: The UUID of the missing game
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, game_id: UUID | str):
|
||||||
|
self.game_id = game_id
|
||||||
|
super().__init__(f"Game not found: {game_id}")
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidGameStateError(GameEngineError):
|
||||||
|
"""
|
||||||
|
Raised when game is in invalid state for requested operation.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- Trying to submit decision when game is completed
|
||||||
|
- Rolling dice before both decisions submitted
|
||||||
|
- Starting an already-started game
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
message: Description of the state violation
|
||||||
|
current_state: Optional current state value for debugging
|
||||||
|
expected_state: Optional expected state value for debugging
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
current_state: str | None = None,
|
||||||
|
expected_state: str | None = None,
|
||||||
|
):
|
||||||
|
self.current_state = current_state
|
||||||
|
self.expected_state = expected_state
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class SubstitutionError(GameEngineError):
|
||||||
|
"""
|
||||||
|
Raised when a player substitution is invalid.
|
||||||
|
|
||||||
|
Used for substitution rule violations, not found errors.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
message: Description of the substitution error
|
||||||
|
error_code: Machine-readable error code for frontend handling
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, error_code: str = "SUBSTITUTION_ERROR"):
|
||||||
|
self.error_code = error_code
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationError(GameEngineError):
|
||||||
|
"""
|
||||||
|
Raised when user lacks permission for an operation.
|
||||||
|
|
||||||
|
Note: Currently deferred (task 001) but defined for future use.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
message: Description of the authorization failure
|
||||||
|
user_id: Optional user ID for logging
|
||||||
|
resource: Optional resource being accessed
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
user_id: int | None = None,
|
||||||
|
resource: str | None = None,
|
||||||
|
):
|
||||||
|
self.user_id = user_id
|
||||||
|
self.resource = resource
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class DecisionTimeoutError(GameEngineError):
|
||||||
|
"""
|
||||||
|
Raised when a decision is not submitted within the timeout period.
|
||||||
|
|
||||||
|
The game engine uses default decisions when this occurs,
|
||||||
|
so this exception is informational rather than fatal.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
game_id: Game where timeout occurred
|
||||||
|
decision_type: Type of decision that timed out ('defensive' or 'offensive')
|
||||||
|
timeout_seconds: How long we waited
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
game_id: UUID | str,
|
||||||
|
decision_type: str,
|
||||||
|
timeout_seconds: int,
|
||||||
|
):
|
||||||
|
self.game_id = game_id
|
||||||
|
self.decision_type = decision_type
|
||||||
|
self.timeout_seconds = timeout_seconds
|
||||||
|
super().__init__(
|
||||||
|
f"Decision timeout for game {game_id}: "
|
||||||
|
f"{decision_type} decision not received within {timeout_seconds}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalAPIError(GameEngineError):
|
||||||
|
"""
|
||||||
|
Raised when an external API call fails.
|
||||||
|
|
||||||
|
Base class for external service errors.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
service: Name of the external service
|
||||||
|
message: Error description
|
||||||
|
status_code: Optional HTTP status code
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
service: str,
|
||||||
|
message: str,
|
||||||
|
status_code: int | None = None,
|
||||||
|
):
|
||||||
|
self.service = service
|
||||||
|
self.status_code = status_code
|
||||||
|
super().__init__(f"{service} API error: {message}")
|
||||||
|
|
||||||
|
|
||||||
|
class PlayerDataError(ExternalAPIError):
|
||||||
|
"""
|
||||||
|
Raised when player data cannot be fetched from external API.
|
||||||
|
|
||||||
|
Specialized error for player data lookups (SBA API, PD API).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
player_id: ID of the player that couldn't be fetched
|
||||||
|
service: Which API was called
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, player_id: int, service: str = "SBA API"):
|
||||||
|
self.player_id = player_id
|
||||||
|
super().__init__(
|
||||||
|
service=service,
|
||||||
|
message=f"Failed to fetch player data for ID {player_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseError(GameEngineError):
|
||||||
|
"""
|
||||||
|
Raised when a database operation fails.
|
||||||
|
|
||||||
|
Wraps SQLAlchemy exceptions with game context.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
operation: What operation was attempted (e.g., 'save_play', 'create_substitution')
|
||||||
|
original_error: The underlying database exception
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, operation: str, original_error: Exception | None = None):
|
||||||
|
self.operation = operation
|
||||||
|
self.original_error = original_error
|
||||||
|
message = f"Database error during {operation}"
|
||||||
|
if original_error:
|
||||||
|
message += f": {original_error}"
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class LineupError(GameEngineError):
|
||||||
|
"""
|
||||||
|
Raised when lineup validation or lookup fails.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
team_id: Team with the lineup issue
|
||||||
|
message: Description of the error
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, team_id: int, message: str):
|
||||||
|
self.team_id = team_id
|
||||||
|
super().__init__(f"Lineup error for team {team_id}: {message}")
|
||||||
0
backend/app/middleware/__init__.py
Normal file
0
backend/app/middleware/__init__.py
Normal file
328
backend/app/middleware/rate_limit.py
Normal file
328
backend/app/middleware/rate_limit.py
Normal file
@ -0,0 +1,328 @@
|
|||||||
|
"""
|
||||||
|
Rate limiting utilities for WebSocket and API endpoints.
|
||||||
|
|
||||||
|
Implements token bucket algorithm for smooth rate limiting that allows
|
||||||
|
bursts while enforcing long-term rate limits.
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- Per-connection rate limiting for WebSocket events
|
||||||
|
- Per-game rate limiting for game actions (decisions, rolls, substitutions)
|
||||||
|
- Per-user rate limiting for REST API endpoints
|
||||||
|
- Automatic cleanup of stale buckets
|
||||||
|
- Non-blocking async design
|
||||||
|
|
||||||
|
Author: Claude
|
||||||
|
Date: 2025-11-27
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import wraps
|
||||||
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
|
||||||
|
import pendulum
|
||||||
|
|
||||||
|
from app.config import get_settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from socketio import AsyncServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(f"{__name__}.RateLimiter")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RateLimitBucket:
|
||||||
|
"""
|
||||||
|
Token bucket for rate limiting.
|
||||||
|
|
||||||
|
The token bucket algorithm works by:
|
||||||
|
1. Bucket starts with max_tokens
|
||||||
|
2. Each request consumes tokens
|
||||||
|
3. Tokens refill at refill_rate per second
|
||||||
|
4. Requests are denied when tokens < 1
|
||||||
|
|
||||||
|
This allows short bursts while enforcing average rate limits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokens: float
|
||||||
|
max_tokens: int
|
||||||
|
refill_rate: float # tokens per second
|
||||||
|
last_refill: pendulum.DateTime = field(
|
||||||
|
default_factory=lambda: pendulum.now("UTC")
|
||||||
|
)
|
||||||
|
|
||||||
|
def consume(self, tokens: int = 1) -> bool:
|
||||||
|
"""
|
||||||
|
Try to consume tokens.
|
||||||
|
|
||||||
|
Returns True if allowed, False if rate limited.
|
||||||
|
"""
|
||||||
|
self._refill()
|
||||||
|
if self.tokens >= tokens:
|
||||||
|
self.tokens -= tokens
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _refill(self) -> None:
|
||||||
|
"""Refill tokens based on time elapsed since last refill."""
|
||||||
|
now = pendulum.now("UTC")
|
||||||
|
elapsed = (now - self.last_refill).total_seconds()
|
||||||
|
refill_amount = elapsed * self.refill_rate
|
||||||
|
|
||||||
|
if refill_amount > 0:
|
||||||
|
self.tokens = min(self.max_tokens, self.tokens + refill_amount)
|
||||||
|
self.last_refill = now
|
||||||
|
|
||||||
|
def get_wait_time(self) -> float:
|
||||||
|
"""
|
||||||
|
Get seconds until a token will be available.
|
||||||
|
|
||||||
|
Useful for informing clients how long to wait.
|
||||||
|
"""
|
||||||
|
if self.tokens >= 1:
|
||||||
|
return 0.0
|
||||||
|
tokens_needed = 1 - self.tokens
|
||||||
|
return tokens_needed / self.refill_rate
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""
|
||||||
|
Rate limiter for WebSocket connections and API endpoints.
|
||||||
|
|
||||||
|
Uses token bucket algorithm for smooth rate limiting that allows
|
||||||
|
bursts while enforcing average rate limits.
|
||||||
|
|
||||||
|
Thread-safe for async operations via atomic dict operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Per-connection buckets (sid -> bucket)
|
||||||
|
self._connection_buckets: dict[str, RateLimitBucket] = {}
|
||||||
|
# Per-game buckets (game_id:action -> bucket)
|
||||||
|
self._game_buckets: dict[str, RateLimitBucket] = {}
|
||||||
|
# Per-user API buckets (user_id -> bucket)
|
||||||
|
self._user_buckets: dict[int, RateLimitBucket] = {}
|
||||||
|
# Cleanup task handle
|
||||||
|
self._cleanup_task: asyncio.Task | None = None
|
||||||
|
# Settings cache
|
||||||
|
self._settings = get_settings()
|
||||||
|
|
||||||
|
def get_connection_bucket(self, sid: str) -> RateLimitBucket:
|
||||||
|
"""Get or create bucket for WebSocket connection."""
|
||||||
|
if sid not in self._connection_buckets:
|
||||||
|
limit = self._settings.rate_limit_websocket_per_minute
|
||||||
|
self._connection_buckets[sid] = RateLimitBucket(
|
||||||
|
tokens=limit,
|
||||||
|
max_tokens=limit,
|
||||||
|
refill_rate=limit / 60, # per minute -> per second
|
||||||
|
)
|
||||||
|
return self._connection_buckets[sid]
|
||||||
|
|
||||||
|
def get_game_bucket(self, game_id: str, action: str) -> RateLimitBucket:
|
||||||
|
"""
|
||||||
|
Get or create bucket for game-specific action.
|
||||||
|
|
||||||
|
Actions: 'decision', 'roll', 'substitution'
|
||||||
|
"""
|
||||||
|
key = f"{game_id}:{action}"
|
||||||
|
if key not in self._game_buckets:
|
||||||
|
# Get limit based on action type
|
||||||
|
if action == "decision":
|
||||||
|
limit = self._settings.rate_limit_decision_per_game
|
||||||
|
elif action == "roll":
|
||||||
|
limit = self._settings.rate_limit_roll_per_game
|
||||||
|
elif action == "substitution":
|
||||||
|
limit = self._settings.rate_limit_substitution_per_game
|
||||||
|
else:
|
||||||
|
limit = 30 # Default for unknown actions
|
||||||
|
|
||||||
|
self._game_buckets[key] = RateLimitBucket(
|
||||||
|
tokens=limit,
|
||||||
|
max_tokens=limit,
|
||||||
|
refill_rate=limit / 60,
|
||||||
|
)
|
||||||
|
return self._game_buckets[key]
|
||||||
|
|
||||||
|
def get_user_bucket(self, user_id: int) -> RateLimitBucket:
|
||||||
|
"""Get or create bucket for API user."""
|
||||||
|
if user_id not in self._user_buckets:
|
||||||
|
limit = self._settings.rate_limit_api_per_minute
|
||||||
|
self._user_buckets[user_id] = RateLimitBucket(
|
||||||
|
tokens=limit,
|
||||||
|
max_tokens=limit,
|
||||||
|
refill_rate=limit / 60,
|
||||||
|
)
|
||||||
|
return self._user_buckets[user_id]
|
||||||
|
|
||||||
|
async def check_websocket_limit(self, sid: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if WebSocket event is allowed for connection.
|
||||||
|
|
||||||
|
Returns True if allowed, False if rate limited.
|
||||||
|
"""
|
||||||
|
bucket = self.get_connection_bucket(sid)
|
||||||
|
allowed = bucket.consume()
|
||||||
|
if not allowed:
|
||||||
|
logger.warning(f"Rate limited WebSocket connection: {sid}")
|
||||||
|
return allowed
|
||||||
|
|
||||||
|
async def check_game_limit(self, game_id: str, action: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if game action is allowed.
|
||||||
|
|
||||||
|
Returns True if allowed, False if rate limited.
|
||||||
|
"""
|
||||||
|
bucket = self.get_game_bucket(game_id, action)
|
||||||
|
allowed = bucket.consume()
|
||||||
|
if not allowed:
|
||||||
|
logger.warning(f"Rate limited game action: {game_id} {action}")
|
||||||
|
return allowed
|
||||||
|
|
||||||
|
async def check_api_limit(self, user_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
Check if API call is allowed for user.
|
||||||
|
|
||||||
|
Returns True if allowed, False if rate limited.
|
||||||
|
"""
|
||||||
|
bucket = self.get_user_bucket(user_id)
|
||||||
|
allowed = bucket.consume()
|
||||||
|
if not allowed:
|
||||||
|
logger.warning(f"Rate limited API user: {user_id}")
|
||||||
|
return allowed
|
||||||
|
|
||||||
|
def remove_connection(self, sid: str) -> None:
|
||||||
|
"""Clean up buckets when connection closes."""
|
||||||
|
self._connection_buckets.pop(sid, None)
|
||||||
|
|
||||||
|
def remove_game(self, game_id: str) -> None:
|
||||||
|
"""Clean up all buckets for a game when it ends."""
|
||||||
|
keys_to_remove = [
|
||||||
|
key for key in self._game_buckets if key.startswith(f"{game_id}:")
|
||||||
|
]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._game_buckets[key]
|
||||||
|
|
||||||
|
async def cleanup_stale_buckets(self) -> None:
|
||||||
|
"""
|
||||||
|
Background task to periodically clean up stale buckets.
|
||||||
|
|
||||||
|
Removes buckets that haven't been used in 10 minutes to prevent
|
||||||
|
memory leaks from abandoned connections.
|
||||||
|
"""
|
||||||
|
interval = self._settings.rate_limit_cleanup_interval
|
||||||
|
stale_threshold_seconds = 600 # 10 minutes
|
||||||
|
|
||||||
|
logger.info(f"Starting rate limiter cleanup task (interval: {interval}s)")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
now = pendulum.now("UTC")
|
||||||
|
|
||||||
|
# Clean connection buckets
|
||||||
|
stale_connections = [
|
||||||
|
sid
|
||||||
|
for sid, bucket in self._connection_buckets.items()
|
||||||
|
if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds
|
||||||
|
]
|
||||||
|
for sid in stale_connections:
|
||||||
|
del self._connection_buckets[sid]
|
||||||
|
|
||||||
|
# Clean game buckets
|
||||||
|
stale_games = [
|
||||||
|
key
|
||||||
|
for key, bucket in self._game_buckets.items()
|
||||||
|
if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds
|
||||||
|
]
|
||||||
|
for key in stale_games:
|
||||||
|
del self._game_buckets[key]
|
||||||
|
|
||||||
|
# Clean user buckets
|
||||||
|
stale_users = [
|
||||||
|
user_id
|
||||||
|
for user_id, bucket in self._user_buckets.items()
|
||||||
|
if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds
|
||||||
|
]
|
||||||
|
for user_id in stale_users:
|
||||||
|
del self._user_buckets[user_id]
|
||||||
|
|
||||||
|
if stale_connections or stale_games or stale_users:
|
||||||
|
logger.debug(
|
||||||
|
f"Rate limiter cleanup: {len(stale_connections)} connections, "
|
||||||
|
f"{len(stale_games)} games, {len(stale_users)} users"
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Rate limiter cleanup task cancelled")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Rate limiter cleanup error: {e}", exc_info=True)
|
||||||
|
# Continue running despite errors
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""Get rate limiter statistics for monitoring."""
|
||||||
|
return {
|
||||||
|
"connection_buckets": len(self._connection_buckets),
|
||||||
|
"game_buckets": len(self._game_buckets),
|
||||||
|
"user_buckets": len(self._user_buckets),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global rate limiter instance
|
||||||
|
rate_limiter = RateLimiter()
|
||||||
|
|
||||||
|
|
||||||
|
def rate_limited(action: str = "general"):
|
||||||
|
"""
|
||||||
|
Decorator for rate-limited WebSocket handlers.
|
||||||
|
|
||||||
|
Checks both connection-level and game-level rate limits before
|
||||||
|
allowing the handler to execute.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: Type of action for game-level limits.
|
||||||
|
Options: 'decision', 'roll', 'substitution', 'general'
|
||||||
|
'general' only applies connection-level limit.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@sio.event
|
||||||
|
@rate_limited(action="decision")
|
||||||
|
async def submit_defensive_decision(sid, data):
|
||||||
|
...
|
||||||
|
|
||||||
|
Note: The decorator must be applied AFTER @sio.event to work correctly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(sid: str, data=None, *args, **kwargs):
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from app.websocket.connection_manager import ConnectionManager
|
||||||
|
|
||||||
|
# Check connection-level limit first
|
||||||
|
if not await rate_limiter.check_websocket_limit(sid):
|
||||||
|
# Get the connection manager to emit error
|
||||||
|
# Note: We can't easily access the manager here, so we'll
|
||||||
|
# handle rate limiting in the handlers themselves
|
||||||
|
logger.warning(f"Rate limited handler call from {sid}")
|
||||||
|
return {"error": "rate_limited", "message": "Rate limited. Please slow down."}
|
||||||
|
|
||||||
|
# Check game-level limit if applicable
|
||||||
|
if action != "general" and isinstance(data, dict):
|
||||||
|
game_id = data.get("game_id")
|
||||||
|
if game_id:
|
||||||
|
if not await rate_limiter.check_game_limit(str(game_id), action):
|
||||||
|
logger.warning(
|
||||||
|
f"Rate limited game action {action} for game {game_id}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"error": "game_rate_limited",
|
||||||
|
"message": f"Too many {action} requests for this game.",
|
||||||
|
}
|
||||||
|
|
||||||
|
return await func(sid, data, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
15
backend/app/monitoring/__init__.py
Normal file
15
backend/app/monitoring/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
"""Monitoring utilities for the Paper Dynasty game engine."""
|
||||||
|
|
||||||
|
from app.monitoring.pool_monitor import (
|
||||||
|
PoolMonitor,
|
||||||
|
PoolStats,
|
||||||
|
init_pool_monitor,
|
||||||
|
pool_monitor,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PoolMonitor",
|
||||||
|
"PoolStats",
|
||||||
|
"init_pool_monitor",
|
||||||
|
"pool_monitor",
|
||||||
|
]
|
||||||
220
backend/app/monitoring/pool_monitor.py
Normal file
220
backend/app/monitoring/pool_monitor.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
"""
|
||||||
|
Database connection pool monitoring.
|
||||||
|
|
||||||
|
Monitors SQLAlchemy async connection pool health and provides
|
||||||
|
statistics for observability and alerting.
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- Real-time pool statistics (checked in/out, overflow)
|
||||||
|
- Health status classification (healthy/warning/critical)
|
||||||
|
- Historical stats tracking
|
||||||
|
- Background monitoring with configurable interval
|
||||||
|
- Warning logs when pool usage exceeds threshold
|
||||||
|
|
||||||
|
Author: Claude
|
||||||
|
Date: 2025-11-27
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pendulum
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||||
|
|
||||||
|
from app.config import get_settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(f"{__name__}.PoolMonitor")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PoolStats:
|
||||||
|
"""
|
||||||
|
Connection pool statistics snapshot.
|
||||||
|
|
||||||
|
Captures the current state of the database connection pool
|
||||||
|
for monitoring and alerting purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pool_size: int
|
||||||
|
max_overflow: int
|
||||||
|
checkedin: int # Available connections
|
||||||
|
checkedout: int # In-use connections
|
||||||
|
overflow: int # Overflow connections in use
|
||||||
|
total_capacity: int
|
||||||
|
usage_percent: float
|
||||||
|
timestamp: pendulum.DateTime = field(default_factory=lambda: pendulum.now("UTC"))
|
||||||
|
|
||||||
|
|
||||||
|
class PoolMonitor:
|
||||||
|
"""
|
||||||
|
Monitor database connection pool health.
|
||||||
|
|
||||||
|
Provides real-time statistics and health status for the SQLAlchemy
|
||||||
|
connection pool. Useful for detecting pool exhaustion before it
|
||||||
|
causes request failures.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
monitor = PoolMonitor(engine)
|
||||||
|
stats = monitor.get_stats()
|
||||||
|
health = monitor.get_health_status()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine: AsyncEngine,
|
||||||
|
alert_threshold: float = 0.8,
|
||||||
|
max_history: int = 100,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize pool monitor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine: SQLAlchemy async engine to monitor
|
||||||
|
alert_threshold: Usage percentage to trigger warning (0.8 = 80%)
|
||||||
|
max_history: Maximum stats snapshots to keep in history
|
||||||
|
"""
|
||||||
|
self._engine = engine
|
||||||
|
self._stats_history: list[PoolStats] = []
|
||||||
|
self._max_history = max_history
|
||||||
|
self._alert_threshold = alert_threshold
|
||||||
|
self._settings = get_settings()
|
||||||
|
|
||||||
|
def get_stats(self) -> PoolStats:
|
||||||
|
"""
|
||||||
|
Get current pool statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PoolStats with current pool state
|
||||||
|
"""
|
||||||
|
pool = self._engine.pool
|
||||||
|
|
||||||
|
checkedin = pool.checkedin()
|
||||||
|
checkedout = pool.checkedout()
|
||||||
|
overflow = pool.overflow()
|
||||||
|
total_capacity = self._settings.db_pool_size + self._settings.db_max_overflow
|
||||||
|
|
||||||
|
usage_percent = checkedout / total_capacity if total_capacity > 0 else 0
|
||||||
|
|
||||||
|
stats = PoolStats(
|
||||||
|
pool_size=self._settings.db_pool_size,
|
||||||
|
max_overflow=self._settings.db_max_overflow,
|
||||||
|
checkedin=checkedin,
|
||||||
|
checkedout=checkedout,
|
||||||
|
overflow=overflow,
|
||||||
|
total_capacity=total_capacity,
|
||||||
|
usage_percent=usage_percent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record history
|
||||||
|
self._stats_history.append(stats)
|
||||||
|
if len(self._stats_history) > self._max_history:
|
||||||
|
self._stats_history.pop(0)
|
||||||
|
|
||||||
|
# Check for alerts
|
||||||
|
if usage_percent >= self._alert_threshold:
|
||||||
|
logger.warning(
|
||||||
|
f"Connection pool usage high: {usage_percent:.1%} "
|
||||||
|
f"({checkedout}/{total_capacity})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if overflow > 0:
|
||||||
|
logger.info(f"Pool overflow active: {overflow} overflow connections")
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def get_health_status(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get pool health status for monitoring endpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with status, statistics, and timestamp
|
||||||
|
"""
|
||||||
|
stats = self.get_stats()
|
||||||
|
|
||||||
|
if stats.usage_percent >= 0.9:
|
||||||
|
status = "critical"
|
||||||
|
elif stats.usage_percent >= 0.75:
|
||||||
|
status = "warning"
|
||||||
|
else:
|
||||||
|
status = "healthy"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": status,
|
||||||
|
"pool_size": stats.pool_size,
|
||||||
|
"max_overflow": stats.max_overflow,
|
||||||
|
"available": stats.checkedin,
|
||||||
|
"in_use": stats.checkedout,
|
||||||
|
"overflow_active": stats.overflow,
|
||||||
|
"total_capacity": stats.total_capacity,
|
||||||
|
"usage_percent": round(stats.usage_percent * 100, 1),
|
||||||
|
"timestamp": stats.timestamp.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_history(self, limit: int = 10) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Get recent stats history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of history entries to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of stats snapshots
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"checkedout": s.checkedout,
|
||||||
|
"usage_percent": round(s.usage_percent * 100, 1),
|
||||||
|
"timestamp": s.timestamp.isoformat(),
|
||||||
|
}
|
||||||
|
for s in self._stats_history[-limit:]
|
||||||
|
]
|
||||||
|
|
||||||
|
async def start_monitoring(self, interval_seconds: int = 60):
|
||||||
|
"""
|
||||||
|
Background task to periodically collect stats.
|
||||||
|
|
||||||
|
Useful for continuous logging and alerting. Runs until cancelled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interval_seconds: Seconds between stat collections
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting pool monitoring (interval: {interval_seconds}s)")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
stats = self.get_stats()
|
||||||
|
logger.debug(
|
||||||
|
f"Pool stats: {stats.checkedout}/{stats.total_capacity} "
|
||||||
|
f"({stats.usage_percent:.1%})"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(interval_seconds)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Pool monitoring stopped")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Pool monitoring error: {e}")
|
||||||
|
await asyncio.sleep(interval_seconds)
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance (initialized in main.py)
|
||||||
|
pool_monitor: Optional[PoolMonitor] = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_pool_monitor(engine: AsyncEngine) -> PoolMonitor:
|
||||||
|
"""
|
||||||
|
Initialize global pool monitor.
|
||||||
|
|
||||||
|
Should be called during application startup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine: SQLAlchemy async engine to monitor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Initialized PoolMonitor instance
|
||||||
|
"""
|
||||||
|
global pool_monitor
|
||||||
|
pool_monitor = PoolMonitor(engine)
|
||||||
|
logger.info("Pool monitor initialized")
|
||||||
|
return pool_monitor
|
||||||
1
backend/tests/unit/middleware/__init__.py
Normal file
1
backend/tests/unit/middleware/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Middleware unit tests
|
||||||
583
backend/tests/unit/middleware/test_rate_limit.py
Normal file
583
backend/tests/unit/middleware/test_rate_limit.py
Normal file
@ -0,0 +1,583 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for rate limiting middleware.
|
||||||
|
|
||||||
|
Tests the token bucket algorithm implementation and rate limiter functionality
|
||||||
|
for WebSocket and API rate limiting.
|
||||||
|
|
||||||
|
Author: Claude
|
||||||
|
Date: 2025-11-27
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pendulum
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.middleware.rate_limit import RateLimitBucket, RateLimiter, rate_limiter
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitBucket:
|
||||||
|
"""Tests for the RateLimitBucket token bucket implementation."""
|
||||||
|
|
||||||
|
def test_bucket_initialization(self):
|
||||||
|
"""
|
||||||
|
Test that bucket initializes with correct values.
|
||||||
|
|
||||||
|
Bucket should start with max_tokens available and record
|
||||||
|
the creation time as last_refill.
|
||||||
|
"""
|
||||||
|
bucket = RateLimitBucket(
|
||||||
|
tokens=10, max_tokens=10, refill_rate=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert bucket.tokens == 10
|
||||||
|
assert bucket.max_tokens == 10
|
||||||
|
assert bucket.refill_rate == 1.0
|
||||||
|
assert bucket.last_refill is not None
|
||||||
|
|
||||||
|
def test_consume_success(self):
|
||||||
|
"""
|
||||||
|
Test successful token consumption.
|
||||||
|
|
||||||
|
Should reduce available tokens and return True when
|
||||||
|
enough tokens are available.
|
||||||
|
"""
|
||||||
|
bucket = RateLimitBucket(
|
||||||
|
tokens=10, max_tokens=10, refill_rate=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
result = bucket.consume(1)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert bucket.tokens < 10 # Reduced (may have refilled slightly)
|
||||||
|
|
||||||
|
def test_consume_multiple_tokens(self):
|
||||||
|
"""
|
||||||
|
Test consuming multiple tokens at once.
|
||||||
|
|
||||||
|
Should consume the requested amount if available.
|
||||||
|
"""
|
||||||
|
bucket = RateLimitBucket(
|
||||||
|
tokens=10, max_tokens=10, refill_rate=0.0 # No refill for test
|
||||||
|
)
|
||||||
|
|
||||||
|
result = bucket.consume(5)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert bucket.tokens == 5
|
||||||
|
|
||||||
|
def test_consume_denied_when_empty(self):
|
||||||
|
"""
|
||||||
|
Test that consumption is denied when not enough tokens.
|
||||||
|
|
||||||
|
Should return False and not reduce tokens below zero.
|
||||||
|
"""
|
||||||
|
bucket = RateLimitBucket(
|
||||||
|
tokens=0, max_tokens=10, refill_rate=0.0 # No refill
|
||||||
|
)
|
||||||
|
|
||||||
|
result = bucket.consume(1)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert bucket.tokens == 0 # Unchanged
|
||||||
|
|
||||||
|
def test_consume_denied_insufficient_tokens(self):
|
||||||
|
"""
|
||||||
|
Test denial when requesting more tokens than available.
|
||||||
|
|
||||||
|
Should return False without modifying token count.
|
||||||
|
"""
|
||||||
|
bucket = RateLimitBucket(
|
||||||
|
tokens=3, max_tokens=10, refill_rate=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
result = bucket.consume(5)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert bucket.tokens == 3 # Unchanged
|
||||||
|
|
||||||
|
def test_token_refill_over_time(self):
|
||||||
|
"""
|
||||||
|
Test that tokens refill based on elapsed time.
|
||||||
|
|
||||||
|
Tokens should accumulate at refill_rate per second.
|
||||||
|
"""
|
||||||
|
# Start with 5 tokens
|
||||||
|
past_time = pendulum.now("UTC").subtract(seconds=5)
|
||||||
|
bucket = RateLimitBucket(
|
||||||
|
tokens=5,
|
||||||
|
max_tokens=10,
|
||||||
|
refill_rate=1.0, # 1 token per second
|
||||||
|
last_refill=past_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force refill by consuming (which triggers refill check)
|
||||||
|
bucket.consume(0)
|
||||||
|
|
||||||
|
# Should have refilled ~5 tokens (5 seconds * 1 token/sec)
|
||||||
|
assert bucket.tokens >= 9.5 # Allow small timing variance
|
||||||
|
|
||||||
|
def test_tokens_capped_at_max(self):
|
||||||
|
"""
|
||||||
|
Test that tokens don't exceed max_tokens after refill.
|
||||||
|
|
||||||
|
Token count should never go above the configured maximum.
|
||||||
|
"""
|
||||||
|
past_time = pendulum.now("UTC").subtract(seconds=100)
|
||||||
|
bucket = RateLimitBucket(
|
||||||
|
tokens=5,
|
||||||
|
max_tokens=10,
|
||||||
|
refill_rate=1.0, # Would add 100 tokens
|
||||||
|
last_refill=past_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
bucket.consume(0) # Trigger refill
|
||||||
|
|
||||||
|
assert bucket.tokens == 10 # Capped at max
|
||||||
|
|
||||||
|
def test_get_wait_time_when_available(self):
|
||||||
|
"""
|
||||||
|
Test wait time calculation when tokens available.
|
||||||
|
|
||||||
|
Should return 0.0 when at least one token is available.
|
||||||
|
"""
|
||||||
|
bucket = RateLimitBucket(
|
||||||
|
tokens=5, max_tokens=10, refill_rate=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
wait_time = bucket.get_wait_time()
|
||||||
|
|
||||||
|
assert wait_time == 0.0
|
||||||
|
|
||||||
|
def test_get_wait_time_when_empty(self):
|
||||||
|
"""
|
||||||
|
Test wait time calculation when no tokens available.
|
||||||
|
|
||||||
|
Should return time needed for one token to refill.
|
||||||
|
"""
|
||||||
|
bucket = RateLimitBucket(
|
||||||
|
tokens=0, max_tokens=10, refill_rate=2.0 # 2 tokens per second
|
||||||
|
)
|
||||||
|
|
||||||
|
wait_time = bucket.get_wait_time()
|
||||||
|
|
||||||
|
assert wait_time == 0.5 # 1 token / 2 tokens per second
|
||||||
|
|
||||||
|
def test_get_wait_time_partial_tokens(self):
|
||||||
|
"""
|
||||||
|
Test wait time with partial tokens available.
|
||||||
|
|
||||||
|
Should calculate time to reach 1.0 token.
|
||||||
|
"""
|
||||||
|
bucket = RateLimitBucket(
|
||||||
|
tokens=0.5, max_tokens=10, refill_rate=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
wait_time = bucket.get_wait_time()
|
||||||
|
|
||||||
|
assert wait_time == 0.5 # Need 0.5 more tokens at 1/sec
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimiter:
|
||||||
|
"""Tests for the RateLimiter class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def limiter(self):
|
||||||
|
"""Create a fresh RateLimiter instance for each test."""
|
||||||
|
return RateLimiter()
|
||||||
|
|
||||||
|
def test_get_connection_bucket_creates_new(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that get_connection_bucket creates bucket for new sid.
|
||||||
|
|
||||||
|
Should create and return a bucket initialized with settings.
|
||||||
|
"""
|
||||||
|
bucket = limiter.get_connection_bucket("test_sid_1")
|
||||||
|
|
||||||
|
assert bucket is not None
|
||||||
|
assert bucket.max_tokens > 0
|
||||||
|
|
||||||
|
def test_get_connection_bucket_returns_existing(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that same bucket is returned for same sid.
|
||||||
|
|
||||||
|
Should return the same bucket instance for subsequent calls.
|
||||||
|
"""
|
||||||
|
bucket1 = limiter.get_connection_bucket("test_sid_1")
|
||||||
|
bucket2 = limiter.get_connection_bucket("test_sid_1")
|
||||||
|
|
||||||
|
assert bucket1 is bucket2
|
||||||
|
|
||||||
|
def test_get_connection_bucket_different_sids(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that different sids get different buckets.
|
||||||
|
|
||||||
|
Each connection should have its own rate limit bucket.
|
||||||
|
"""
|
||||||
|
bucket1 = limiter.get_connection_bucket("sid_1")
|
||||||
|
bucket2 = limiter.get_connection_bucket("sid_2")
|
||||||
|
|
||||||
|
assert bucket1 is not bucket2
|
||||||
|
|
||||||
|
def test_get_game_bucket_creates_new(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that get_game_bucket creates bucket for new game+action.
|
||||||
|
|
||||||
|
Should create bucket keyed by game_id and action type.
|
||||||
|
"""
|
||||||
|
bucket = limiter.get_game_bucket("game_123", "decision")
|
||||||
|
|
||||||
|
assert bucket is not None
|
||||||
|
assert bucket.max_tokens > 0
|
||||||
|
|
||||||
|
def test_get_game_bucket_different_actions(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that different actions get different buckets for same game.
|
||||||
|
|
||||||
|
Each action type has its own rate limit within a game.
|
||||||
|
"""
|
||||||
|
decision_bucket = limiter.get_game_bucket("game_123", "decision")
|
||||||
|
roll_bucket = limiter.get_game_bucket("game_123", "roll")
|
||||||
|
|
||||||
|
assert decision_bucket is not roll_bucket
|
||||||
|
|
||||||
|
def test_get_game_bucket_different_games(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that same action on different games gets different buckets.
|
||||||
|
|
||||||
|
Each game tracks its own rate limits.
|
||||||
|
"""
|
||||||
|
bucket1 = limiter.get_game_bucket("game_1", "decision")
|
||||||
|
bucket2 = limiter.get_game_bucket("game_2", "decision")
|
||||||
|
|
||||||
|
assert bucket1 is not bucket2
|
||||||
|
|
||||||
|
def test_get_user_bucket_creates_new(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that get_user_bucket creates bucket for new user.
|
||||||
|
|
||||||
|
Should create bucket for API rate limiting per user.
|
||||||
|
"""
|
||||||
|
bucket = limiter.get_user_bucket(123)
|
||||||
|
|
||||||
|
assert bucket is not None
|
||||||
|
assert bucket.max_tokens > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_websocket_limit_allowed(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that check_websocket_limit allows requests with tokens.
|
||||||
|
|
||||||
|
Should return True when tokens are available.
|
||||||
|
"""
|
||||||
|
result = await limiter.check_websocket_limit("test_sid")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_websocket_limit_denied_when_exhausted(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that check_websocket_limit denies when rate limited.
|
||||||
|
|
||||||
|
Should return False after tokens are exhausted.
|
||||||
|
"""
|
||||||
|
sid = "rate_limited_sid"
|
||||||
|
|
||||||
|
# Exhaust all tokens
|
||||||
|
bucket = limiter.get_connection_bucket(sid)
|
||||||
|
bucket.tokens = 0
|
||||||
|
|
||||||
|
result = await limiter.check_websocket_limit(sid)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_game_limit_allowed(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that check_game_limit allows requests with tokens.
|
||||||
|
|
||||||
|
Should return True when game action limit not reached.
|
||||||
|
"""
|
||||||
|
result = await limiter.check_game_limit("game_123", "decision")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_game_limit_denied_when_exhausted(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that check_game_limit denies when rate limited.
|
||||||
|
|
||||||
|
Should return False after game action tokens exhausted.
|
||||||
|
"""
|
||||||
|
game_id = "rate_limited_game"
|
||||||
|
|
||||||
|
bucket = limiter.get_game_bucket(game_id, "roll")
|
||||||
|
bucket.tokens = 0
|
||||||
|
|
||||||
|
result = await limiter.check_game_limit(game_id, "roll")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_api_limit_allowed(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that check_api_limit allows requests with tokens.
|
||||||
|
|
||||||
|
Should return True when user API limit not reached.
|
||||||
|
"""
|
||||||
|
result = await limiter.check_api_limit(123)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_api_limit_denied_when_exhausted(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that check_api_limit denies when rate limited.
|
||||||
|
|
||||||
|
Should return False after user API tokens exhausted.
|
||||||
|
"""
|
||||||
|
user_id = 999
|
||||||
|
|
||||||
|
bucket = limiter.get_user_bucket(user_id)
|
||||||
|
bucket.tokens = 0
|
||||||
|
|
||||||
|
result = await limiter.check_api_limit(user_id)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_remove_connection_cleans_up(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that remove_connection removes bucket for sid.
|
||||||
|
|
||||||
|
Should clean up bucket when connection closes.
|
||||||
|
"""
|
||||||
|
sid = "cleanup_test_sid"
|
||||||
|
limiter.get_connection_bucket(sid) # Create bucket
|
||||||
|
|
||||||
|
assert sid in limiter._connection_buckets
|
||||||
|
|
||||||
|
limiter.remove_connection(sid)
|
||||||
|
|
||||||
|
assert sid not in limiter._connection_buckets
|
||||||
|
|
||||||
|
def test_remove_connection_handles_missing(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that remove_connection handles non-existent sid gracefully.
|
||||||
|
|
||||||
|
Should not raise error for unknown sid.
|
||||||
|
"""
|
||||||
|
limiter.remove_connection("nonexistent_sid") # Should not raise
|
||||||
|
|
||||||
|
def test_remove_game_cleans_up_all_actions(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that remove_game removes all buckets for a game.
|
||||||
|
|
||||||
|
Should clean up decision, roll, substitution buckets for game.
|
||||||
|
"""
|
||||||
|
game_id = "cleanup_game"
|
||||||
|
|
||||||
|
# Create buckets for different actions
|
||||||
|
limiter.get_game_bucket(game_id, "decision")
|
||||||
|
limiter.get_game_bucket(game_id, "roll")
|
||||||
|
limiter.get_game_bucket(game_id, "substitution")
|
||||||
|
|
||||||
|
assert f"{game_id}:decision" in limiter._game_buckets
|
||||||
|
assert f"{game_id}:roll" in limiter._game_buckets
|
||||||
|
assert f"{game_id}:substitution" in limiter._game_buckets
|
||||||
|
|
||||||
|
limiter.remove_game(game_id)
|
||||||
|
|
||||||
|
assert f"{game_id}:decision" not in limiter._game_buckets
|
||||||
|
assert f"{game_id}:roll" not in limiter._game_buckets
|
||||||
|
assert f"{game_id}:substitution" not in limiter._game_buckets
|
||||||
|
|
||||||
|
def test_remove_game_preserves_other_games(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that remove_game only removes specified game's buckets.
|
||||||
|
|
||||||
|
Other games' buckets should remain intact.
|
||||||
|
"""
|
||||||
|
limiter.get_game_bucket("game_1", "decision")
|
||||||
|
limiter.get_game_bucket("game_2", "decision")
|
||||||
|
|
||||||
|
limiter.remove_game("game_1")
|
||||||
|
|
||||||
|
assert "game_1:decision" not in limiter._game_buckets
|
||||||
|
assert "game_2:decision" in limiter._game_buckets
|
||||||
|
|
||||||
|
def test_get_stats_returns_counts(self, limiter):
|
||||||
|
"""
|
||||||
|
Test that get_stats returns bucket counts.
|
||||||
|
|
||||||
|
Should return counts for all bucket types.
|
||||||
|
"""
|
||||||
|
limiter.get_connection_bucket("sid_1")
|
||||||
|
limiter.get_connection_bucket("sid_2")
|
||||||
|
limiter.get_game_bucket("game_1", "decision")
|
||||||
|
limiter.get_user_bucket(123)
|
||||||
|
|
||||||
|
stats = limiter.get_stats()
|
||||||
|
|
||||||
|
assert stats["connection_buckets"] == 2
|
||||||
|
assert stats["game_buckets"] == 1
|
||||||
|
assert stats["user_buckets"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupTask:
|
||||||
|
"""Tests for the stale bucket cleanup functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_removes_stale_connection_buckets(self):
|
||||||
|
"""
|
||||||
|
Test that cleanup removes connection buckets older than threshold.
|
||||||
|
|
||||||
|
Buckets unused for >10 minutes should be cleaned up.
|
||||||
|
"""
|
||||||
|
limiter = RateLimiter()
|
||||||
|
|
||||||
|
# Create a bucket with old last_refill time (>10 minutes ago)
|
||||||
|
stale_time = pendulum.now("UTC").subtract(minutes=15)
|
||||||
|
bucket = limiter.get_connection_bucket("stale_sid")
|
||||||
|
bucket.last_refill = stale_time
|
||||||
|
|
||||||
|
# Create a fresh bucket
|
||||||
|
fresh_bucket = limiter.get_connection_bucket("fresh_sid")
|
||||||
|
|
||||||
|
# Manually trigger cleanup logic (without running full async task)
|
||||||
|
now = pendulum.now("UTC")
|
||||||
|
stale_sids = [
|
||||||
|
sid
|
||||||
|
for sid, b in limiter._connection_buckets.items()
|
||||||
|
if (now - b.last_refill).total_seconds() > 600
|
||||||
|
]
|
||||||
|
for sid in stale_sids:
|
||||||
|
del limiter._connection_buckets[sid]
|
||||||
|
|
||||||
|
assert "stale_sid" not in limiter._connection_buckets
|
||||||
|
assert "fresh_sid" in limiter._connection_buckets
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_removes_stale_game_buckets(self):
|
||||||
|
"""
|
||||||
|
Test that cleanup removes game buckets older than threshold.
|
||||||
|
|
||||||
|
Abandoned game buckets should be cleaned up.
|
||||||
|
"""
|
||||||
|
limiter = RateLimiter()
|
||||||
|
|
||||||
|
stale_time = pendulum.now("UTC").subtract(minutes=15)
|
||||||
|
bucket = limiter.get_game_bucket("stale_game", "decision")
|
||||||
|
bucket.last_refill = stale_time
|
||||||
|
|
||||||
|
limiter.get_game_bucket("fresh_game", "decision")
|
||||||
|
|
||||||
|
# Manually trigger cleanup
|
||||||
|
now = pendulum.now("UTC")
|
||||||
|
stale_keys = [
|
||||||
|
key
|
||||||
|
for key, b in limiter._game_buckets.items()
|
||||||
|
if (now - b.last_refill).total_seconds() > 600
|
||||||
|
]
|
||||||
|
for key in stale_keys:
|
||||||
|
del limiter._game_buckets[key]
|
||||||
|
|
||||||
|
assert "stale_game:decision" not in limiter._game_buckets
|
||||||
|
assert "fresh_game:decision" in limiter._game_buckets
|
||||||
|
|
||||||
|
|
||||||
|
class TestGlobalRateLimiter:
|
||||||
|
"""Tests for the global rate_limiter instance."""
|
||||||
|
|
||||||
|
def test_global_instance_exists(self):
|
||||||
|
"""
|
||||||
|
Test that global rate_limiter instance is available.
|
||||||
|
|
||||||
|
Should be importable and usable as singleton.
|
||||||
|
"""
|
||||||
|
assert rate_limiter is not None
|
||||||
|
assert isinstance(rate_limiter, RateLimiter)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_global_instance_functional(self):
|
||||||
|
"""
|
||||||
|
Test that global rate_limiter works correctly.
|
||||||
|
|
||||||
|
Should accept and rate limit requests.
|
||||||
|
"""
|
||||||
|
result = await rate_limiter.check_websocket_limit("global_test_sid")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
rate_limiter.remove_connection("global_test_sid")
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitConfiguration:
|
||||||
|
"""Tests for rate limit configuration integration."""
|
||||||
|
|
||||||
|
def test_connection_bucket_uses_config(self):
|
||||||
|
"""
|
||||||
|
Test that connection bucket uses configured limits.
|
||||||
|
|
||||||
|
max_tokens should match rate_limit_websocket_per_minute setting.
|
||||||
|
"""
|
||||||
|
limiter = RateLimiter()
|
||||||
|
bucket = limiter.get_connection_bucket("config_test_sid")
|
||||||
|
|
||||||
|
# Default is 120 per minute
|
||||||
|
assert bucket.max_tokens == limiter._settings.rate_limit_websocket_per_minute
|
||||||
|
|
||||||
|
limiter.remove_connection("config_test_sid")
|
||||||
|
|
||||||
|
def test_decision_bucket_uses_config(self):
|
||||||
|
"""
|
||||||
|
Test that decision bucket uses configured limits.
|
||||||
|
|
||||||
|
max_tokens should match rate_limit_decision_per_game setting.
|
||||||
|
"""
|
||||||
|
limiter = RateLimiter()
|
||||||
|
bucket = limiter.get_game_bucket("config_game", "decision")
|
||||||
|
|
||||||
|
assert bucket.max_tokens == limiter._settings.rate_limit_decision_per_game
|
||||||
|
|
||||||
|
limiter.remove_game("config_game")
|
||||||
|
|
||||||
|
def test_roll_bucket_uses_config(self):
|
||||||
|
"""
|
||||||
|
Test that roll bucket uses configured limits.
|
||||||
|
|
||||||
|
max_tokens should match rate_limit_roll_per_game setting.
|
||||||
|
"""
|
||||||
|
limiter = RateLimiter()
|
||||||
|
bucket = limiter.get_game_bucket("config_game", "roll")
|
||||||
|
|
||||||
|
assert bucket.max_tokens == limiter._settings.rate_limit_roll_per_game
|
||||||
|
|
||||||
|
limiter.remove_game("config_game")
|
||||||
|
|
||||||
|
def test_substitution_bucket_uses_config(self):
|
||||||
|
"""
|
||||||
|
Test that substitution bucket uses configured limits.
|
||||||
|
|
||||||
|
max_tokens should match rate_limit_substitution_per_game setting.
|
||||||
|
"""
|
||||||
|
limiter = RateLimiter()
|
||||||
|
bucket = limiter.get_game_bucket("config_game", "substitution")
|
||||||
|
|
||||||
|
assert bucket.max_tokens == limiter._settings.rate_limit_substitution_per_game
|
||||||
|
|
||||||
|
limiter.remove_game("config_game")
|
||||||
|
|
||||||
|
def test_api_bucket_uses_config(self):
|
||||||
|
"""
|
||||||
|
Test that API bucket uses configured limits.
|
||||||
|
|
||||||
|
max_tokens should match rate_limit_api_per_minute setting.
|
||||||
|
"""
|
||||||
|
limiter = RateLimiter()
|
||||||
|
bucket = limiter.get_user_bucket(12345)
|
||||||
|
|
||||||
|
assert bucket.max_tokens == limiter._settings.rate_limit_api_per_minute
|
||||||
|
|
||||||
|
del limiter._user_buckets[12345]
|
||||||
1
backend/tests/unit/monitoring/__init__.py
Normal file
1
backend/tests/unit/monitoring/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Tests for monitoring utilities."""
|
||||||
388
backend/tests/unit/monitoring/test_pool_monitor.py
Normal file
388
backend/tests/unit/monitoring/test_pool_monitor.py
Normal file
@ -0,0 +1,388 @@
|
|||||||
|
"""
|
||||||
|
Tests for database connection pool monitoring.
|
||||||
|
|
||||||
|
Verifies that the PoolMonitor correctly:
|
||||||
|
- Reports pool statistics (checked in/out, overflow)
|
||||||
|
- Classifies health status (healthy/warning/critical)
|
||||||
|
- Tracks history of stats over time
|
||||||
|
- Logs warnings when usage exceeds threshold
|
||||||
|
|
||||||
|
Author: Claude
|
||||||
|
Date: 2025-11-27
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from app.monitoring.pool_monitor import PoolMonitor, PoolStats
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# TEST FIXTURES
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_engine():
|
||||||
|
"""
|
||||||
|
Create a mock SQLAlchemy engine with pool stats.
|
||||||
|
|
||||||
|
Default configuration: 15 available, 5 in use, no overflow.
|
||||||
|
"""
|
||||||
|
engine = MagicMock()
|
||||||
|
pool = MagicMock()
|
||||||
|
pool.checkedin.return_value = 15
|
||||||
|
pool.checkedout.return_value = 5
|
||||||
|
pool.overflow.return_value = 0
|
||||||
|
engine.pool = pool
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings():
|
||||||
|
"""Mock settings with pool configuration."""
|
||||||
|
settings = MagicMock()
|
||||||
|
settings.db_pool_size = 20
|
||||||
|
settings.db_max_overflow = 10
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# POOL STATS TESTS
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestPoolMonitorStats:
|
||||||
|
"""Tests for PoolMonitor.get_stats()."""
|
||||||
|
|
||||||
|
def test_get_stats_returns_pool_stats(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify get_stats() returns PoolStats with correct values.
|
||||||
|
|
||||||
|
The returned PoolStats should reflect the current pool state
|
||||||
|
from the engine's pool attribute.
|
||||||
|
"""
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine)
|
||||||
|
stats = monitor.get_stats()
|
||||||
|
|
||||||
|
assert isinstance(stats, PoolStats)
|
||||||
|
assert stats.checkedout == 5
|
||||||
|
assert stats.checkedin == 15
|
||||||
|
assert stats.overflow == 0
|
||||||
|
assert stats.pool_size == 20
|
||||||
|
assert stats.max_overflow == 10
|
||||||
|
assert stats.total_capacity == 30
|
||||||
|
|
||||||
|
def test_get_stats_calculates_usage_percent(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify usage percentage is calculated correctly.
|
||||||
|
|
||||||
|
Usage = checkedout / total_capacity
|
||||||
|
"""
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine)
|
||||||
|
stats = monitor.get_stats()
|
||||||
|
|
||||||
|
# 5 out of 30 = 16.67%
|
||||||
|
assert stats.usage_percent == pytest.approx(5 / 30, rel=0.01)
|
||||||
|
|
||||||
|
def test_get_stats_handles_zero_capacity(self, mock_engine):
|
||||||
|
"""
|
||||||
|
Verify get_stats() handles zero capacity without division error.
|
||||||
|
|
||||||
|
Edge case: if both pool_size and max_overflow are 0.
|
||||||
|
"""
|
||||||
|
settings = MagicMock()
|
||||||
|
settings.db_pool_size = 0
|
||||||
|
settings.db_max_overflow = 0
|
||||||
|
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=settings):
|
||||||
|
monitor = PoolMonitor(mock_engine)
|
||||||
|
stats = monitor.get_stats()
|
||||||
|
|
||||||
|
assert stats.usage_percent == 0
|
||||||
|
assert stats.total_capacity == 0
|
||||||
|
|
||||||
|
def test_get_stats_records_history(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify get_stats() records each call in history.
|
||||||
|
|
||||||
|
History is used for monitoring dashboards and alerts.
|
||||||
|
"""
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine)
|
||||||
|
|
||||||
|
# Make multiple calls
|
||||||
|
for _ in range(5):
|
||||||
|
monitor.get_stats()
|
||||||
|
|
||||||
|
history = monitor.get_history(limit=10)
|
||||||
|
assert len(history) == 5
|
||||||
|
|
||||||
|
def test_get_stats_limits_history_size(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify history is limited to prevent memory growth.
|
||||||
|
|
||||||
|
Older entries should be removed when max_history is exceeded.
|
||||||
|
"""
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine, max_history=5)
|
||||||
|
|
||||||
|
# Make more calls than max_history
|
||||||
|
for _ in range(10):
|
||||||
|
monitor.get_stats()
|
||||||
|
|
||||||
|
history = monitor.get_history(limit=10)
|
||||||
|
assert len(history) == 5
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# HEALTH STATUS TESTS
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestPoolMonitorHealth:
|
||||||
|
"""Tests for PoolMonitor.get_health_status()."""
|
||||||
|
|
||||||
|
def test_health_status_healthy(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify health status is 'healthy' when usage < 75%.
|
||||||
|
|
||||||
|
At 5/30 (16.7%) usage, status should be healthy.
|
||||||
|
"""
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine)
|
||||||
|
health = monitor.get_health_status()
|
||||||
|
|
||||||
|
assert health["status"] == "healthy"
|
||||||
|
assert health["in_use"] == 5
|
||||||
|
assert health["available"] == 15
|
||||||
|
|
||||||
|
def test_health_status_warning(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify health status is 'warning' when usage is 75-90%.
|
||||||
|
|
||||||
|
At 24/30 (80%) usage, status should be warning.
|
||||||
|
"""
|
||||||
|
mock_engine.pool.checkedout.return_value = 24
|
||||||
|
mock_engine.pool.checkedin.return_value = 6
|
||||||
|
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine)
|
||||||
|
health = monitor.get_health_status()
|
||||||
|
|
||||||
|
assert health["status"] == "warning"
|
||||||
|
assert health["usage_percent"] == pytest.approx(80, rel=1)
|
||||||
|
|
||||||
|
def test_health_status_critical(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify health status is 'critical' when usage >= 90%.
|
||||||
|
|
||||||
|
At 28/30 (93%) usage, status should be critical.
|
||||||
|
"""
|
||||||
|
mock_engine.pool.checkedout.return_value = 28
|
||||||
|
mock_engine.pool.checkedin.return_value = 2
|
||||||
|
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine)
|
||||||
|
health = monitor.get_health_status()
|
||||||
|
|
||||||
|
assert health["status"] == "critical"
|
||||||
|
assert health["usage_percent"] >= 90
|
||||||
|
|
||||||
|
def test_health_status_includes_timestamp(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify health status includes ISO timestamp.
|
||||||
|
|
||||||
|
Timestamp is used for monitoring and alerting correlation.
|
||||||
|
"""
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine)
|
||||||
|
health = monitor.get_health_status()
|
||||||
|
|
||||||
|
assert "timestamp" in health
|
||||||
|
assert isinstance(health["timestamp"], str)
|
||||||
|
assert "T" in health["timestamp"] # ISO format
|
||||||
|
|
||||||
|
def test_health_status_includes_all_fields(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify health status includes all required fields.
|
||||||
|
|
||||||
|
All fields are needed for complete monitoring visibility.
|
||||||
|
"""
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine)
|
||||||
|
health = monitor.get_health_status()
|
||||||
|
|
||||||
|
required_fields = [
|
||||||
|
"status",
|
||||||
|
"pool_size",
|
||||||
|
"max_overflow",
|
||||||
|
"available",
|
||||||
|
"in_use",
|
||||||
|
"overflow_active",
|
||||||
|
"total_capacity",
|
||||||
|
"usage_percent",
|
||||||
|
"timestamp",
|
||||||
|
]
|
||||||
|
|
||||||
|
for field in required_fields:
|
||||||
|
assert field in health, f"Missing field: {field}"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# HISTORY TESTS
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestPoolMonitorHistory:
|
||||||
|
"""Tests for PoolMonitor.get_history()."""
|
||||||
|
|
||||||
|
def test_get_history_returns_recent_stats(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify get_history() returns recent stat snapshots.
|
||||||
|
|
||||||
|
History should include checkedout, usage_percent, and timestamp.
|
||||||
|
"""
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine)
|
||||||
|
|
||||||
|
# Generate some history
|
||||||
|
for _ in range(3):
|
||||||
|
monitor.get_stats()
|
||||||
|
|
||||||
|
history = monitor.get_history(limit=5)
|
||||||
|
|
||||||
|
assert len(history) == 3
|
||||||
|
for entry in history:
|
||||||
|
assert "checkedout" in entry
|
||||||
|
assert "usage_percent" in entry
|
||||||
|
assert "timestamp" in entry
|
||||||
|
|
||||||
|
def test_get_history_respects_limit(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify get_history() respects the limit parameter.
|
||||||
|
|
||||||
|
Should return at most 'limit' entries.
|
||||||
|
"""
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine)
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
monitor.get_stats()
|
||||||
|
|
||||||
|
history = monitor.get_history(limit=3)
|
||||||
|
assert len(history) == 3
|
||||||
|
|
||||||
|
def test_get_history_returns_empty_when_no_stats(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify get_history() returns empty list when no stats collected.
|
||||||
|
|
||||||
|
Fresh monitor should have no history.
|
||||||
|
"""
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine)
|
||||||
|
history = monitor.get_history()
|
||||||
|
|
||||||
|
assert history == []
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# ALERT THRESHOLD TESTS
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestPoolMonitorAlerts:
|
||||||
|
"""Tests for alert threshold behavior."""
|
||||||
|
|
||||||
|
def test_logs_warning_when_threshold_exceeded(self, mock_engine, mock_settings, caplog):
|
||||||
|
"""
|
||||||
|
Verify warning is logged when usage exceeds alert threshold.
|
||||||
|
|
||||||
|
Default threshold is 80%. At 85% usage, warning should be logged.
|
||||||
|
"""
|
||||||
|
mock_engine.pool.checkedout.return_value = 26 # 86.7%
|
||||||
|
mock_engine.pool.checkedin.return_value = 4
|
||||||
|
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine, alert_threshold=0.8)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
monitor.get_stats()
|
||||||
|
|
||||||
|
assert "Connection pool usage high" in caplog.text
|
||||||
|
|
||||||
|
def test_logs_info_when_overflow_active(self, mock_engine, mock_settings, caplog):
|
||||||
|
"""
|
||||||
|
Verify info is logged when overflow connections are in use.
|
||||||
|
|
||||||
|
Overflow indicates pool is at capacity and using extra connections.
|
||||||
|
"""
|
||||||
|
mock_engine.pool.overflow.return_value = 3
|
||||||
|
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
with caplog.at_level(logging.INFO):
|
||||||
|
monitor.get_stats()
|
||||||
|
|
||||||
|
assert "Pool overflow active" in caplog.text
|
||||||
|
|
||||||
|
def test_no_warning_when_below_threshold(self, mock_engine, mock_settings, caplog):
|
||||||
|
"""
|
||||||
|
Verify no warning when usage is below threshold.
|
||||||
|
|
||||||
|
At 50% usage with 80% threshold, no warning should be logged.
|
||||||
|
"""
|
||||||
|
mock_engine.pool.checkedout.return_value = 10 # 33%
|
||||||
|
mock_engine.pool.checkedin.return_value = 20
|
||||||
|
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = PoolMonitor(mock_engine, alert_threshold=0.8)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
monitor.get_stats()
|
||||||
|
|
||||||
|
assert "Connection pool usage high" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# INIT FUNCTION TESTS
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitPoolMonitor:
|
||||||
|
"""Tests for init_pool_monitor function."""
|
||||||
|
|
||||||
|
def test_init_pool_monitor_creates_global_instance(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify init_pool_monitor() creates global instance.
|
||||||
|
|
||||||
|
The global instance is used by health endpoints.
|
||||||
|
"""
|
||||||
|
from app.monitoring.pool_monitor import init_pool_monitor, pool_monitor
|
||||||
|
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = init_pool_monitor(mock_engine)
|
||||||
|
|
||||||
|
assert monitor is not None
|
||||||
|
assert isinstance(monitor, PoolMonitor)
|
||||||
|
|
||||||
|
def test_init_pool_monitor_returns_monitor(self, mock_engine, mock_settings):
|
||||||
|
"""
|
||||||
|
Verify init_pool_monitor() returns the created monitor.
|
||||||
|
|
||||||
|
Return value allows caller to use the monitor directly.
|
||||||
|
"""
|
||||||
|
from app.monitoring.pool_monitor import init_pool_monitor
|
||||||
|
|
||||||
|
with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings):
|
||||||
|
monitor = init_pool_monitor(mock_engine)
|
||||||
|
|
||||||
|
# Should be able to use the returned monitor
|
||||||
|
stats = monitor.get_stats()
|
||||||
|
assert stats is not None
|
||||||
Loading…
Reference in New Issue
Block a user