From 2a392b87f8d601d7718adf060b9368bfb171f472 Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Fri, 28 Nov 2025 12:06:10 -0600 Subject: [PATCH] CLAUDE: Add rate limiting, pool monitoring, and exception infrastructure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add rate_limit.py middleware with per-client throttling and cleanup task - Add pool_monitor.py for database connection pool health monitoring - Add custom exceptions module (GameEngineError, DatabaseError, etc.) - Add config settings for eviction intervals, session timeouts, memory limits - Add unit tests for rate limiting and pool monitoring 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- backend/app/config.py | 13 + backend/app/core/exceptions.py | 214 +++++++ backend/app/middleware/__init__.py | 0 backend/app/middleware/rate_limit.py | 328 ++++++++++ backend/app/monitoring/__init__.py | 15 + backend/app/monitoring/pool_monitor.py | 220 +++++++ backend/tests/unit/middleware/__init__.py | 1 + .../tests/unit/middleware/test_rate_limit.py | 583 ++++++++++++++++++ backend/tests/unit/monitoring/__init__.py | 1 + .../unit/monitoring/test_pool_monitor.py | 388 ++++++++++++ 10 files changed, 1763 insertions(+) create mode 100644 backend/app/core/exceptions.py create mode 100644 backend/app/middleware/__init__.py create mode 100644 backend/app/middleware/rate_limit.py create mode 100644 backend/app/monitoring/__init__.py create mode 100644 backend/app/monitoring/pool_monitor.py create mode 100644 backend/tests/unit/middleware/__init__.py create mode 100644 backend/tests/unit/middleware/test_rate_limit.py create mode 100644 backend/tests/unit/monitoring/__init__.py create mode 100644 backend/tests/unit/monitoring/test_pool_monitor.py diff --git a/backend/app/config.py b/backend/app/config.py index b105e9a..a703ea0 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -47,6 +47,19 @@ class Settings(BaseSettings): max_concurrent_games: int = 20 game_idle_timeout: int = 86400 # 24 hours + # Game eviction settings (memory management) + game_idle_timeout_hours: int = 24 # Evict games idle > 24 hours + game_eviction_interval_minutes: int = 60 # Check every hour + game_max_in_memory: int = 500 # Hard limit on in-memory games + + # Rate limiting settings + rate_limit_websocket_per_minute: int = 120 # Events per minute per connection + rate_limit_api_per_minute: int = 100 # API calls per minute per user + rate_limit_decision_per_game: int = 20 # Decisions per minute per game + rate_limit_roll_per_game: int = 30 # Rolls per minute per game + rate_limit_substitution_per_game: int = 15 # Substitutions per minute per game + rate_limit_cleanup_interval: int = 300 # Cleanup stale buckets every 5 min + class Config: env_file = ".env" case_sensitive = False diff --git a/backend/app/core/exceptions.py b/backend/app/core/exceptions.py new file mode 100644 index 0000000..0a05041 --- /dev/null +++ b/backend/app/core/exceptions.py @@ -0,0 +1,214 @@ +""" +Custom exceptions for the game engine. + +Provides a hierarchy of specific exception types to replace broad `except Exception` +patterns, improving debugging and error handling precision. + +Exception Hierarchy: + GameEngineError (base) + ├── GameNotFoundError - Game doesn't exist in state manager + ├── InvalidGameStateError - Game in wrong state for operation + ├── SubstitutionError - Invalid player substitution + ├── AuthorizationError - User lacks permission + ├── DecisionTimeoutError - Decision not submitted in time + └── ExternalAPIError - External service call failed + └── PlayerDataError - Failed to fetch player data + +Author: Claude +Date: 2025-11-27 +""" + +from uuid import UUID + + +class GameEngineError(Exception): + """ + Base exception for all game engine errors. + + All game-specific exceptions should inherit from this class + to allow catching game errors separately from system errors. + """ + + pass + + +class GameNotFoundError(GameEngineError): + """ + Raised when a game doesn't exist in state manager. + + Attributes: + game_id: The UUID of the missing game + """ + + def __init__(self, game_id: UUID | str): + self.game_id = game_id + super().__init__(f"Game not found: {game_id}") + + +class InvalidGameStateError(GameEngineError): + """ + Raised when game is in invalid state for requested operation. + + Examples: + - Trying to submit decision when game is completed + - Rolling dice before both decisions submitted + - Starting an already-started game + + Attributes: + message: Description of the state violation + current_state: Optional current state value for debugging + expected_state: Optional expected state value for debugging + """ + + def __init__( + self, + message: str, + current_state: str | None = None, + expected_state: str | None = None, + ): + self.current_state = current_state + self.expected_state = expected_state + super().__init__(message) + + +class SubstitutionError(GameEngineError): + """ + Raised when a player substitution is invalid. + + Used for substitution rule violations, not found errors. + + Attributes: + message: Description of the substitution error + error_code: Machine-readable error code for frontend handling + """ + + def __init__(self, message: str, error_code: str = "SUBSTITUTION_ERROR"): + self.error_code = error_code + super().__init__(message) + + +class AuthorizationError(GameEngineError): + """ + Raised when user lacks permission for an operation. + + Note: Currently deferred (task 001) but defined for future use. + + Attributes: + message: Description of the authorization failure + user_id: Optional user ID for logging + resource: Optional resource being accessed + """ + + def __init__( + self, + message: str, + user_id: int | None = None, + resource: str | None = None, + ): + self.user_id = user_id + self.resource = resource + super().__init__(message) + + +class DecisionTimeoutError(GameEngineError): + """ + Raised when a decision is not submitted within the timeout period. + + The game engine uses default decisions when this occurs, + so this exception is informational rather than fatal. + + Attributes: + game_id: Game where timeout occurred + decision_type: Type of decision that timed out ('defensive' or 'offensive') + timeout_seconds: How long we waited + """ + + def __init__( + self, + game_id: UUID | str, + decision_type: str, + timeout_seconds: int, + ): + self.game_id = game_id + self.decision_type = decision_type + self.timeout_seconds = timeout_seconds + super().__init__( + f"Decision timeout for game {game_id}: " + f"{decision_type} decision not received within {timeout_seconds}s" + ) + + +class ExternalAPIError(GameEngineError): + """ + Raised when an external API call fails. + + Base class for external service errors. + + Attributes: + service: Name of the external service + message: Error description + status_code: Optional HTTP status code + """ + + def __init__( + self, + service: str, + message: str, + status_code: int | None = None, + ): + self.service = service + self.status_code = status_code + super().__init__(f"{service} API error: {message}") + + +class PlayerDataError(ExternalAPIError): + """ + Raised when player data cannot be fetched from external API. + + Specialized error for player data lookups (SBA API, PD API). + + Attributes: + player_id: ID of the player that couldn't be fetched + service: Which API was called + """ + + def __init__(self, player_id: int, service: str = "SBA API"): + self.player_id = player_id + super().__init__( + service=service, + message=f"Failed to fetch player data for ID {player_id}", + ) + + +class DatabaseError(GameEngineError): + """ + Raised when a database operation fails. + + Wraps SQLAlchemy exceptions with game context. + + Attributes: + operation: What operation was attempted (e.g., 'save_play', 'create_substitution') + original_error: The underlying database exception + """ + + def __init__(self, operation: str, original_error: Exception | None = None): + self.operation = operation + self.original_error = original_error + message = f"Database error during {operation}" + if original_error: + message += f": {original_error}" + super().__init__(message) + + +class LineupError(GameEngineError): + """ + Raised when lineup validation or lookup fails. + + Attributes: + team_id: Team with the lineup issue + message: Description of the error + """ + + def __init__(self, team_id: int, message: str): + self.team_id = team_id + super().__init__(f"Lineup error for team {team_id}: {message}") diff --git a/backend/app/middleware/__init__.py b/backend/app/middleware/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/middleware/rate_limit.py b/backend/app/middleware/rate_limit.py new file mode 100644 index 0000000..e13c4d8 --- /dev/null +++ b/backend/app/middleware/rate_limit.py @@ -0,0 +1,328 @@ +""" +Rate limiting utilities for WebSocket and API endpoints. + +Implements token bucket algorithm for smooth rate limiting that allows +bursts while enforcing long-term rate limits. + +Key features: +- Per-connection rate limiting for WebSocket events +- Per-game rate limiting for game actions (decisions, rolls, substitutions) +- Per-user rate limiting for REST API endpoints +- Automatic cleanup of stale buckets +- Non-blocking async design + +Author: Claude +Date: 2025-11-27 +""" + +import asyncio +import logging +from dataclasses import dataclass, field +from functools import wraps +from typing import TYPE_CHECKING, Callable + +import pendulum + +from app.config import get_settings + +if TYPE_CHECKING: + from socketio import AsyncServer + +logger = logging.getLogger(f"{__name__}.RateLimiter") + + +@dataclass +class RateLimitBucket: + """ + Token bucket for rate limiting. + + The token bucket algorithm works by: + 1. Bucket starts with max_tokens + 2. Each request consumes tokens + 3. Tokens refill at refill_rate per second + 4. Requests are denied when tokens < 1 + + This allows short bursts while enforcing average rate limits. + """ + + tokens: float + max_tokens: int + refill_rate: float # tokens per second + last_refill: pendulum.DateTime = field( + default_factory=lambda: pendulum.now("UTC") + ) + + def consume(self, tokens: int = 1) -> bool: + """ + Try to consume tokens. + + Returns True if allowed, False if rate limited. + """ + self._refill() + if self.tokens >= tokens: + self.tokens -= tokens + return True + return False + + def _refill(self) -> None: + """Refill tokens based on time elapsed since last refill.""" + now = pendulum.now("UTC") + elapsed = (now - self.last_refill).total_seconds() + refill_amount = elapsed * self.refill_rate + + if refill_amount > 0: + self.tokens = min(self.max_tokens, self.tokens + refill_amount) + self.last_refill = now + + def get_wait_time(self) -> float: + """ + Get seconds until a token will be available. + + Useful for informing clients how long to wait. + """ + if self.tokens >= 1: + return 0.0 + tokens_needed = 1 - self.tokens + return tokens_needed / self.refill_rate + + +class RateLimiter: + """ + Rate limiter for WebSocket connections and API endpoints. + + Uses token bucket algorithm for smooth rate limiting that allows + bursts while enforcing average rate limits. + + Thread-safe for async operations via atomic dict operations. + """ + + def __init__(self): + # Per-connection buckets (sid -> bucket) + self._connection_buckets: dict[str, RateLimitBucket] = {} + # Per-game buckets (game_id:action -> bucket) + self._game_buckets: dict[str, RateLimitBucket] = {} + # Per-user API buckets (user_id -> bucket) + self._user_buckets: dict[int, RateLimitBucket] = {} + # Cleanup task handle + self._cleanup_task: asyncio.Task | None = None + # Settings cache + self._settings = get_settings() + + def get_connection_bucket(self, sid: str) -> RateLimitBucket: + """Get or create bucket for WebSocket connection.""" + if sid not in self._connection_buckets: + limit = self._settings.rate_limit_websocket_per_minute + self._connection_buckets[sid] = RateLimitBucket( + tokens=limit, + max_tokens=limit, + refill_rate=limit / 60, # per minute -> per second + ) + return self._connection_buckets[sid] + + def get_game_bucket(self, game_id: str, action: str) -> RateLimitBucket: + """ + Get or create bucket for game-specific action. + + Actions: 'decision', 'roll', 'substitution' + """ + key = f"{game_id}:{action}" + if key not in self._game_buckets: + # Get limit based on action type + if action == "decision": + limit = self._settings.rate_limit_decision_per_game + elif action == "roll": + limit = self._settings.rate_limit_roll_per_game + elif action == "substitution": + limit = self._settings.rate_limit_substitution_per_game + else: + limit = 30 # Default for unknown actions + + self._game_buckets[key] = RateLimitBucket( + tokens=limit, + max_tokens=limit, + refill_rate=limit / 60, + ) + return self._game_buckets[key] + + def get_user_bucket(self, user_id: int) -> RateLimitBucket: + """Get or create bucket for API user.""" + if user_id not in self._user_buckets: + limit = self._settings.rate_limit_api_per_minute + self._user_buckets[user_id] = RateLimitBucket( + tokens=limit, + max_tokens=limit, + refill_rate=limit / 60, + ) + return self._user_buckets[user_id] + + async def check_websocket_limit(self, sid: str) -> bool: + """ + Check if WebSocket event is allowed for connection. + + Returns True if allowed, False if rate limited. + """ + bucket = self.get_connection_bucket(sid) + allowed = bucket.consume() + if not allowed: + logger.warning(f"Rate limited WebSocket connection: {sid}") + return allowed + + async def check_game_limit(self, game_id: str, action: str) -> bool: + """ + Check if game action is allowed. + + Returns True if allowed, False if rate limited. + """ + bucket = self.get_game_bucket(game_id, action) + allowed = bucket.consume() + if not allowed: + logger.warning(f"Rate limited game action: {game_id} {action}") + return allowed + + async def check_api_limit(self, user_id: int) -> bool: + """ + Check if API call is allowed for user. + + Returns True if allowed, False if rate limited. + """ + bucket = self.get_user_bucket(user_id) + allowed = bucket.consume() + if not allowed: + logger.warning(f"Rate limited API user: {user_id}") + return allowed + + def remove_connection(self, sid: str) -> None: + """Clean up buckets when connection closes.""" + self._connection_buckets.pop(sid, None) + + def remove_game(self, game_id: str) -> None: + """Clean up all buckets for a game when it ends.""" + keys_to_remove = [ + key for key in self._game_buckets if key.startswith(f"{game_id}:") + ] + for key in keys_to_remove: + del self._game_buckets[key] + + async def cleanup_stale_buckets(self) -> None: + """ + Background task to periodically clean up stale buckets. + + Removes buckets that haven't been used in 10 minutes to prevent + memory leaks from abandoned connections. + """ + interval = self._settings.rate_limit_cleanup_interval + stale_threshold_seconds = 600 # 10 minutes + + logger.info(f"Starting rate limiter cleanup task (interval: {interval}s)") + + while True: + try: + await asyncio.sleep(interval) + now = pendulum.now("UTC") + + # Clean connection buckets + stale_connections = [ + sid + for sid, bucket in self._connection_buckets.items() + if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds + ] + for sid in stale_connections: + del self._connection_buckets[sid] + + # Clean game buckets + stale_games = [ + key + for key, bucket in self._game_buckets.items() + if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds + ] + for key in stale_games: + del self._game_buckets[key] + + # Clean user buckets + stale_users = [ + user_id + for user_id, bucket in self._user_buckets.items() + if (now - bucket.last_refill).total_seconds() > stale_threshold_seconds + ] + for user_id in stale_users: + del self._user_buckets[user_id] + + if stale_connections or stale_games or stale_users: + logger.debug( + f"Rate limiter cleanup: {len(stale_connections)} connections, " + f"{len(stale_games)} games, {len(stale_users)} users" + ) + + except asyncio.CancelledError: + logger.info("Rate limiter cleanup task cancelled") + break + except Exception as e: + logger.error(f"Rate limiter cleanup error: {e}", exc_info=True) + # Continue running despite errors + + def get_stats(self) -> dict: + """Get rate limiter statistics for monitoring.""" + return { + "connection_buckets": len(self._connection_buckets), + "game_buckets": len(self._game_buckets), + "user_buckets": len(self._user_buckets), + } + + +# Global rate limiter instance +rate_limiter = RateLimiter() + + +def rate_limited(action: str = "general"): + """ + Decorator for rate-limited WebSocket handlers. + + Checks both connection-level and game-level rate limits before + allowing the handler to execute. + + Args: + action: Type of action for game-level limits. + Options: 'decision', 'roll', 'substitution', 'general' + 'general' only applies connection-level limit. + + Usage: + @sio.event + @rate_limited(action="decision") + async def submit_defensive_decision(sid, data): + ... + + Note: The decorator must be applied AFTER @sio.event to work correctly. + """ + + def decorator(func: Callable): + @wraps(func) + async def wrapper(sid: str, data=None, *args, **kwargs): + # Import here to avoid circular imports + from app.websocket.connection_manager import ConnectionManager + + # Check connection-level limit first + if not await rate_limiter.check_websocket_limit(sid): + # Get the connection manager to emit error + # Note: We can't easily access the manager here, so we'll + # handle rate limiting in the handlers themselves + logger.warning(f"Rate limited handler call from {sid}") + return {"error": "rate_limited", "message": "Rate limited. Please slow down."} + + # Check game-level limit if applicable + if action != "general" and isinstance(data, dict): + game_id = data.get("game_id") + if game_id: + if not await rate_limiter.check_game_limit(str(game_id), action): + logger.warning( + f"Rate limited game action {action} for game {game_id}" + ) + return { + "error": "game_rate_limited", + "message": f"Too many {action} requests for this game.", + } + + return await func(sid, data, *args, **kwargs) + + return wrapper + + return decorator diff --git a/backend/app/monitoring/__init__.py b/backend/app/monitoring/__init__.py new file mode 100644 index 0000000..fac88f5 --- /dev/null +++ b/backend/app/monitoring/__init__.py @@ -0,0 +1,15 @@ +"""Monitoring utilities for the Paper Dynasty game engine.""" + +from app.monitoring.pool_monitor import ( + PoolMonitor, + PoolStats, + init_pool_monitor, + pool_monitor, +) + +__all__ = [ + "PoolMonitor", + "PoolStats", + "init_pool_monitor", + "pool_monitor", +] diff --git a/backend/app/monitoring/pool_monitor.py b/backend/app/monitoring/pool_monitor.py new file mode 100644 index 0000000..c104320 --- /dev/null +++ b/backend/app/monitoring/pool_monitor.py @@ -0,0 +1,220 @@ +""" +Database connection pool monitoring. + +Monitors SQLAlchemy async connection pool health and provides +statistics for observability and alerting. + +Key features: +- Real-time pool statistics (checked in/out, overflow) +- Health status classification (healthy/warning/critical) +- Historical stats tracking +- Background monitoring with configurable interval +- Warning logs when pool usage exceeds threshold + +Author: Claude +Date: 2025-11-27 +""" + +import asyncio +import logging +from dataclasses import dataclass, field +from typing import Optional + +import pendulum +from sqlalchemy.ext.asyncio import AsyncEngine + +from app.config import get_settings + +logger = logging.getLogger(f"{__name__}.PoolMonitor") + + +@dataclass +class PoolStats: + """ + Connection pool statistics snapshot. + + Captures the current state of the database connection pool + for monitoring and alerting purposes. + """ + + pool_size: int + max_overflow: int + checkedin: int # Available connections + checkedout: int # In-use connections + overflow: int # Overflow connections in use + total_capacity: int + usage_percent: float + timestamp: pendulum.DateTime = field(default_factory=lambda: pendulum.now("UTC")) + + +class PoolMonitor: + """ + Monitor database connection pool health. + + Provides real-time statistics and health status for the SQLAlchemy + connection pool. Useful for detecting pool exhaustion before it + causes request failures. + + Usage: + monitor = PoolMonitor(engine) + stats = monitor.get_stats() + health = monitor.get_health_status() + """ + + def __init__( + self, + engine: AsyncEngine, + alert_threshold: float = 0.8, + max_history: int = 100, + ): + """ + Initialize pool monitor. + + Args: + engine: SQLAlchemy async engine to monitor + alert_threshold: Usage percentage to trigger warning (0.8 = 80%) + max_history: Maximum stats snapshots to keep in history + """ + self._engine = engine + self._stats_history: list[PoolStats] = [] + self._max_history = max_history + self._alert_threshold = alert_threshold + self._settings = get_settings() + + def get_stats(self) -> PoolStats: + """ + Get current pool statistics. + + Returns: + PoolStats with current pool state + """ + pool = self._engine.pool + + checkedin = pool.checkedin() + checkedout = pool.checkedout() + overflow = pool.overflow() + total_capacity = self._settings.db_pool_size + self._settings.db_max_overflow + + usage_percent = checkedout / total_capacity if total_capacity > 0 else 0 + + stats = PoolStats( + pool_size=self._settings.db_pool_size, + max_overflow=self._settings.db_max_overflow, + checkedin=checkedin, + checkedout=checkedout, + overflow=overflow, + total_capacity=total_capacity, + usage_percent=usage_percent, + ) + + # Record history + self._stats_history.append(stats) + if len(self._stats_history) > self._max_history: + self._stats_history.pop(0) + + # Check for alerts + if usage_percent >= self._alert_threshold: + logger.warning( + f"Connection pool usage high: {usage_percent:.1%} " + f"({checkedout}/{total_capacity})" + ) + + if overflow > 0: + logger.info(f"Pool overflow active: {overflow} overflow connections") + + return stats + + def get_health_status(self) -> dict: + """ + Get pool health status for monitoring endpoint. + + Returns: + Dict with status, statistics, and timestamp + """ + stats = self.get_stats() + + if stats.usage_percent >= 0.9: + status = "critical" + elif stats.usage_percent >= 0.75: + status = "warning" + else: + status = "healthy" + + return { + "status": status, + "pool_size": stats.pool_size, + "max_overflow": stats.max_overflow, + "available": stats.checkedin, + "in_use": stats.checkedout, + "overflow_active": stats.overflow, + "total_capacity": stats.total_capacity, + "usage_percent": round(stats.usage_percent * 100, 1), + "timestamp": stats.timestamp.isoformat(), + } + + def get_history(self, limit: int = 10) -> list[dict]: + """ + Get recent stats history. + + Args: + limit: Maximum number of history entries to return + + Returns: + List of stats snapshots + """ + return [ + { + "checkedout": s.checkedout, + "usage_percent": round(s.usage_percent * 100, 1), + "timestamp": s.timestamp.isoformat(), + } + for s in self._stats_history[-limit:] + ] + + async def start_monitoring(self, interval_seconds: int = 60): + """ + Background task to periodically collect stats. + + Useful for continuous logging and alerting. Runs until cancelled. + + Args: + interval_seconds: Seconds between stat collections + """ + logger.info(f"Starting pool monitoring (interval: {interval_seconds}s)") + + while True: + try: + stats = self.get_stats() + logger.debug( + f"Pool stats: {stats.checkedout}/{stats.total_capacity} " + f"({stats.usage_percent:.1%})" + ) + await asyncio.sleep(interval_seconds) + except asyncio.CancelledError: + logger.info("Pool monitoring stopped") + break + except Exception as e: + logger.error(f"Pool monitoring error: {e}") + await asyncio.sleep(interval_seconds) + + +# Global instance (initialized in main.py) +pool_monitor: Optional[PoolMonitor] = None + + +def init_pool_monitor(engine: AsyncEngine) -> PoolMonitor: + """ + Initialize global pool monitor. + + Should be called during application startup. + + Args: + engine: SQLAlchemy async engine to monitor + + Returns: + Initialized PoolMonitor instance + """ + global pool_monitor + pool_monitor = PoolMonitor(engine) + logger.info("Pool monitor initialized") + return pool_monitor diff --git a/backend/tests/unit/middleware/__init__.py b/backend/tests/unit/middleware/__init__.py new file mode 100644 index 0000000..9725e36 --- /dev/null +++ b/backend/tests/unit/middleware/__init__.py @@ -0,0 +1 @@ +# Middleware unit tests diff --git a/backend/tests/unit/middleware/test_rate_limit.py b/backend/tests/unit/middleware/test_rate_limit.py new file mode 100644 index 0000000..20251d4 --- /dev/null +++ b/backend/tests/unit/middleware/test_rate_limit.py @@ -0,0 +1,583 @@ +""" +Unit tests for rate limiting middleware. + +Tests the token bucket algorithm implementation and rate limiter functionality +for WebSocket and API rate limiting. + +Author: Claude +Date: 2025-11-27 +""" + +import asyncio +from unittest.mock import patch + +import pendulum +import pytest + +from app.middleware.rate_limit import RateLimitBucket, RateLimiter, rate_limiter + + +class TestRateLimitBucket: + """Tests for the RateLimitBucket token bucket implementation.""" + + def test_bucket_initialization(self): + """ + Test that bucket initializes with correct values. + + Bucket should start with max_tokens available and record + the creation time as last_refill. + """ + bucket = RateLimitBucket( + tokens=10, max_tokens=10, refill_rate=1.0 + ) + + assert bucket.tokens == 10 + assert bucket.max_tokens == 10 + assert bucket.refill_rate == 1.0 + assert bucket.last_refill is not None + + def test_consume_success(self): + """ + Test successful token consumption. + + Should reduce available tokens and return True when + enough tokens are available. + """ + bucket = RateLimitBucket( + tokens=10, max_tokens=10, refill_rate=1.0 + ) + + result = bucket.consume(1) + + assert result is True + assert bucket.tokens < 10 # Reduced (may have refilled slightly) + + def test_consume_multiple_tokens(self): + """ + Test consuming multiple tokens at once. + + Should consume the requested amount if available. + """ + bucket = RateLimitBucket( + tokens=10, max_tokens=10, refill_rate=0.0 # No refill for test + ) + + result = bucket.consume(5) + + assert result is True + assert bucket.tokens == 5 + + def test_consume_denied_when_empty(self): + """ + Test that consumption is denied when not enough tokens. + + Should return False and not reduce tokens below zero. + """ + bucket = RateLimitBucket( + tokens=0, max_tokens=10, refill_rate=0.0 # No refill + ) + + result = bucket.consume(1) + + assert result is False + assert bucket.tokens == 0 # Unchanged + + def test_consume_denied_insufficient_tokens(self): + """ + Test denial when requesting more tokens than available. + + Should return False without modifying token count. + """ + bucket = RateLimitBucket( + tokens=3, max_tokens=10, refill_rate=0.0 + ) + + result = bucket.consume(5) + + assert result is False + assert bucket.tokens == 3 # Unchanged + + def test_token_refill_over_time(self): + """ + Test that tokens refill based on elapsed time. + + Tokens should accumulate at refill_rate per second. + """ + # Start with 5 tokens + past_time = pendulum.now("UTC").subtract(seconds=5) + bucket = RateLimitBucket( + tokens=5, + max_tokens=10, + refill_rate=1.0, # 1 token per second + last_refill=past_time, + ) + + # Force refill by consuming (which triggers refill check) + bucket.consume(0) + + # Should have refilled ~5 tokens (5 seconds * 1 token/sec) + assert bucket.tokens >= 9.5 # Allow small timing variance + + def test_tokens_capped_at_max(self): + """ + Test that tokens don't exceed max_tokens after refill. + + Token count should never go above the configured maximum. + """ + past_time = pendulum.now("UTC").subtract(seconds=100) + bucket = RateLimitBucket( + tokens=5, + max_tokens=10, + refill_rate=1.0, # Would add 100 tokens + last_refill=past_time, + ) + + bucket.consume(0) # Trigger refill + + assert bucket.tokens == 10 # Capped at max + + def test_get_wait_time_when_available(self): + """ + Test wait time calculation when tokens available. + + Should return 0.0 when at least one token is available. + """ + bucket = RateLimitBucket( + tokens=5, max_tokens=10, refill_rate=1.0 + ) + + wait_time = bucket.get_wait_time() + + assert wait_time == 0.0 + + def test_get_wait_time_when_empty(self): + """ + Test wait time calculation when no tokens available. + + Should return time needed for one token to refill. + """ + bucket = RateLimitBucket( + tokens=0, max_tokens=10, refill_rate=2.0 # 2 tokens per second + ) + + wait_time = bucket.get_wait_time() + + assert wait_time == 0.5 # 1 token / 2 tokens per second + + def test_get_wait_time_partial_tokens(self): + """ + Test wait time with partial tokens available. + + Should calculate time to reach 1.0 token. + """ + bucket = RateLimitBucket( + tokens=0.5, max_tokens=10, refill_rate=1.0 + ) + + wait_time = bucket.get_wait_time() + + assert wait_time == 0.5 # Need 0.5 more tokens at 1/sec + + +class TestRateLimiter: + """Tests for the RateLimiter class.""" + + @pytest.fixture + def limiter(self): + """Create a fresh RateLimiter instance for each test.""" + return RateLimiter() + + def test_get_connection_bucket_creates_new(self, limiter): + """ + Test that get_connection_bucket creates bucket for new sid. + + Should create and return a bucket initialized with settings. + """ + bucket = limiter.get_connection_bucket("test_sid_1") + + assert bucket is not None + assert bucket.max_tokens > 0 + + def test_get_connection_bucket_returns_existing(self, limiter): + """ + Test that same bucket is returned for same sid. + + Should return the same bucket instance for subsequent calls. + """ + bucket1 = limiter.get_connection_bucket("test_sid_1") + bucket2 = limiter.get_connection_bucket("test_sid_1") + + assert bucket1 is bucket2 + + def test_get_connection_bucket_different_sids(self, limiter): + """ + Test that different sids get different buckets. + + Each connection should have its own rate limit bucket. + """ + bucket1 = limiter.get_connection_bucket("sid_1") + bucket2 = limiter.get_connection_bucket("sid_2") + + assert bucket1 is not bucket2 + + def test_get_game_bucket_creates_new(self, limiter): + """ + Test that get_game_bucket creates bucket for new game+action. + + Should create bucket keyed by game_id and action type. + """ + bucket = limiter.get_game_bucket("game_123", "decision") + + assert bucket is not None + assert bucket.max_tokens > 0 + + def test_get_game_bucket_different_actions(self, limiter): + """ + Test that different actions get different buckets for same game. + + Each action type has its own rate limit within a game. + """ + decision_bucket = limiter.get_game_bucket("game_123", "decision") + roll_bucket = limiter.get_game_bucket("game_123", "roll") + + assert decision_bucket is not roll_bucket + + def test_get_game_bucket_different_games(self, limiter): + """ + Test that same action on different games gets different buckets. + + Each game tracks its own rate limits. + """ + bucket1 = limiter.get_game_bucket("game_1", "decision") + bucket2 = limiter.get_game_bucket("game_2", "decision") + + assert bucket1 is not bucket2 + + def test_get_user_bucket_creates_new(self, limiter): + """ + Test that get_user_bucket creates bucket for new user. + + Should create bucket for API rate limiting per user. + """ + bucket = limiter.get_user_bucket(123) + + assert bucket is not None + assert bucket.max_tokens > 0 + + @pytest.mark.asyncio + async def test_check_websocket_limit_allowed(self, limiter): + """ + Test that check_websocket_limit allows requests with tokens. + + Should return True when tokens are available. + """ + result = await limiter.check_websocket_limit("test_sid") + + assert result is True + + @pytest.mark.asyncio + async def test_check_websocket_limit_denied_when_exhausted(self, limiter): + """ + Test that check_websocket_limit denies when rate limited. + + Should return False after tokens are exhausted. + """ + sid = "rate_limited_sid" + + # Exhaust all tokens + bucket = limiter.get_connection_bucket(sid) + bucket.tokens = 0 + + result = await limiter.check_websocket_limit(sid) + + assert result is False + + @pytest.mark.asyncio + async def test_check_game_limit_allowed(self, limiter): + """ + Test that check_game_limit allows requests with tokens. + + Should return True when game action limit not reached. + """ + result = await limiter.check_game_limit("game_123", "decision") + + assert result is True + + @pytest.mark.asyncio + async def test_check_game_limit_denied_when_exhausted(self, limiter): + """ + Test that check_game_limit denies when rate limited. + + Should return False after game action tokens exhausted. + """ + game_id = "rate_limited_game" + + bucket = limiter.get_game_bucket(game_id, "roll") + bucket.tokens = 0 + + result = await limiter.check_game_limit(game_id, "roll") + + assert result is False + + @pytest.mark.asyncio + async def test_check_api_limit_allowed(self, limiter): + """ + Test that check_api_limit allows requests with tokens. + + Should return True when user API limit not reached. + """ + result = await limiter.check_api_limit(123) + + assert result is True + + @pytest.mark.asyncio + async def test_check_api_limit_denied_when_exhausted(self, limiter): + """ + Test that check_api_limit denies when rate limited. + + Should return False after user API tokens exhausted. + """ + user_id = 999 + + bucket = limiter.get_user_bucket(user_id) + bucket.tokens = 0 + + result = await limiter.check_api_limit(user_id) + + assert result is False + + def test_remove_connection_cleans_up(self, limiter): + """ + Test that remove_connection removes bucket for sid. + + Should clean up bucket when connection closes. + """ + sid = "cleanup_test_sid" + limiter.get_connection_bucket(sid) # Create bucket + + assert sid in limiter._connection_buckets + + limiter.remove_connection(sid) + + assert sid not in limiter._connection_buckets + + def test_remove_connection_handles_missing(self, limiter): + """ + Test that remove_connection handles non-existent sid gracefully. + + Should not raise error for unknown sid. + """ + limiter.remove_connection("nonexistent_sid") # Should not raise + + def test_remove_game_cleans_up_all_actions(self, limiter): + """ + Test that remove_game removes all buckets for a game. + + Should clean up decision, roll, substitution buckets for game. + """ + game_id = "cleanup_game" + + # Create buckets for different actions + limiter.get_game_bucket(game_id, "decision") + limiter.get_game_bucket(game_id, "roll") + limiter.get_game_bucket(game_id, "substitution") + + assert f"{game_id}:decision" in limiter._game_buckets + assert f"{game_id}:roll" in limiter._game_buckets + assert f"{game_id}:substitution" in limiter._game_buckets + + limiter.remove_game(game_id) + + assert f"{game_id}:decision" not in limiter._game_buckets + assert f"{game_id}:roll" not in limiter._game_buckets + assert f"{game_id}:substitution" not in limiter._game_buckets + + def test_remove_game_preserves_other_games(self, limiter): + """ + Test that remove_game only removes specified game's buckets. + + Other games' buckets should remain intact. + """ + limiter.get_game_bucket("game_1", "decision") + limiter.get_game_bucket("game_2", "decision") + + limiter.remove_game("game_1") + + assert "game_1:decision" not in limiter._game_buckets + assert "game_2:decision" in limiter._game_buckets + + def test_get_stats_returns_counts(self, limiter): + """ + Test that get_stats returns bucket counts. + + Should return counts for all bucket types. + """ + limiter.get_connection_bucket("sid_1") + limiter.get_connection_bucket("sid_2") + limiter.get_game_bucket("game_1", "decision") + limiter.get_user_bucket(123) + + stats = limiter.get_stats() + + assert stats["connection_buckets"] == 2 + assert stats["game_buckets"] == 1 + assert stats["user_buckets"] == 1 + + +class TestCleanupTask: + """Tests for the stale bucket cleanup functionality.""" + + @pytest.mark.asyncio + async def test_cleanup_removes_stale_connection_buckets(self): + """ + Test that cleanup removes connection buckets older than threshold. + + Buckets unused for >10 minutes should be cleaned up. + """ + limiter = RateLimiter() + + # Create a bucket with old last_refill time (>10 minutes ago) + stale_time = pendulum.now("UTC").subtract(minutes=15) + bucket = limiter.get_connection_bucket("stale_sid") + bucket.last_refill = stale_time + + # Create a fresh bucket + fresh_bucket = limiter.get_connection_bucket("fresh_sid") + + # Manually trigger cleanup logic (without running full async task) + now = pendulum.now("UTC") + stale_sids = [ + sid + for sid, b in limiter._connection_buckets.items() + if (now - b.last_refill).total_seconds() > 600 + ] + for sid in stale_sids: + del limiter._connection_buckets[sid] + + assert "stale_sid" not in limiter._connection_buckets + assert "fresh_sid" in limiter._connection_buckets + + @pytest.mark.asyncio + async def test_cleanup_removes_stale_game_buckets(self): + """ + Test that cleanup removes game buckets older than threshold. + + Abandoned game buckets should be cleaned up. + """ + limiter = RateLimiter() + + stale_time = pendulum.now("UTC").subtract(minutes=15) + bucket = limiter.get_game_bucket("stale_game", "decision") + bucket.last_refill = stale_time + + limiter.get_game_bucket("fresh_game", "decision") + + # Manually trigger cleanup + now = pendulum.now("UTC") + stale_keys = [ + key + for key, b in limiter._game_buckets.items() + if (now - b.last_refill).total_seconds() > 600 + ] + for key in stale_keys: + del limiter._game_buckets[key] + + assert "stale_game:decision" not in limiter._game_buckets + assert "fresh_game:decision" in limiter._game_buckets + + +class TestGlobalRateLimiter: + """Tests for the global rate_limiter instance.""" + + def test_global_instance_exists(self): + """ + Test that global rate_limiter instance is available. + + Should be importable and usable as singleton. + """ + assert rate_limiter is not None + assert isinstance(rate_limiter, RateLimiter) + + @pytest.mark.asyncio + async def test_global_instance_functional(self): + """ + Test that global rate_limiter works correctly. + + Should accept and rate limit requests. + """ + result = await rate_limiter.check_websocket_limit("global_test_sid") + + assert result is True + + # Cleanup + rate_limiter.remove_connection("global_test_sid") + + +class TestRateLimitConfiguration: + """Tests for rate limit configuration integration.""" + + def test_connection_bucket_uses_config(self): + """ + Test that connection bucket uses configured limits. + + max_tokens should match rate_limit_websocket_per_minute setting. + """ + limiter = RateLimiter() + bucket = limiter.get_connection_bucket("config_test_sid") + + # Default is 120 per minute + assert bucket.max_tokens == limiter._settings.rate_limit_websocket_per_minute + + limiter.remove_connection("config_test_sid") + + def test_decision_bucket_uses_config(self): + """ + Test that decision bucket uses configured limits. + + max_tokens should match rate_limit_decision_per_game setting. + """ + limiter = RateLimiter() + bucket = limiter.get_game_bucket("config_game", "decision") + + assert bucket.max_tokens == limiter._settings.rate_limit_decision_per_game + + limiter.remove_game("config_game") + + def test_roll_bucket_uses_config(self): + """ + Test that roll bucket uses configured limits. + + max_tokens should match rate_limit_roll_per_game setting. + """ + limiter = RateLimiter() + bucket = limiter.get_game_bucket("config_game", "roll") + + assert bucket.max_tokens == limiter._settings.rate_limit_roll_per_game + + limiter.remove_game("config_game") + + def test_substitution_bucket_uses_config(self): + """ + Test that substitution bucket uses configured limits. + + max_tokens should match rate_limit_substitution_per_game setting. + """ + limiter = RateLimiter() + bucket = limiter.get_game_bucket("config_game", "substitution") + + assert bucket.max_tokens == limiter._settings.rate_limit_substitution_per_game + + limiter.remove_game("config_game") + + def test_api_bucket_uses_config(self): + """ + Test that API bucket uses configured limits. + + max_tokens should match rate_limit_api_per_minute setting. + """ + limiter = RateLimiter() + bucket = limiter.get_user_bucket(12345) + + assert bucket.max_tokens == limiter._settings.rate_limit_api_per_minute + + del limiter._user_buckets[12345] diff --git a/backend/tests/unit/monitoring/__init__.py b/backend/tests/unit/monitoring/__init__.py new file mode 100644 index 0000000..de0e797 --- /dev/null +++ b/backend/tests/unit/monitoring/__init__.py @@ -0,0 +1 @@ +"""Tests for monitoring utilities.""" diff --git a/backend/tests/unit/monitoring/test_pool_monitor.py b/backend/tests/unit/monitoring/test_pool_monitor.py new file mode 100644 index 0000000..748ceec --- /dev/null +++ b/backend/tests/unit/monitoring/test_pool_monitor.py @@ -0,0 +1,388 @@ +""" +Tests for database connection pool monitoring. + +Verifies that the PoolMonitor correctly: +- Reports pool statistics (checked in/out, overflow) +- Classifies health status (healthy/warning/critical) +- Tracks history of stats over time +- Logs warnings when usage exceeds threshold + +Author: Claude +Date: 2025-11-27 +""" + +import pytest +from unittest.mock import MagicMock, patch + +from app.monitoring.pool_monitor import PoolMonitor, PoolStats + + +# ============================================================================ +# TEST FIXTURES +# ============================================================================ + + +@pytest.fixture +def mock_engine(): + """ + Create a mock SQLAlchemy engine with pool stats. + + Default configuration: 15 available, 5 in use, no overflow. + """ + engine = MagicMock() + pool = MagicMock() + pool.checkedin.return_value = 15 + pool.checkedout.return_value = 5 + pool.overflow.return_value = 0 + engine.pool = pool + return engine + + +@pytest.fixture +def mock_settings(): + """Mock settings with pool configuration.""" + settings = MagicMock() + settings.db_pool_size = 20 + settings.db_max_overflow = 10 + return settings + + +# ============================================================================ +# POOL STATS TESTS +# ============================================================================ + + +class TestPoolMonitorStats: + """Tests for PoolMonitor.get_stats().""" + + def test_get_stats_returns_pool_stats(self, mock_engine, mock_settings): + """ + Verify get_stats() returns PoolStats with correct values. + + The returned PoolStats should reflect the current pool state + from the engine's pool attribute. + """ + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine) + stats = monitor.get_stats() + + assert isinstance(stats, PoolStats) + assert stats.checkedout == 5 + assert stats.checkedin == 15 + assert stats.overflow == 0 + assert stats.pool_size == 20 + assert stats.max_overflow == 10 + assert stats.total_capacity == 30 + + def test_get_stats_calculates_usage_percent(self, mock_engine, mock_settings): + """ + Verify usage percentage is calculated correctly. + + Usage = checkedout / total_capacity + """ + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine) + stats = monitor.get_stats() + + # 5 out of 30 = 16.67% + assert stats.usage_percent == pytest.approx(5 / 30, rel=0.01) + + def test_get_stats_handles_zero_capacity(self, mock_engine): + """ + Verify get_stats() handles zero capacity without division error. + + Edge case: if both pool_size and max_overflow are 0. + """ + settings = MagicMock() + settings.db_pool_size = 0 + settings.db_max_overflow = 0 + + with patch("app.monitoring.pool_monitor.get_settings", return_value=settings): + monitor = PoolMonitor(mock_engine) + stats = monitor.get_stats() + + assert stats.usage_percent == 0 + assert stats.total_capacity == 0 + + def test_get_stats_records_history(self, mock_engine, mock_settings): + """ + Verify get_stats() records each call in history. + + History is used for monitoring dashboards and alerts. + """ + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine) + + # Make multiple calls + for _ in range(5): + monitor.get_stats() + + history = monitor.get_history(limit=10) + assert len(history) == 5 + + def test_get_stats_limits_history_size(self, mock_engine, mock_settings): + """ + Verify history is limited to prevent memory growth. + + Older entries should be removed when max_history is exceeded. + """ + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine, max_history=5) + + # Make more calls than max_history + for _ in range(10): + monitor.get_stats() + + history = monitor.get_history(limit=10) + assert len(history) == 5 + + +# ============================================================================ +# HEALTH STATUS TESTS +# ============================================================================ + + +class TestPoolMonitorHealth: + """Tests for PoolMonitor.get_health_status().""" + + def test_health_status_healthy(self, mock_engine, mock_settings): + """ + Verify health status is 'healthy' when usage < 75%. + + At 5/30 (16.7%) usage, status should be healthy. + """ + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine) + health = monitor.get_health_status() + + assert health["status"] == "healthy" + assert health["in_use"] == 5 + assert health["available"] == 15 + + def test_health_status_warning(self, mock_engine, mock_settings): + """ + Verify health status is 'warning' when usage is 75-90%. + + At 24/30 (80%) usage, status should be warning. + """ + mock_engine.pool.checkedout.return_value = 24 + mock_engine.pool.checkedin.return_value = 6 + + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine) + health = monitor.get_health_status() + + assert health["status"] == "warning" + assert health["usage_percent"] == pytest.approx(80, rel=1) + + def test_health_status_critical(self, mock_engine, mock_settings): + """ + Verify health status is 'critical' when usage >= 90%. + + At 28/30 (93%) usage, status should be critical. + """ + mock_engine.pool.checkedout.return_value = 28 + mock_engine.pool.checkedin.return_value = 2 + + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine) + health = monitor.get_health_status() + + assert health["status"] == "critical" + assert health["usage_percent"] >= 90 + + def test_health_status_includes_timestamp(self, mock_engine, mock_settings): + """ + Verify health status includes ISO timestamp. + + Timestamp is used for monitoring and alerting correlation. + """ + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine) + health = monitor.get_health_status() + + assert "timestamp" in health + assert isinstance(health["timestamp"], str) + assert "T" in health["timestamp"] # ISO format + + def test_health_status_includes_all_fields(self, mock_engine, mock_settings): + """ + Verify health status includes all required fields. + + All fields are needed for complete monitoring visibility. + """ + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine) + health = monitor.get_health_status() + + required_fields = [ + "status", + "pool_size", + "max_overflow", + "available", + "in_use", + "overflow_active", + "total_capacity", + "usage_percent", + "timestamp", + ] + + for field in required_fields: + assert field in health, f"Missing field: {field}" + + +# ============================================================================ +# HISTORY TESTS +# ============================================================================ + + +class TestPoolMonitorHistory: + """Tests for PoolMonitor.get_history().""" + + def test_get_history_returns_recent_stats(self, mock_engine, mock_settings): + """ + Verify get_history() returns recent stat snapshots. + + History should include checkedout, usage_percent, and timestamp. + """ + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine) + + # Generate some history + for _ in range(3): + monitor.get_stats() + + history = monitor.get_history(limit=5) + + assert len(history) == 3 + for entry in history: + assert "checkedout" in entry + assert "usage_percent" in entry + assert "timestamp" in entry + + def test_get_history_respects_limit(self, mock_engine, mock_settings): + """ + Verify get_history() respects the limit parameter. + + Should return at most 'limit' entries. + """ + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine) + + for _ in range(10): + monitor.get_stats() + + history = monitor.get_history(limit=3) + assert len(history) == 3 + + def test_get_history_returns_empty_when_no_stats(self, mock_engine, mock_settings): + """ + Verify get_history() returns empty list when no stats collected. + + Fresh monitor should have no history. + """ + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine) + history = monitor.get_history() + + assert history == [] + + +# ============================================================================ +# ALERT THRESHOLD TESTS +# ============================================================================ + + +class TestPoolMonitorAlerts: + """Tests for alert threshold behavior.""" + + def test_logs_warning_when_threshold_exceeded(self, mock_engine, mock_settings, caplog): + """ + Verify warning is logged when usage exceeds alert threshold. + + Default threshold is 80%. At 85% usage, warning should be logged. + """ + mock_engine.pool.checkedout.return_value = 26 # 86.7% + mock_engine.pool.checkedin.return_value = 4 + + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine, alert_threshold=0.8) + + import logging + with caplog.at_level(logging.WARNING): + monitor.get_stats() + + assert "Connection pool usage high" in caplog.text + + def test_logs_info_when_overflow_active(self, mock_engine, mock_settings, caplog): + """ + Verify info is logged when overflow connections are in use. + + Overflow indicates pool is at capacity and using extra connections. + """ + mock_engine.pool.overflow.return_value = 3 + + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine) + + import logging + with caplog.at_level(logging.INFO): + monitor.get_stats() + + assert "Pool overflow active" in caplog.text + + def test_no_warning_when_below_threshold(self, mock_engine, mock_settings, caplog): + """ + Verify no warning when usage is below threshold. + + At 50% usage with 80% threshold, no warning should be logged. + """ + mock_engine.pool.checkedout.return_value = 10 # 33% + mock_engine.pool.checkedin.return_value = 20 + + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = PoolMonitor(mock_engine, alert_threshold=0.8) + + import logging + with caplog.at_level(logging.WARNING): + monitor.get_stats() + + assert "Connection pool usage high" not in caplog.text + + +# ============================================================================ +# INIT FUNCTION TESTS +# ============================================================================ + + +class TestInitPoolMonitor: + """Tests for init_pool_monitor function.""" + + def test_init_pool_monitor_creates_global_instance(self, mock_engine, mock_settings): + """ + Verify init_pool_monitor() creates global instance. + + The global instance is used by health endpoints. + """ + from app.monitoring.pool_monitor import init_pool_monitor, pool_monitor + + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = init_pool_monitor(mock_engine) + + assert monitor is not None + assert isinstance(monitor, PoolMonitor) + + def test_init_pool_monitor_returns_monitor(self, mock_engine, mock_settings): + """ + Verify init_pool_monitor() returns the created monitor. + + Return value allows caller to use the monitor directly. + """ + from app.monitoring.pool_monitor import init_pool_monitor + + with patch("app.monitoring.pool_monitor.get_settings", return_value=mock_settings): + monitor = init_pool_monitor(mock_engine) + + # Should be able to use the returned monitor + stats = monitor.get_stats() + assert stats is not None