mantimon-tcg/backend/app/services/token_store.py
Cal Corum 996c43fbd9 Implement Phase 2: Authentication system
Complete OAuth-based authentication with JWT session management:

Core Services:
- JWT service for access/refresh token creation and verification
- Token store with Redis-backed refresh token revocation
- User service for CRUD operations and OAuth-based creation
- Google and Discord OAuth services with full flow support

API Endpoints:
- GET /api/auth/{google,discord} - Start OAuth flows
- GET /api/auth/{google,discord}/callback - Handle OAuth callbacks
- POST /api/auth/refresh - Exchange refresh token for new access token
- POST /api/auth/logout - Revoke single refresh token
- POST /api/auth/logout-all - Revoke all user sessions
- GET/PATCH /api/users/me - User profile management
- GET /api/users/me/linked-accounts - List OAuth providers
- GET /api/users/me/sessions - Count active sessions

Infrastructure:
- Pydantic schemas for auth/user request/response models
- FastAPI dependencies (get_current_user, get_current_premium_user)
- OAuthLinkedAccount model for multi-provider support
- Alembic migration for oauth_linked_accounts table

Dependencies added: email-validator, fakeredis (dev), respx (dev)

84 new tests, 1058 total passing
2026-01-27 21:49:59 -06:00

196 lines
5.6 KiB
Python

"""Refresh token storage for Mantimon TCG.
This module provides Redis-based storage for refresh token tracking
and revocation. Each refresh token's JTI is stored in Redis with
a TTL matching the token's expiration.
Key Pattern:
refresh_token:{user_id}:{jti} -> "1" (exists = valid)
Revocation:
- Single token: Delete the specific key
- All user tokens: Delete all keys matching refresh_token:{user_id}:*
Example:
from app.services.token_store import token_store
# Store a new refresh token
await token_store.store_refresh_token(user_id, jti, expires_at)
# Check if token is valid (not revoked)
if await token_store.is_token_valid(user_id, jti):
# Issue new access token
# Revoke on logout
await token_store.revoke_token(user_id, jti)
# Logout from all devices
await token_store.revoke_all_user_tokens(user_id)
"""
from datetime import UTC, datetime
from uuid import UUID
from app.db.redis import get_redis
class TokenStore:
"""Redis-based refresh token storage for revocation support.
Tracks valid refresh tokens by storing their JTIs in Redis.
Tokens can be revoked individually or all at once per user.
"""
KEY_PREFIX = "refresh_token"
def _make_key(self, user_id: UUID, jti: str) -> str:
"""Create Redis key for a refresh token.
Args:
user_id: The user's UUID.
jti: The token's unique identifier.
Returns:
Redis key string.
"""
return f"{self.KEY_PREFIX}:{user_id}:{jti}"
def _make_user_pattern(self, user_id: UUID) -> str:
"""Create Redis key pattern for all user's tokens.
Args:
user_id: The user's UUID.
Returns:
Redis key pattern for SCAN/KEYS.
"""
return f"{self.KEY_PREFIX}:{user_id}:*"
async def store_refresh_token(
self,
user_id: UUID,
jti: str,
expires_at: datetime,
) -> None:
"""Store a refresh token's JTI in Redis.
Args:
user_id: The user's UUID.
jti: The token's unique identifier (from JWT).
expires_at: When the token expires (for TTL calculation).
Example:
expires_at = datetime.now(UTC) + timedelta(days=7)
await token_store.store_refresh_token(user_id, jti, expires_at)
"""
key = self._make_key(user_id, jti)
# Calculate TTL in seconds
now = datetime.now(UTC)
ttl_seconds = int((expires_at - now).total_seconds())
if ttl_seconds <= 0:
# Token already expired, don't store
return
async with get_redis() as redis:
await redis.setex(key, ttl_seconds, "1")
async def is_token_valid(self, user_id: UUID, jti: str) -> bool:
"""Check if a refresh token is valid (not revoked).
Args:
user_id: The user's UUID.
jti: The token's unique identifier.
Returns:
True if token exists in store (valid), False if revoked or expired.
Example:
if await token_store.is_token_valid(user_id, jti):
# Token is valid, issue new access token
else:
# Token was revoked, require re-authentication
"""
key = self._make_key(user_id, jti)
async with get_redis() as redis:
result = await redis.exists(key)
return result > 0
async def revoke_token(self, user_id: UUID, jti: str) -> bool:
"""Revoke a specific refresh token.
Args:
user_id: The user's UUID.
jti: The token's unique identifier.
Returns:
True if token was revoked, False if it didn't exist.
Example:
# On logout
await token_store.revoke_token(user_id, jti)
"""
key = self._make_key(user_id, jti)
async with get_redis() as redis:
result = await redis.delete(key)
return result > 0
async def revoke_all_user_tokens(self, user_id: UUID) -> int:
"""Revoke all refresh tokens for a user.
Useful for "logout from all devices" or security incidents.
Args:
user_id: The user's UUID.
Returns:
Number of tokens revoked.
Example:
# Logout from all devices
count = await token_store.revoke_all_user_tokens(user_id)
print(f"Revoked {count} sessions")
"""
pattern = self._make_user_pattern(user_id)
async with get_redis() as redis:
# Use SCAN to find all matching keys (safer than KEYS for large datasets)
keys_to_delete = []
async for key in redis.scan_iter(match=pattern):
keys_to_delete.append(key)
if not keys_to_delete:
return 0
# Delete all found keys
result = await redis.delete(*keys_to_delete)
return result
async def get_active_session_count(self, user_id: UUID) -> int:
"""Get the number of active sessions (valid refresh tokens) for a user.
Args:
user_id: The user's UUID.
Returns:
Number of active sessions.
Example:
count = await token_store.get_active_session_count(user_id)
print(f"User has {count} active sessions")
"""
pattern = self._make_user_pattern(user_id)
async with get_redis() as redis:
count = 0
async for _ in redis.scan_iter(match=pattern):
count += 1
return count
# Global token store instance
token_store = TokenStore()