"""Refresh token storage for Mantimon TCG. This module provides Redis-based storage for refresh token tracking and revocation. Each refresh token's JTI is stored in Redis with a TTL matching the token's expiration. Key Pattern: refresh_token:{user_id}:{jti} -> "1" (exists = valid) Revocation: - Single token: Delete the specific key - All user tokens: Delete all keys matching refresh_token:{user_id}:* Example: from app.services.token_store import token_store # Store a new refresh token await token_store.store_refresh_token(user_id, jti, expires_at) # Check if token is valid (not revoked) if await token_store.is_token_valid(user_id, jti): # Issue new access token # Revoke on logout await token_store.revoke_token(user_id, jti) # Logout from all devices await token_store.revoke_all_user_tokens(user_id) """ from datetime import UTC, datetime from uuid import UUID from app.db.redis import get_redis class TokenStore: """Redis-based refresh token storage for revocation support. Tracks valid refresh tokens by storing their JTIs in Redis. Tokens can be revoked individually or all at once per user. """ KEY_PREFIX = "refresh_token" def _make_key(self, user_id: UUID, jti: str) -> str: """Create Redis key for a refresh token. Args: user_id: The user's UUID. jti: The token's unique identifier. Returns: Redis key string. """ return f"{self.KEY_PREFIX}:{user_id}:{jti}" def _make_user_pattern(self, user_id: UUID) -> str: """Create Redis key pattern for all user's tokens. Args: user_id: The user's UUID. Returns: Redis key pattern for SCAN/KEYS. """ return f"{self.KEY_PREFIX}:{user_id}:*" async def store_refresh_token( self, user_id: UUID, jti: str, expires_at: datetime, ) -> None: """Store a refresh token's JTI in Redis. Args: user_id: The user's UUID. jti: The token's unique identifier (from JWT). expires_at: When the token expires (for TTL calculation). Example: expires_at = datetime.now(UTC) + timedelta(days=7) await token_store.store_refresh_token(user_id, jti, expires_at) """ key = self._make_key(user_id, jti) # Calculate TTL in seconds now = datetime.now(UTC) ttl_seconds = int((expires_at - now).total_seconds()) if ttl_seconds <= 0: # Token already expired, don't store return async with get_redis() as redis: await redis.setex(key, ttl_seconds, "1") async def is_token_valid(self, user_id: UUID, jti: str) -> bool: """Check if a refresh token is valid (not revoked). Args: user_id: The user's UUID. jti: The token's unique identifier. Returns: True if token exists in store (valid), False if revoked or expired. Example: if await token_store.is_token_valid(user_id, jti): # Token is valid, issue new access token else: # Token was revoked, require re-authentication """ key = self._make_key(user_id, jti) async with get_redis() as redis: result = await redis.exists(key) return result > 0 async def revoke_token(self, user_id: UUID, jti: str) -> bool: """Revoke a specific refresh token. Args: user_id: The user's UUID. jti: The token's unique identifier. Returns: True if token was revoked, False if it didn't exist. Example: # On logout await token_store.revoke_token(user_id, jti) """ key = self._make_key(user_id, jti) async with get_redis() as redis: result = await redis.delete(key) return result > 0 async def revoke_all_user_tokens(self, user_id: UUID) -> int: """Revoke all refresh tokens for a user. Useful for "logout from all devices" or security incidents. Args: user_id: The user's UUID. Returns: Number of tokens revoked. Example: # Logout from all devices count = await token_store.revoke_all_user_tokens(user_id) print(f"Revoked {count} sessions") """ pattern = self._make_user_pattern(user_id) async with get_redis() as redis: # Use SCAN to find all matching keys (safer than KEYS for large datasets) keys_to_delete = [] async for key in redis.scan_iter(match=pattern): keys_to_delete.append(key) if not keys_to_delete: return 0 # Delete all found keys result = await redis.delete(*keys_to_delete) return result async def get_active_session_count(self, user_id: UUID) -> int: """Get the number of active sessions (valid refresh tokens) for a user. Args: user_id: The user's UUID. Returns: Number of active sessions. Example: count = await token_store.get_active_session_count(user_id) print(f"User has {count} active sessions") """ pattern = self._make_user_pattern(user_id) async with get_redis() as redis: count = 0 async for _ in redis.scan_iter(match=pattern): count += 1 return count # Global token store instance token_store = TokenStore()