Complete OAuth-based authentication with JWT session management:
Core Services:
- JWT service for access/refresh token creation and verification
- Token store with Redis-backed refresh token revocation
- User service for CRUD operations and OAuth-based creation
- Google and Discord OAuth services with full flow support
API Endpoints:
- GET /api/auth/{google,discord} - Start OAuth flows
- GET /api/auth/{google,discord}/callback - Handle OAuth callbacks
- POST /api/auth/refresh - Exchange refresh token for new access token
- POST /api/auth/logout - Revoke single refresh token
- POST /api/auth/logout-all - Revoke all user sessions
- GET/PATCH /api/users/me - User profile management
- GET /api/users/me/linked-accounts - List OAuth providers
- GET /api/users/me/sessions - Count active sessions
Infrastructure:
- Pydantic schemas for auth/user request/response models
- FastAPI dependencies (get_current_user, get_current_premium_user)
- OAuthLinkedAccount model for multi-provider support
- Alembic migration for oauth_linked_accounts table
Dependencies added: email-validator, fakeredis (dev), respx (dev)
84 new tests, 1058 total passing
196 lines
5.6 KiB
Python
196 lines
5.6 KiB
Python
"""Refresh token storage for Mantimon TCG.
|
|
|
|
This module provides Redis-based storage for refresh token tracking
|
|
and revocation. Each refresh token's JTI is stored in Redis with
|
|
a TTL matching the token's expiration.
|
|
|
|
Key Pattern:
|
|
refresh_token:{user_id}:{jti} -> "1" (exists = valid)
|
|
|
|
Revocation:
|
|
- Single token: Delete the specific key
|
|
- All user tokens: Delete all keys matching refresh_token:{user_id}:*
|
|
|
|
Example:
|
|
from app.services.token_store import token_store
|
|
|
|
# Store a new refresh token
|
|
await token_store.store_refresh_token(user_id, jti, expires_at)
|
|
|
|
# Check if token is valid (not revoked)
|
|
if await token_store.is_token_valid(user_id, jti):
|
|
# Issue new access token
|
|
|
|
# Revoke on logout
|
|
await token_store.revoke_token(user_id, jti)
|
|
|
|
# Logout from all devices
|
|
await token_store.revoke_all_user_tokens(user_id)
|
|
"""
|
|
|
|
from datetime import UTC, datetime
|
|
from uuid import UUID
|
|
|
|
from app.db.redis import get_redis
|
|
|
|
|
|
class TokenStore:
|
|
"""Redis-based refresh token storage for revocation support.
|
|
|
|
Tracks valid refresh tokens by storing their JTIs in Redis.
|
|
Tokens can be revoked individually or all at once per user.
|
|
"""
|
|
|
|
KEY_PREFIX = "refresh_token"
|
|
|
|
def _make_key(self, user_id: UUID, jti: str) -> str:
|
|
"""Create Redis key for a refresh token.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
jti: The token's unique identifier.
|
|
|
|
Returns:
|
|
Redis key string.
|
|
"""
|
|
return f"{self.KEY_PREFIX}:{user_id}:{jti}"
|
|
|
|
def _make_user_pattern(self, user_id: UUID) -> str:
|
|
"""Create Redis key pattern for all user's tokens.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
|
|
Returns:
|
|
Redis key pattern for SCAN/KEYS.
|
|
"""
|
|
return f"{self.KEY_PREFIX}:{user_id}:*"
|
|
|
|
async def store_refresh_token(
|
|
self,
|
|
user_id: UUID,
|
|
jti: str,
|
|
expires_at: datetime,
|
|
) -> None:
|
|
"""Store a refresh token's JTI in Redis.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
jti: The token's unique identifier (from JWT).
|
|
expires_at: When the token expires (for TTL calculation).
|
|
|
|
Example:
|
|
expires_at = datetime.now(UTC) + timedelta(days=7)
|
|
await token_store.store_refresh_token(user_id, jti, expires_at)
|
|
"""
|
|
key = self._make_key(user_id, jti)
|
|
|
|
# Calculate TTL in seconds
|
|
now = datetime.now(UTC)
|
|
ttl_seconds = int((expires_at - now).total_seconds())
|
|
|
|
if ttl_seconds <= 0:
|
|
# Token already expired, don't store
|
|
return
|
|
|
|
async with get_redis() as redis:
|
|
await redis.setex(key, ttl_seconds, "1")
|
|
|
|
async def is_token_valid(self, user_id: UUID, jti: str) -> bool:
|
|
"""Check if a refresh token is valid (not revoked).
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
jti: The token's unique identifier.
|
|
|
|
Returns:
|
|
True if token exists in store (valid), False if revoked or expired.
|
|
|
|
Example:
|
|
if await token_store.is_token_valid(user_id, jti):
|
|
# Token is valid, issue new access token
|
|
else:
|
|
# Token was revoked, require re-authentication
|
|
"""
|
|
key = self._make_key(user_id, jti)
|
|
|
|
async with get_redis() as redis:
|
|
result = await redis.exists(key)
|
|
return result > 0
|
|
|
|
async def revoke_token(self, user_id: UUID, jti: str) -> bool:
|
|
"""Revoke a specific refresh token.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
jti: The token's unique identifier.
|
|
|
|
Returns:
|
|
True if token was revoked, False if it didn't exist.
|
|
|
|
Example:
|
|
# On logout
|
|
await token_store.revoke_token(user_id, jti)
|
|
"""
|
|
key = self._make_key(user_id, jti)
|
|
|
|
async with get_redis() as redis:
|
|
result = await redis.delete(key)
|
|
return result > 0
|
|
|
|
async def revoke_all_user_tokens(self, user_id: UUID) -> int:
|
|
"""Revoke all refresh tokens for a user.
|
|
|
|
Useful for "logout from all devices" or security incidents.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
|
|
Returns:
|
|
Number of tokens revoked.
|
|
|
|
Example:
|
|
# Logout from all devices
|
|
count = await token_store.revoke_all_user_tokens(user_id)
|
|
print(f"Revoked {count} sessions")
|
|
"""
|
|
pattern = self._make_user_pattern(user_id)
|
|
|
|
async with get_redis() as redis:
|
|
# Use SCAN to find all matching keys (safer than KEYS for large datasets)
|
|
keys_to_delete = []
|
|
async for key in redis.scan_iter(match=pattern):
|
|
keys_to_delete.append(key)
|
|
|
|
if not keys_to_delete:
|
|
return 0
|
|
|
|
# Delete all found keys
|
|
result = await redis.delete(*keys_to_delete)
|
|
return result
|
|
|
|
async def get_active_session_count(self, user_id: UUID) -> int:
|
|
"""Get the number of active sessions (valid refresh tokens) for a user.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
|
|
Returns:
|
|
Number of active sessions.
|
|
|
|
Example:
|
|
count = await token_store.get_active_session_count(user_id)
|
|
print(f"User has {count} active sessions")
|
|
"""
|
|
pattern = self._make_user_pattern(user_id)
|
|
|
|
async with get_redis() as redis:
|
|
count = 0
|
|
async for _ in redis.scan_iter(match=pattern):
|
|
count += 1
|
|
return count
|
|
|
|
|
|
# Global token store instance
|
|
token_store = TokenStore()
|