Merge branch 'backend-phase2' - Complete Phase 2 Authentication
This commit is contained in:
commit
4cdb544162
14
backend/app/api/__init__.py
Normal file
14
backend/app/api/__init__.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
"""API routers and dependencies for Mantimon TCG.
|
||||||
|
|
||||||
|
This package contains FastAPI routers and common dependencies
|
||||||
|
for the REST API.
|
||||||
|
|
||||||
|
Routers:
|
||||||
|
- auth: OAuth login, token refresh, logout
|
||||||
|
- users: User profile management
|
||||||
|
|
||||||
|
Dependencies:
|
||||||
|
- get_current_user: Extract and validate user from JWT
|
||||||
|
- get_current_active_user: Ensure user exists
|
||||||
|
- get_current_premium_user: Require premium subscription
|
||||||
|
"""
|
||||||
651
backend/app/api/auth.py
Normal file
651
backend/app/api/auth.py
Normal file
@ -0,0 +1,651 @@
|
|||||||
|
"""Authentication API router for Mantimon TCG.
|
||||||
|
|
||||||
|
This module provides endpoints for OAuth authentication:
|
||||||
|
- OAuth login redirects (Google, Discord)
|
||||||
|
- OAuth callbacks (token exchange)
|
||||||
|
- Token refresh
|
||||||
|
- Logout
|
||||||
|
|
||||||
|
OAuth Flow:
|
||||||
|
1. Client calls GET /auth/{provider} to get redirect URL
|
||||||
|
2. Client redirects user to OAuth provider
|
||||||
|
3. Provider redirects back to GET /auth/{provider}/callback
|
||||||
|
4. Server exchanges code for tokens, creates/fetches user
|
||||||
|
5. Server returns JWT access + refresh tokens
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Start OAuth flow
|
||||||
|
GET /api/auth/google?redirect_uri=https://play.mantimon.com/login/callback
|
||||||
|
|
||||||
|
# After OAuth callback, refresh tokens
|
||||||
|
POST /api/auth/refresh
|
||||||
|
{"refresh_token": "..."}
|
||||||
|
|
||||||
|
# Logout
|
||||||
|
POST /api/auth/logout
|
||||||
|
{"refresh_token": "..."}
|
||||||
|
"""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query, status
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
from app.api.deps import CurrentUser, DbSession
|
||||||
|
from app.config import settings
|
||||||
|
from app.db.redis import get_redis
|
||||||
|
from app.schemas.auth import RefreshTokenRequest, TokenResponse
|
||||||
|
from app.services.jwt_service import (
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
get_refresh_token_expiration,
|
||||||
|
get_token_expiration_seconds,
|
||||||
|
verify_refresh_token,
|
||||||
|
)
|
||||||
|
from app.services.oauth.discord import DiscordOAuthError, discord_oauth
|
||||||
|
from app.services.oauth.google import GoogleOAuthError, google_oauth
|
||||||
|
from app.services.token_store import token_store
|
||||||
|
from app.services.user_service import AccountLinkingError, user_service
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
# OAuth state TTL (5 minutes)
|
||||||
|
OAUTH_STATE_TTL = 300
|
||||||
|
|
||||||
|
|
||||||
|
async def _store_oauth_state(state: str, provider: str, redirect_uri: str) -> None:
|
||||||
|
"""Store OAuth state in Redis for CSRF validation."""
|
||||||
|
async with get_redis() as redis:
|
||||||
|
key = f"oauth_state:{state}"
|
||||||
|
value = f"{provider}:{redirect_uri}"
|
||||||
|
await redis.setex(key, OAUTH_STATE_TTL, value)
|
||||||
|
|
||||||
|
|
||||||
|
async def _validate_oauth_state(state: str, provider: str) -> str | None:
|
||||||
|
"""Validate and consume OAuth state, returning redirect_uri if valid."""
|
||||||
|
async with get_redis() as redis:
|
||||||
|
key = f"oauth_state:{state}"
|
||||||
|
value = await redis.get(key)
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Delete state (one-time use)
|
||||||
|
await redis.delete(key)
|
||||||
|
|
||||||
|
# Parse and validate
|
||||||
|
stored_provider, redirect_uri = value.split(":", 1)
|
||||||
|
if stored_provider != provider:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return redirect_uri
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_tokens_for_user(user_id) -> TokenResponse:
|
||||||
|
"""Create access and refresh tokens for a user."""
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
if isinstance(user_id, str):
|
||||||
|
user_id = UUID(user_id)
|
||||||
|
|
||||||
|
access_token = create_access_token(user_id)
|
||||||
|
refresh_token, jti = create_refresh_token(user_id)
|
||||||
|
|
||||||
|
# Store refresh token in Redis for revocation tracking
|
||||||
|
expires_at = get_refresh_token_expiration()
|
||||||
|
await token_store.store_refresh_token(user_id, jti, expires_at)
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
expires_in=get_token_expiration_seconds(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Google OAuth
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/google")
|
||||||
|
async def google_auth_redirect(
|
||||||
|
redirect_uri: str = Query(..., description="URI to redirect to after OAuth completes"),
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Start Google OAuth flow.
|
||||||
|
|
||||||
|
Redirects the user to Google's OAuth consent screen.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redirect_uri: Where to redirect after successful authentication.
|
||||||
|
This is YOUR app's callback, not the OAuth callback.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Redirect to Google OAuth authorization URL.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 501 if Google OAuth is not configured.
|
||||||
|
"""
|
||||||
|
if not google_oauth.is_configured():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
detail="Google OAuth is not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate state for CSRF protection
|
||||||
|
state = secrets.token_urlsafe(32)
|
||||||
|
await _store_oauth_state(state, "google", redirect_uri)
|
||||||
|
|
||||||
|
# Build OAuth callback URL (our server endpoint)
|
||||||
|
# The redirect_uri param here is where Google sends the code
|
||||||
|
# Must be an absolute URL for OAuth providers
|
||||||
|
oauth_callback = f"{settings.base_url}/api/auth/google/callback"
|
||||||
|
|
||||||
|
# Get authorization URL
|
||||||
|
auth_url = google_oauth.get_authorization_url(
|
||||||
|
redirect_uri=oauth_callback,
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RedirectResponse(url=auth_url, status_code=status.HTTP_302_FOUND)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/google/callback")
|
||||||
|
async def google_auth_callback(
|
||||||
|
db: DbSession,
|
||||||
|
code: str = Query(..., description="Authorization code from Google"),
|
||||||
|
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||||
|
) -> TokenResponse:
|
||||||
|
"""Handle Google OAuth callback.
|
||||||
|
|
||||||
|
Exchanges the authorization code for tokens and creates/fetches the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Authorization code from Google.
|
||||||
|
state: State parameter for CSRF validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JWT access and refresh tokens.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 400 if state is invalid or OAuth fails.
|
||||||
|
"""
|
||||||
|
# Validate state
|
||||||
|
redirect_uri = await _validate_oauth_state(state, "google")
|
||||||
|
if redirect_uri is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid or expired state parameter",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Exchange code for user info
|
||||||
|
oauth_callback = f"{settings.base_url}/api/auth/google/callback"
|
||||||
|
user_info = await google_oauth.get_user_info(code, oauth_callback)
|
||||||
|
|
||||||
|
# Get or create user
|
||||||
|
user, created = await user_service.get_or_create_from_oauth(db, user_info)
|
||||||
|
|
||||||
|
# Update last login
|
||||||
|
await user_service.update_last_login(db, user)
|
||||||
|
|
||||||
|
# Create tokens
|
||||||
|
return await _create_tokens_for_user(user.id)
|
||||||
|
|
||||||
|
except GoogleOAuthError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Google OAuth failed: {e}",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Discord OAuth
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/discord")
|
||||||
|
async def discord_auth_redirect(
|
||||||
|
redirect_uri: str = Query(..., description="URI to redirect to after OAuth completes"),
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Start Discord OAuth flow.
|
||||||
|
|
||||||
|
Redirects the user to Discord's OAuth consent screen.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redirect_uri: Where to redirect after successful authentication.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Redirect to Discord OAuth authorization URL.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 501 if Discord OAuth is not configured.
|
||||||
|
"""
|
||||||
|
if not discord_oauth.is_configured():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
detail="Discord OAuth is not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate state for CSRF protection
|
||||||
|
state = secrets.token_urlsafe(32)
|
||||||
|
await _store_oauth_state(state, "discord", redirect_uri)
|
||||||
|
|
||||||
|
# Build OAuth callback URL (must be absolute for OAuth providers)
|
||||||
|
oauth_callback = f"{settings.base_url}/api/auth/discord/callback"
|
||||||
|
|
||||||
|
# Get authorization URL
|
||||||
|
auth_url = discord_oauth.get_authorization_url(
|
||||||
|
redirect_uri=oauth_callback,
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RedirectResponse(url=auth_url, status_code=status.HTTP_302_FOUND)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/discord/callback")
|
||||||
|
async def discord_auth_callback(
|
||||||
|
db: DbSession,
|
||||||
|
code: str = Query(..., description="Authorization code from Discord"),
|
||||||
|
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||||
|
) -> TokenResponse:
|
||||||
|
"""Handle Discord OAuth callback.
|
||||||
|
|
||||||
|
Exchanges the authorization code for tokens and creates/fetches the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Authorization code from Discord.
|
||||||
|
state: State parameter for CSRF validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JWT access and refresh tokens.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 400 if state is invalid or OAuth fails.
|
||||||
|
"""
|
||||||
|
# Validate state
|
||||||
|
redirect_uri = await _validate_oauth_state(state, "discord")
|
||||||
|
if redirect_uri is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid or expired state parameter",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Exchange code for user info
|
||||||
|
oauth_callback = f"{settings.base_url}/api/auth/discord/callback"
|
||||||
|
user_info = await discord_oauth.get_user_info(code, oauth_callback)
|
||||||
|
|
||||||
|
# Get or create user
|
||||||
|
user, created = await user_service.get_or_create_from_oauth(db, user_info)
|
||||||
|
|
||||||
|
# Update last login
|
||||||
|
await user_service.update_last_login(db, user)
|
||||||
|
|
||||||
|
# Create tokens
|
||||||
|
return await _create_tokens_for_user(user.id)
|
||||||
|
|
||||||
|
except DiscordOAuthError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Discord OAuth failed: {e}",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Token Management
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=TokenResponse)
|
||||||
|
async def refresh_tokens(
|
||||||
|
db: DbSession,
|
||||||
|
request: RefreshTokenRequest,
|
||||||
|
) -> TokenResponse:
|
||||||
|
"""Refresh access token using refresh token.
|
||||||
|
|
||||||
|
Validates the refresh token and issues a new access token.
|
||||||
|
The refresh token itself is NOT rotated (same token can be used
|
||||||
|
until it expires or is revoked).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Contains the refresh token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New JWT access token (same refresh token).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 401 if refresh token is invalid or revoked.
|
||||||
|
"""
|
||||||
|
# Verify refresh token
|
||||||
|
result = verify_refresh_token(request.refresh_token)
|
||||||
|
if result is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id, jti = result
|
||||||
|
|
||||||
|
# Check if token is revoked
|
||||||
|
is_valid = await token_store.is_token_valid(user_id, jti)
|
||||||
|
if not is_valid:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Refresh token has been revoked",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify user still exists
|
||||||
|
user = await user_service.get_by_id(db, user_id)
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new access token (keep same refresh token)
|
||||||
|
access_token = create_access_token(user_id)
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=request.refresh_token, # Return same refresh token
|
||||||
|
expires_in=get_token_expiration_seconds(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def logout(
|
||||||
|
request: RefreshTokenRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Logout by revoking the refresh token.
|
||||||
|
|
||||||
|
The access token will continue to work until it expires,
|
||||||
|
but the refresh token cannot be used to get new tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Contains the refresh token to revoke.
|
||||||
|
"""
|
||||||
|
# Verify refresh token to get user_id and jti
|
||||||
|
result = verify_refresh_token(request.refresh_token)
|
||||||
|
if result is None:
|
||||||
|
# Token is invalid, but that's fine - user is effectively logged out
|
||||||
|
return
|
||||||
|
|
||||||
|
user_id, jti = result
|
||||||
|
|
||||||
|
# Revoke the token
|
||||||
|
await token_store.revoke_token(user_id, jti)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout-all", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def logout_all(
|
||||||
|
user: CurrentUser,
|
||||||
|
) -> None:
|
||||||
|
"""Logout from all devices by revoking all refresh tokens.
|
||||||
|
|
||||||
|
Requires authentication (uses current access token).
|
||||||
|
All refresh tokens for the user will be revoked.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: Current authenticated user.
|
||||||
|
"""
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
|
await token_store.revoke_all_user_tokens(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Account Linking
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def _store_link_state(state: str, provider: str, user_id: str, redirect_uri: str) -> None:
|
||||||
|
"""Store OAuth state for account linking (includes user_id)."""
|
||||||
|
async with get_redis() as redis:
|
||||||
|
key = f"oauth_link_state:{state}"
|
||||||
|
value = f"{provider}:{user_id}:{redirect_uri}"
|
||||||
|
await redis.setex(key, OAUTH_STATE_TTL, value)
|
||||||
|
|
||||||
|
|
||||||
|
async def _validate_link_state(state: str, provider: str) -> tuple[str, str] | None:
|
||||||
|
"""Validate and consume link state, returning (user_id, redirect_uri) if valid."""
|
||||||
|
async with get_redis() as redis:
|
||||||
|
key = f"oauth_link_state:{state}"
|
||||||
|
value = await redis.get(key)
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Delete state (one-time use)
|
||||||
|
await redis.delete(key)
|
||||||
|
|
||||||
|
# Parse and validate
|
||||||
|
parts = value.split(":", 2)
|
||||||
|
if len(parts) != 3:
|
||||||
|
return None
|
||||||
|
|
||||||
|
stored_provider, user_id, redirect_uri = parts
|
||||||
|
if stored_provider != provider:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return user_id, redirect_uri
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/link/google")
|
||||||
|
async def google_link_redirect(
|
||||||
|
user: CurrentUser,
|
||||||
|
redirect_uri: str = Query(..., description="URI to redirect to after linking"),
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Start Google OAuth flow for account linking.
|
||||||
|
|
||||||
|
Requires authentication. Links Google account to the current user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redirect_uri: Where to redirect after linking completes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Redirect to Google OAuth authorization URL.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 501 if Google OAuth is not configured.
|
||||||
|
HTTPException: 400 if Google is already the primary provider.
|
||||||
|
"""
|
||||||
|
if not google_oauth.is_configured():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
detail="Google OAuth is not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if Google is already their primary provider
|
||||||
|
if user.oauth_provider == "google":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Google is already your primary login provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate state for CSRF protection
|
||||||
|
state = secrets.token_urlsafe(32)
|
||||||
|
await _store_link_state(state, "google", str(user.id), redirect_uri)
|
||||||
|
|
||||||
|
# Build OAuth callback URL for linking
|
||||||
|
oauth_callback = f"{settings.base_url}/api/auth/link/google/callback"
|
||||||
|
|
||||||
|
# Get authorization URL
|
||||||
|
auth_url = google_oauth.get_authorization_url(
|
||||||
|
redirect_uri=oauth_callback,
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RedirectResponse(url=auth_url, status_code=status.HTTP_302_FOUND)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/link/google/callback")
|
||||||
|
async def google_link_callback(
|
||||||
|
db: DbSession,
|
||||||
|
code: str = Query(..., description="Authorization code from Google"),
|
||||||
|
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Handle Google OAuth callback for account linking.
|
||||||
|
|
||||||
|
Exchanges the authorization code and links the Google account to the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Authorization code from Google.
|
||||||
|
state: State parameter for CSRF validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Redirect to the original redirect_uri with success/error query params.
|
||||||
|
"""
|
||||||
|
# Validate state
|
||||||
|
result = await _validate_link_state(state, "google")
|
||||||
|
if result is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid or expired state parameter",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id_str, redirect_uri = result
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Exchange code for user info
|
||||||
|
oauth_callback = f"{settings.base_url}/api/auth/link/google/callback"
|
||||||
|
oauth_info = await google_oauth.get_user_info(code, oauth_callback)
|
||||||
|
|
||||||
|
# Get the user
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
user_id = UUID(user_id_str)
|
||||||
|
user = await user_service.get_by_id(db, user_id)
|
||||||
|
if user is None:
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?error=user_not_found",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Link the account
|
||||||
|
await user_service.link_oauth_account(db, user, oauth_info)
|
||||||
|
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?linked=google",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
except GoogleOAuthError as e:
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?error=oauth_failed&message={e}",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
except AccountLinkingError as e:
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?error=linking_failed&message={e}",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/link/discord")
|
||||||
|
async def discord_link_redirect(
|
||||||
|
user: CurrentUser,
|
||||||
|
redirect_uri: str = Query(..., description="URI to redirect to after linking"),
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Start Discord OAuth flow for account linking.
|
||||||
|
|
||||||
|
Requires authentication. Links Discord account to the current user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redirect_uri: Where to redirect after linking completes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Redirect to Discord OAuth authorization URL.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 501 if Discord OAuth is not configured.
|
||||||
|
HTTPException: 400 if Discord is already the primary provider.
|
||||||
|
"""
|
||||||
|
if not discord_oauth.is_configured():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
detail="Discord OAuth is not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if Discord is already their primary provider
|
||||||
|
if user.oauth_provider == "discord":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Discord is already your primary login provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate state for CSRF protection
|
||||||
|
state = secrets.token_urlsafe(32)
|
||||||
|
await _store_link_state(state, "discord", str(user.id), redirect_uri)
|
||||||
|
|
||||||
|
# Build OAuth callback URL for linking
|
||||||
|
oauth_callback = f"{settings.base_url}/api/auth/link/discord/callback"
|
||||||
|
|
||||||
|
# Get authorization URL
|
||||||
|
auth_url = discord_oauth.get_authorization_url(
|
||||||
|
redirect_uri=oauth_callback,
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RedirectResponse(url=auth_url, status_code=status.HTTP_302_FOUND)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/link/discord/callback")
|
||||||
|
async def discord_link_callback(
|
||||||
|
db: DbSession,
|
||||||
|
code: str = Query(..., description="Authorization code from Discord"),
|
||||||
|
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Handle Discord OAuth callback for account linking.
|
||||||
|
|
||||||
|
Exchanges the authorization code and links the Discord account to the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Authorization code from Discord.
|
||||||
|
state: State parameter for CSRF validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Redirect to the original redirect_uri with success/error query params.
|
||||||
|
"""
|
||||||
|
# Validate state
|
||||||
|
result = await _validate_link_state(state, "discord")
|
||||||
|
if result is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid or expired state parameter",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id_str, redirect_uri = result
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Exchange code for user info
|
||||||
|
oauth_callback = f"{settings.base_url}/api/auth/link/discord/callback"
|
||||||
|
oauth_info = await discord_oauth.get_user_info(code, oauth_callback)
|
||||||
|
|
||||||
|
# Get the user
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
user_id = UUID(user_id_str)
|
||||||
|
user = await user_service.get_by_id(db, user_id)
|
||||||
|
if user is None:
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?error=user_not_found",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Link the account
|
||||||
|
await user_service.link_oauth_account(db, user, oauth_info)
|
||||||
|
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?linked=discord",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
except DiscordOAuthError as e:
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?error=oauth_failed&message={e}",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
except AccountLinkingError as e:
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?error=linking_failed&message={e}",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
172
backend/app/api/deps.py
Normal file
172
backend/app/api/deps.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
"""FastAPI dependencies for Mantimon TCG API.
|
||||||
|
|
||||||
|
This module provides dependency injection functions for authentication
|
||||||
|
and database access in API endpoints.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from app.api.deps import get_current_user, get_db
|
||||||
|
|
||||||
|
@router.get("/me")
|
||||||
|
async def get_me(
|
||||||
|
user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
return user
|
||||||
|
|
||||||
|
Dependencies:
|
||||||
|
- get_db: Async database session
|
||||||
|
- get_current_user: Authenticated user from JWT (required)
|
||||||
|
- get_optional_user: Authenticated user or None
|
||||||
|
- get_current_premium_user: User with active premium
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import get_session_dependency
|
||||||
|
from app.db.models import User
|
||||||
|
from app.services.jwt_service import verify_access_token
|
||||||
|
from app.services.user_service import user_service
|
||||||
|
|
||||||
|
# OAuth2 scheme for extracting Bearer token from Authorization header
|
||||||
|
# tokenUrl is for OpenAPI docs - points to where tokens are obtained
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(
|
||||||
|
tokenUrl="/api/auth/token", # For OpenAPI docs
|
||||||
|
auto_error=True, # Raise 401 if no token
|
||||||
|
)
|
||||||
|
|
||||||
|
oauth2_scheme_optional = OAuth2PasswordBearer(
|
||||||
|
tokenUrl="/api/auth/token",
|
||||||
|
auto_error=False, # Return None if no token
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db() -> AsyncSession:
|
||||||
|
"""Get async database session.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
AsyncSession for database operations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@router.get("/items")
|
||||||
|
async def get_items(db: AsyncSession = Depends(get_db)):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
async with get_session_dependency() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: Annotated[str, Depends(oauth2_scheme)],
|
||||||
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
) -> User:
|
||||||
|
"""Get the current authenticated user from JWT token.
|
||||||
|
|
||||||
|
Validates the access token and fetches the user from database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT access token from Authorization header.
|
||||||
|
db: Database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The authenticated User.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 401 if token is invalid or user not found.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@router.get("/me")
|
||||||
|
async def get_me(user: User = Depends(get_current_user)):
|
||||||
|
return {"email": user.email}
|
||||||
|
"""
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify token and extract user ID
|
||||||
|
user_id = verify_access_token(token)
|
||||||
|
if user_id is None:
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
# Fetch user from database
|
||||||
|
user = await user_service.get_by_id(db, user_id)
|
||||||
|
if user is None:
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_optional_user(
|
||||||
|
token: Annotated[str | None, Depends(oauth2_scheme_optional)],
|
||||||
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
) -> User | None:
|
||||||
|
"""Get the current user if authenticated, or None.
|
||||||
|
|
||||||
|
Useful for endpoints that work both with and without authentication,
|
||||||
|
but may provide additional features for authenticated users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT access token or None.
|
||||||
|
db: Database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The authenticated User, or None if not authenticated.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@router.get("/cards")
|
||||||
|
async def get_cards(user: User | None = Depends(get_optional_user)):
|
||||||
|
if user:
|
||||||
|
# Show user's collection
|
||||||
|
else:
|
||||||
|
# Show public cards
|
||||||
|
"""
|
||||||
|
if token is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_id = verify_access_token(token)
|
||||||
|
if user_id is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await user_service.get_by_id(db, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_premium_user(
|
||||||
|
user: Annotated[User, Depends(get_current_user)],
|
||||||
|
) -> User:
|
||||||
|
"""Get the current user and verify they have premium.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: The authenticated user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The authenticated User with active premium.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 403 if user doesn't have active premium.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@router.post("/decks")
|
||||||
|
async def create_unlimited_decks(
|
||||||
|
user: User = Depends(get_current_premium_user)
|
||||||
|
):
|
||||||
|
# Only premium users can have unlimited decks
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
if not user.has_active_premium:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Premium subscription required",
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
# Type aliases for cleaner endpoint signatures
|
||||||
|
CurrentUser = Annotated[User, Depends(get_current_user)]
|
||||||
|
OptionalUser = Annotated[User | None, Depends(get_optional_user)]
|
||||||
|
PremiumUser = Annotated[User, Depends(get_current_premium_user)]
|
||||||
|
DbSession = Annotated[AsyncSession, Depends(get_db)]
|
||||||
168
backend/app/api/users.py
Normal file
168
backend/app/api/users.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
"""User API router for Mantimon TCG.
|
||||||
|
|
||||||
|
This module provides endpoints for user profile management:
|
||||||
|
- Get current user profile
|
||||||
|
- Update profile (display name, avatar)
|
||||||
|
- List linked OAuth accounts
|
||||||
|
- Session management
|
||||||
|
|
||||||
|
All endpoints require authentication.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Get current user
|
||||||
|
GET /api/users/me
|
||||||
|
Authorization: Bearer <access_token>
|
||||||
|
|
||||||
|
# Update profile
|
||||||
|
PATCH /api/users/me
|
||||||
|
{"display_name": "NewName"}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.api.deps import CurrentUser, DbSession
|
||||||
|
from app.schemas.user import UserResponse, UserUpdate
|
||||||
|
from app.services.token_store import token_store
|
||||||
|
from app.services.user_service import AccountLinkingError, user_service
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/users", tags=["users"])
|
||||||
|
|
||||||
|
|
||||||
|
class LinkedAccountResponse(BaseModel):
|
||||||
|
"""Response for a linked OAuth account."""
|
||||||
|
|
||||||
|
provider: str = Field(..., description="OAuth provider name")
|
||||||
|
email: str | None = Field(None, description="Email from this provider")
|
||||||
|
linked_at: str = Field(..., description="When account was linked (ISO format)")
|
||||||
|
|
||||||
|
|
||||||
|
class SessionsResponse(BaseModel):
|
||||||
|
"""Response for active sessions count."""
|
||||||
|
|
||||||
|
active_sessions: int = Field(..., description="Number of active sessions")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
async def get_current_user_profile(
|
||||||
|
user: CurrentUser,
|
||||||
|
) -> UserResponse:
|
||||||
|
"""Get the current user's profile.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User profile information.
|
||||||
|
"""
|
||||||
|
return UserResponse.model_validate(user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/me", response_model=UserResponse)
|
||||||
|
async def update_current_user_profile(
|
||||||
|
user: CurrentUser,
|
||||||
|
db: DbSession,
|
||||||
|
update_data: UserUpdate,
|
||||||
|
) -> UserResponse:
|
||||||
|
"""Update the current user's profile.
|
||||||
|
|
||||||
|
Only provided fields are updated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_data: Fields to update (display_name, avatar_url).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated user profile.
|
||||||
|
"""
|
||||||
|
updated_user = await user_service.update(db, user, update_data)
|
||||||
|
return UserResponse.model_validate(updated_user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me/linked-accounts", response_model=list[LinkedAccountResponse])
|
||||||
|
async def get_linked_accounts(
|
||||||
|
user: CurrentUser,
|
||||||
|
) -> list[LinkedAccountResponse]:
|
||||||
|
"""Get all OAuth accounts linked to the current user.
|
||||||
|
|
||||||
|
Returns the primary OAuth provider plus any additional linked accounts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of linked OAuth accounts.
|
||||||
|
"""
|
||||||
|
accounts = []
|
||||||
|
|
||||||
|
# Add primary OAuth account
|
||||||
|
accounts.append(
|
||||||
|
LinkedAccountResponse(
|
||||||
|
provider=user.oauth_provider,
|
||||||
|
email=user.email,
|
||||||
|
linked_at=user.created_at.isoformat(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add additional linked accounts
|
||||||
|
for linked in user.linked_accounts:
|
||||||
|
accounts.append(
|
||||||
|
LinkedAccountResponse(
|
||||||
|
provider=linked.provider,
|
||||||
|
email=linked.email,
|
||||||
|
linked_at=linked.linked_at.isoformat(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return accounts
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me/sessions", response_model=SessionsResponse)
|
||||||
|
async def get_active_sessions(
|
||||||
|
user: CurrentUser,
|
||||||
|
) -> SessionsResponse:
|
||||||
|
"""Get the number of active sessions for the current user.
|
||||||
|
|
||||||
|
Each session corresponds to a valid refresh token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of active sessions.
|
||||||
|
"""
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
|
count = await token_store.get_active_session_count(user_id)
|
||||||
|
|
||||||
|
return SessionsResponse(active_sessions=count)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/me/link/{provider}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def unlink_oauth_account(
|
||||||
|
user: CurrentUser,
|
||||||
|
db: DbSession,
|
||||||
|
provider: str,
|
||||||
|
) -> None:
|
||||||
|
"""Unlink an OAuth provider from the current user's account.
|
||||||
|
|
||||||
|
Cannot unlink the primary OAuth provider (the one used to create the account).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: OAuth provider name to unlink ('google' or 'discord').
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 400 if trying to unlink primary provider.
|
||||||
|
HTTPException: 404 if provider is not linked.
|
||||||
|
"""
|
||||||
|
provider = provider.lower()
|
||||||
|
|
||||||
|
if provider not in ("google", "discord"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unknown provider: {provider}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
unlinked = await user_service.unlink_oauth_account(db, user, provider)
|
||||||
|
if not unlinked:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"{provider.title()} is not linked to your account",
|
||||||
|
)
|
||||||
|
except AccountLinkingError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e),
|
||||||
|
) from None
|
||||||
@ -154,6 +154,12 @@ class Settings(BaseSettings):
|
|||||||
description="Allowed CORS origins",
|
description="Allowed CORS origins",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Base URL (for OAuth callbacks and external links)
|
||||||
|
base_url: str = Field(
|
||||||
|
default="http://localhost:8000",
|
||||||
|
description="Base URL of the API server (for OAuth callbacks)",
|
||||||
|
)
|
||||||
|
|
||||||
# Game Settings
|
# Game Settings
|
||||||
turn_timeout_seconds: int = Field(
|
turn_timeout_seconds: int = Field(
|
||||||
default=120,
|
default=120,
|
||||||
|
|||||||
@ -0,0 +1,69 @@
|
|||||||
|
"""add_oauth_linked_accounts
|
||||||
|
|
||||||
|
Revision ID: 5ce887128ab1
|
||||||
|
Revises: ab8a0039fe55
|
||||||
|
Create Date: 2026-01-27 16:42:12.335987
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "5ce887128ab1"
|
||||||
|
down_revision: str | Sequence[str] | None = "ab8a0039fe55"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"oauth_linked_accounts",
|
||||||
|
sa.Column("user_id", sa.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("provider", sa.String(length=20), nullable=False),
|
||||||
|
sa.Column("oauth_id", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("email", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("display_name", sa.String(length=100), nullable=True),
|
||||||
|
sa.Column("avatar_url", sa.String(length=500), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"linked_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("id", sa.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_linked_accounts_provider_oauth_id",
|
||||||
|
"oauth_linked_accounts",
|
||||||
|
["provider", "oauth_id"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_linked_accounts_user_id"), "oauth_linked_accounts", ["user_id"], unique=False
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f("ix_oauth_linked_accounts_user_id"), table_name="oauth_linked_accounts")
|
||||||
|
op.drop_index("ix_oauth_linked_accounts_provider_oauth_id", table_name="oauth_linked_accounts")
|
||||||
|
op.drop_table("oauth_linked_accounts")
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -25,11 +25,14 @@ from app.db.models.campaign import CampaignProgress
|
|||||||
from app.db.models.collection import CardSource, Collection
|
from app.db.models.collection import CardSource, Collection
|
||||||
from app.db.models.deck import Deck
|
from app.db.models.deck import Deck
|
||||||
from app.db.models.game import ActiveGame, EndReason, GameHistory, GameType
|
from app.db.models.game import ActiveGame, EndReason, GameHistory, GameType
|
||||||
|
from app.db.models.oauth_account import OAuthLinkedAccount
|
||||||
from app.db.models.user import User
|
from app.db.models.user import User
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# User
|
# User
|
||||||
"User",
|
"User",
|
||||||
|
# OAuth
|
||||||
|
"OAuthLinkedAccount",
|
||||||
# Collection
|
# Collection
|
||||||
"Collection",
|
"Collection",
|
||||||
"CardSource",
|
"CardSource",
|
||||||
|
|||||||
122
backend/app/db/models/oauth_account.py
Normal file
122
backend/app/db/models/oauth_account.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
"""OAuth linked account model for Mantimon TCG.
|
||||||
|
|
||||||
|
This module defines the OAuthLinkedAccount model for supporting multiple
|
||||||
|
OAuth providers per user (account linking).
|
||||||
|
|
||||||
|
A user can have multiple linked accounts (e.g., both Google and Discord),
|
||||||
|
allowing them to log in with either provider.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# User links Discord to their existing Google account
|
||||||
|
linked_account = OAuthLinkedAccount(
|
||||||
|
user_id=user.id,
|
||||||
|
provider="discord",
|
||||||
|
oauth_id="123456789",
|
||||||
|
email="player@example.com"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, ForeignKey, Index, String, func
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.db.base import Base
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.db.models.user import User
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthLinkedAccount(Base):
|
||||||
|
"""Linked OAuth account for multi-provider authentication.
|
||||||
|
|
||||||
|
Allows users to link multiple OAuth providers to a single account,
|
||||||
|
enabling login via any linked provider.
|
||||||
|
|
||||||
|
The User model still has oauth_provider/oauth_id for the "primary"
|
||||||
|
provider (the one used to create the account). This table tracks
|
||||||
|
additional linked providers.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Unique identifier (UUID).
|
||||||
|
user_id: Foreign key to the user who owns this linked account.
|
||||||
|
provider: OAuth provider name ('google', 'discord').
|
||||||
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
email: Email address from this OAuth provider (may differ from user's primary email).
|
||||||
|
display_name: Display name from this OAuth provider.
|
||||||
|
avatar_url: Avatar URL from this OAuth provider.
|
||||||
|
linked_at: When this account was linked.
|
||||||
|
|
||||||
|
Relationships:
|
||||||
|
user: The User who owns this linked account.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "oauth_linked_accounts"
|
||||||
|
|
||||||
|
# Foreign key to user
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False),
|
||||||
|
ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
doc="User who owns this linked account",
|
||||||
|
)
|
||||||
|
|
||||||
|
# OAuth provider info
|
||||||
|
provider: Mapped[str] = mapped_column(
|
||||||
|
String(20),
|
||||||
|
nullable=False,
|
||||||
|
doc="OAuth provider name (google, discord)",
|
||||||
|
)
|
||||||
|
oauth_id: Mapped[str] = mapped_column(
|
||||||
|
String(255),
|
||||||
|
nullable=False,
|
||||||
|
doc="Unique ID from OAuth provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Additional info from provider
|
||||||
|
email: Mapped[str | None] = mapped_column(
|
||||||
|
String(255),
|
||||||
|
nullable=True,
|
||||||
|
doc="Email from this OAuth provider",
|
||||||
|
)
|
||||||
|
display_name: Mapped[str | None] = mapped_column(
|
||||||
|
String(100),
|
||||||
|
nullable=True,
|
||||||
|
doc="Display name from this OAuth provider",
|
||||||
|
)
|
||||||
|
avatar_url: Mapped[str | None] = mapped_column(
|
||||||
|
String(500),
|
||||||
|
nullable=True,
|
||||||
|
doc="Avatar URL from this OAuth provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
# When linked
|
||||||
|
linked_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
nullable=False,
|
||||||
|
doc="When this account was linked",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationship back to user
|
||||||
|
user: Mapped["User"] = relationship(
|
||||||
|
"User",
|
||||||
|
back_populates="linked_accounts",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Indexes and constraints
|
||||||
|
__table_args__ = (
|
||||||
|
# Each OAuth provider+ID can only be linked to one user
|
||||||
|
Index(
|
||||||
|
"ix_oauth_linked_accounts_provider_oauth_id",
|
||||||
|
"provider",
|
||||||
|
"oauth_id",
|
||||||
|
unique=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<OAuthLinkedAccount(user_id={self.user_id!r}, provider={self.provider!r})>"
|
||||||
@ -24,6 +24,7 @@ if TYPE_CHECKING:
|
|||||||
from app.db.models.campaign import CampaignProgress
|
from app.db.models.campaign import CampaignProgress
|
||||||
from app.db.models.collection import Collection
|
from app.db.models.collection import Collection
|
||||||
from app.db.models.deck import Deck
|
from app.db.models.deck import Deck
|
||||||
|
from app.db.models.oauth_account import OAuthLinkedAccount
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
@ -45,6 +46,7 @@ class User(Base):
|
|||||||
decks: User's deck configurations.
|
decks: User's deck configurations.
|
||||||
collection: User's card collection.
|
collection: User's card collection.
|
||||||
campaign_progress: User's campaign state.
|
campaign_progress: User's campaign state.
|
||||||
|
linked_accounts: Additional linked OAuth providers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
@ -120,6 +122,12 @@ class User(Base):
|
|||||||
uselist=False,
|
uselist=False,
|
||||||
lazy="selectin",
|
lazy="selectin",
|
||||||
)
|
)
|
||||||
|
linked_accounts: Mapped[list["OAuthLinkedAccount"]] = relationship(
|
||||||
|
"OAuthLinkedAccount",
|
||||||
|
back_populates="user",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
lazy="selectin",
|
||||||
|
)
|
||||||
|
|
||||||
# Indexes
|
# Indexes
|
||||||
__table_args__ = (Index("ix_users_oauth", "oauth_provider", "oauth_id", unique=True),)
|
__table_args__ = (Index("ix_users_oauth", "oauth_provider", "oauth_id", unique=True),)
|
||||||
|
|||||||
@ -18,6 +18,8 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.api.auth import router as auth_router
|
||||||
|
from app.api.users import router as users_router
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.db import close_db, init_db
|
from app.db import close_db, init_db
|
||||||
from app.db.redis import close_redis, init_redis
|
from app.db.redis import close_redis, init_redis
|
||||||
@ -159,9 +161,11 @@ async def readiness_check() -> dict[str, str | int]:
|
|||||||
|
|
||||||
|
|
||||||
# === API Routers ===
|
# === API Routers ===
|
||||||
# TODO: Add routers in Phase 2
|
app.include_router(auth_router, prefix="/api")
|
||||||
# from app.api import auth, games, cards, decks, campaign
|
app.include_router(users_router, prefix="/api")
|
||||||
# app.include_router(auth.router, prefix="/api/auth", tags=["auth"])
|
|
||||||
|
# TODO: Add remaining routers in future phases
|
||||||
|
# from app.api import cards, decks, games, campaign
|
||||||
# app.include_router(cards.router, prefix="/api/cards", tags=["cards"])
|
# app.include_router(cards.router, prefix="/api/cards", tags=["cards"])
|
||||||
# app.include_router(decks.router, prefix="/api/decks", tags=["decks"])
|
# app.include_router(decks.router, prefix="/api/decks", tags=["decks"])
|
||||||
# app.include_router(games.router, prefix="/api/games", tags=["games"])
|
# app.include_router(games.router, prefix="/api/games", tags=["games"])
|
||||||
|
|||||||
32
backend/app/schemas/__init__.py
Normal file
32
backend/app/schemas/__init__.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
"""Pydantic schemas for Mantimon TCG API.
|
||||||
|
|
||||||
|
This package contains request/response models for all API endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.schemas.auth import (
|
||||||
|
OAuthState,
|
||||||
|
RefreshTokenRequest,
|
||||||
|
TokenPayload,
|
||||||
|
TokenResponse,
|
||||||
|
TokenType,
|
||||||
|
)
|
||||||
|
from app.schemas.user import (
|
||||||
|
OAuthUserInfo,
|
||||||
|
UserCreate,
|
||||||
|
UserResponse,
|
||||||
|
UserUpdate,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Auth schemas
|
||||||
|
"TokenType",
|
||||||
|
"TokenPayload",
|
||||||
|
"TokenResponse",
|
||||||
|
"RefreshTokenRequest",
|
||||||
|
"OAuthState",
|
||||||
|
# User schemas
|
||||||
|
"UserResponse",
|
||||||
|
"UserCreate",
|
||||||
|
"UserUpdate",
|
||||||
|
"OAuthUserInfo",
|
||||||
|
]
|
||||||
87
backend/app/schemas/auth.py
Normal file
87
backend/app/schemas/auth.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
"""Authentication schemas for Mantimon TCG.
|
||||||
|
|
||||||
|
This module defines Pydantic models for JWT tokens and authentication
|
||||||
|
responses used throughout the auth system.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
token_payload = TokenPayload(
|
||||||
|
sub="550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
exp=datetime.now(UTC) + timedelta(minutes=30),
|
||||||
|
iat=datetime.now(UTC),
|
||||||
|
type="access"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TokenType(str, Enum):
|
||||||
|
"""Type of JWT token."""
|
||||||
|
|
||||||
|
ACCESS = "access"
|
||||||
|
REFRESH = "refresh"
|
||||||
|
|
||||||
|
|
||||||
|
class TokenPayload(BaseModel):
|
||||||
|
"""JWT token payload structure.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
sub: Subject - the user ID as a string (UUID format).
|
||||||
|
exp: Expiration timestamp.
|
||||||
|
iat: Issued-at timestamp.
|
||||||
|
type: Token type (access or refresh).
|
||||||
|
jti: JWT ID - unique identifier for refresh token tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sub: str = Field(..., description="User ID (UUID as string)")
|
||||||
|
exp: datetime = Field(..., description="Token expiration timestamp")
|
||||||
|
iat: datetime = Field(..., description="Token issued-at timestamp")
|
||||||
|
type: TokenType = Field(..., description="Token type (access/refresh)")
|
||||||
|
jti: str | None = Field(default=None, description="JWT ID for refresh token tracking")
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
"""Response containing JWT tokens after successful authentication.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
access_token: Short-lived JWT for API authentication.
|
||||||
|
refresh_token: Longer-lived JWT for obtaining new access tokens.
|
||||||
|
token_type: Always "bearer" for OAuth 2.0 compatibility.
|
||||||
|
expires_in: Access token lifetime in seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
access_token: str = Field(..., description="JWT access token")
|
||||||
|
refresh_token: str = Field(..., description="JWT refresh token")
|
||||||
|
token_type: str = Field(default="bearer", description="Token type (always bearer)")
|
||||||
|
expires_in: int = Field(..., description="Access token lifetime in seconds")
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenRequest(BaseModel):
|
||||||
|
"""Request to refresh an access token.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
refresh_token: The refresh token to exchange for a new access token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
refresh_token: str = Field(..., description="Refresh token to exchange")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthState(BaseModel):
|
||||||
|
"""OAuth state parameter for CSRF protection.
|
||||||
|
|
||||||
|
Stored in Redis with short TTL during OAuth flow.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
state: Random string for CSRF protection.
|
||||||
|
redirect_uri: Where to redirect after OAuth callback.
|
||||||
|
provider: OAuth provider name.
|
||||||
|
created_at: When the state was created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
state: str = Field(..., description="Random CSRF protection string")
|
||||||
|
redirect_uri: str = Field(..., description="Post-auth redirect URI")
|
||||||
|
provider: str = Field(..., description="OAuth provider (google, discord)")
|
||||||
|
created_at: datetime = Field(..., description="State creation timestamp")
|
||||||
114
backend/app/schemas/user.py
Normal file
114
backend/app/schemas/user.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
"""User schemas for Mantimon TCG.
|
||||||
|
|
||||||
|
This module defines Pydantic models for user-related API requests
|
||||||
|
and responses.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
user_response = UserResponse(
|
||||||
|
id="550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
email="player@example.com",
|
||||||
|
display_name="Player1",
|
||||||
|
is_premium=False,
|
||||||
|
created_at=datetime.now(UTC)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, EmailStr, Field
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
"""Public user information returned by API endpoints.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: User's unique identifier.
|
||||||
|
email: User's email address.
|
||||||
|
display_name: User's public display name.
|
||||||
|
avatar_url: URL to user's avatar image.
|
||||||
|
is_premium: Whether user has active premium subscription.
|
||||||
|
premium_until: When premium subscription expires (if premium).
|
||||||
|
created_at: When the account was created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: UUID = Field(..., description="User ID")
|
||||||
|
email: EmailStr = Field(..., description="User's email address")
|
||||||
|
display_name: str = Field(..., description="Public display name")
|
||||||
|
avatar_url: str | None = Field(default=None, description="Avatar image URL")
|
||||||
|
is_premium: bool = Field(default=False, description="Premium subscription status")
|
||||||
|
premium_until: datetime | None = Field(default=None, description="Premium expiration date")
|
||||||
|
created_at: datetime = Field(..., description="Account creation date")
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class UserCreate(BaseModel):
|
||||||
|
"""Internal schema for creating a user from OAuth data.
|
||||||
|
|
||||||
|
Not exposed via API - used internally by auth service.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
email: User's email from OAuth provider.
|
||||||
|
display_name: User's name from OAuth provider.
|
||||||
|
avatar_url: Avatar URL from OAuth provider.
|
||||||
|
oauth_provider: OAuth provider name (google, discord).
|
||||||
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
"""
|
||||||
|
|
||||||
|
email: EmailStr = Field(..., description="Email from OAuth provider")
|
||||||
|
display_name: str = Field(..., max_length=50, description="Display name")
|
||||||
|
avatar_url: str | None = Field(default=None, description="Avatar URL")
|
||||||
|
oauth_provider: str = Field(..., description="OAuth provider (google, discord)")
|
||||||
|
oauth_id: str = Field(..., description="Unique ID from OAuth provider")
|
||||||
|
|
||||||
|
|
||||||
|
class UserUpdate(BaseModel):
|
||||||
|
"""Schema for updating user profile.
|
||||||
|
|
||||||
|
All fields are optional - only provided fields are updated.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
display_name: New display name.
|
||||||
|
avatar_url: New avatar URL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
display_name: str | None = Field(
|
||||||
|
default=None, min_length=1, max_length=50, description="New display name"
|
||||||
|
)
|
||||||
|
avatar_url: str | None = Field(default=None, description="New avatar URL")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthUserInfo(BaseModel):
|
||||||
|
"""Normalized user information from OAuth providers.
|
||||||
|
|
||||||
|
This provides a consistent structure regardless of whether
|
||||||
|
the user authenticated via Google or Discord.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
provider: OAuth provider name.
|
||||||
|
oauth_id: Unique ID from the provider.
|
||||||
|
email: User's email address.
|
||||||
|
name: User's display name from provider.
|
||||||
|
avatar_url: Avatar URL from provider.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: str = Field(..., description="OAuth provider (google, discord)")
|
||||||
|
oauth_id: str = Field(..., description="Unique ID from provider")
|
||||||
|
email: EmailStr = Field(..., description="User's email")
|
||||||
|
name: str = Field(..., description="User's name from provider")
|
||||||
|
avatar_url: str | None = Field(default=None, description="Avatar URL from provider")
|
||||||
|
|
||||||
|
def to_user_create(self) -> UserCreate:
|
||||||
|
"""Convert OAuth info to UserCreate schema.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserCreate instance ready for user creation.
|
||||||
|
"""
|
||||||
|
return UserCreate(
|
||||||
|
email=self.email,
|
||||||
|
display_name=self.name[:50], # Enforce max length
|
||||||
|
avatar_url=self.avatar_url,
|
||||||
|
oauth_provider=self.provider,
|
||||||
|
oauth_id=self.oauth_id,
|
||||||
|
)
|
||||||
206
backend/app/services/jwt_service.py
Normal file
206
backend/app/services/jwt_service.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
"""JWT token service for Mantimon TCG.
|
||||||
|
|
||||||
|
This module provides functions for creating and verifying JWT tokens
|
||||||
|
used in the authentication system.
|
||||||
|
|
||||||
|
Token Types:
|
||||||
|
- Access tokens: Short-lived (30 min default), used for API authentication
|
||||||
|
- Refresh tokens: Longer-lived (7 days default), used to obtain new access tokens
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from app.services.jwt_service import create_access_token, verify_token
|
||||||
|
|
||||||
|
# Create tokens
|
||||||
|
access_token = create_access_token(user_id)
|
||||||
|
refresh_token, jti = create_refresh_token(user_id)
|
||||||
|
|
||||||
|
# Verify token
|
||||||
|
user_id = verify_token(access_token)
|
||||||
|
if user_id:
|
||||||
|
print(f"Valid token for user {user_id}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.schemas.auth import TokenPayload, TokenType
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(user_id: uuid.UUID) -> str:
|
||||||
|
"""Create a short-lived JWT access token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID to encode in the token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encoded JWT access token string.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
token = create_access_token(user.id)
|
||||||
|
# Use token in Authorization header: Bearer {token}
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
expires = now + timedelta(minutes=settings.jwt_expire_minutes)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"sub": str(user_id),
|
||||||
|
"exp": expires,
|
||||||
|
"iat": now,
|
||||||
|
"type": TokenType.ACCESS.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwt.encode(
|
||||||
|
payload,
|
||||||
|
settings.secret_key.get_secret_value(),
|
||||||
|
algorithm=settings.jwt_algorithm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token(user_id: uuid.UUID) -> tuple[str, str]:
|
||||||
|
"""Create a longer-lived JWT refresh token.
|
||||||
|
|
||||||
|
Refresh tokens include a JTI (JWT ID) for tracking and revocation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID to encode in the token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (encoded JWT refresh token, JTI string).
|
||||||
|
The JTI should be stored in Redis for revocation tracking.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
token, jti = create_refresh_token(user.id)
|
||||||
|
await token_store.store_refresh_token(user.id, jti, expires_at)
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
expires = now + timedelta(days=settings.jwt_refresh_expire_days)
|
||||||
|
jti = str(uuid.uuid4())
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"sub": str(user_id),
|
||||||
|
"exp": expires,
|
||||||
|
"iat": now,
|
||||||
|
"type": TokenType.REFRESH.value,
|
||||||
|
"jti": jti,
|
||||||
|
}
|
||||||
|
|
||||||
|
token = jwt.encode(
|
||||||
|
payload,
|
||||||
|
settings.secret_key.get_secret_value(),
|
||||||
|
algorithm=settings.jwt_algorithm,
|
||||||
|
)
|
||||||
|
|
||||||
|
return token, jti
|
||||||
|
|
||||||
|
|
||||||
|
def decode_token(token: str) -> TokenPayload | None:
|
||||||
|
"""Decode and validate a JWT token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The JWT token string to decode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenPayload if valid, None if invalid or expired.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
payload = decode_token(token)
|
||||||
|
if payload:
|
||||||
|
user_id = UUID(payload.sub)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload_dict = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.secret_key.get_secret_value(),
|
||||||
|
algorithms=[settings.jwt_algorithm],
|
||||||
|
)
|
||||||
|
return TokenPayload(**payload_dict)
|
||||||
|
except JWTError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def verify_access_token(token: str) -> uuid.UUID | None:
|
||||||
|
"""Verify an access token and extract the user ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The JWT access token to verify.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User UUID if valid access token, None otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
user_id = verify_access_token(token)
|
||||||
|
if user_id:
|
||||||
|
user = await user_service.get_user_by_id(db, user_id)
|
||||||
|
"""
|
||||||
|
payload = decode_token(token)
|
||||||
|
if payload is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if payload.type != TokenType.ACCESS:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return uuid.UUID(payload.sub)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def verify_refresh_token(token: str) -> tuple[uuid.UUID, str] | None:
|
||||||
|
"""Verify a refresh token and extract user ID and JTI.
|
||||||
|
|
||||||
|
The JTI should be checked against the token store to ensure
|
||||||
|
the token hasn't been revoked.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The JWT refresh token to verify.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (user UUID, JTI) if valid refresh token, None otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
result = verify_refresh_token(token)
|
||||||
|
if result:
|
||||||
|
user_id, jti = result
|
||||||
|
if await token_store.is_token_valid(user_id, jti):
|
||||||
|
# Issue new access token
|
||||||
|
"""
|
||||||
|
payload = decode_token(token)
|
||||||
|
if payload is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if payload.type != TokenType.REFRESH:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if payload.jti is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_id = uuid.UUID(payload.sub)
|
||||||
|
return user_id, payload.jti
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_expiration_seconds() -> int:
|
||||||
|
"""Get the access token expiration time in seconds.
|
||||||
|
|
||||||
|
Useful for the expires_in field in token responses.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Access token lifetime in seconds.
|
||||||
|
"""
|
||||||
|
return settings.jwt_expire_minutes * 60
|
||||||
|
|
||||||
|
|
||||||
|
def get_refresh_token_expiration() -> datetime:
|
||||||
|
"""Get the expiration datetime for a new refresh token.
|
||||||
|
|
||||||
|
Useful for setting TTL when storing in Redis.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Datetime when a refresh token created now would expire.
|
||||||
|
"""
|
||||||
|
return datetime.now(UTC) + timedelta(days=settings.jwt_refresh_expire_days)
|
||||||
25
backend/app/services/oauth/__init__.py
Normal file
25
backend/app/services/oauth/__init__.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
"""OAuth provider services for Mantimon TCG.
|
||||||
|
|
||||||
|
This package contains OAuth integration services for supported providers.
|
||||||
|
|
||||||
|
Providers:
|
||||||
|
- Google: OAuth 2.0 with Google accounts
|
||||||
|
- Discord: OAuth 2.0 with Discord accounts
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from app.services.oauth import google_oauth, discord_oauth
|
||||||
|
|
||||||
|
# Get authorization URL
|
||||||
|
auth_url = google_oauth.get_authorization_url(redirect_uri, state)
|
||||||
|
|
||||||
|
# Exchange code for user info
|
||||||
|
user_info = await google_oauth.get_user_info(code, redirect_uri)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.services.oauth.discord import discord_oauth
|
||||||
|
from app.services.oauth.google import google_oauth
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"google_oauth",
|
||||||
|
"discord_oauth",
|
||||||
|
]
|
||||||
242
backend/app/services/oauth/discord.py
Normal file
242
backend/app/services/oauth/discord.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
"""Discord OAuth service for Mantimon TCG.
|
||||||
|
|
||||||
|
This module handles Discord OAuth 2.0 authentication flow:
|
||||||
|
1. Generate authorization URL for user redirect
|
||||||
|
2. Exchange authorization code for tokens
|
||||||
|
3. Fetch user information from Discord
|
||||||
|
|
||||||
|
Discord OAuth Endpoints:
|
||||||
|
- Authorization: https://discord.com/api/oauth2/authorize
|
||||||
|
- Token: https://discord.com/api/oauth2/token
|
||||||
|
- User Info: https://discord.com/api/users/@me
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from app.services.oauth.discord import discord_oauth
|
||||||
|
|
||||||
|
# Step 1: Redirect user to Discord
|
||||||
|
auth_url = discord_oauth.get_authorization_url(
|
||||||
|
redirect_uri="https://play.mantimon.com/api/auth/discord/callback",
|
||||||
|
state="random-csrf-token"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: Handle callback and get user info
|
||||||
|
user_info = await discord_oauth.get_user_info(code, redirect_uri)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.schemas.user import OAuthUserInfo
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordOAuthError(Exception):
|
||||||
|
"""Exception raised for Discord OAuth errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordOAuth:
|
||||||
|
"""Discord OAuth 2.0 service.
|
||||||
|
|
||||||
|
Handles the OAuth flow for authenticating users with Discord accounts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
AUTHORIZATION_URL = "https://discord.com/api/oauth2/authorize"
|
||||||
|
TOKEN_URL = "https://discord.com/api/oauth2/token"
|
||||||
|
USER_INFO_URL = "https://discord.com/api/users/@me"
|
||||||
|
CDN_URL = "https://cdn.discordapp.com"
|
||||||
|
|
||||||
|
# Scopes we request from Discord
|
||||||
|
SCOPES = ["identify", "email"]
|
||||||
|
|
||||||
|
def get_authorization_url(self, redirect_uri: str, state: str) -> str:
|
||||||
|
"""Generate the Discord OAuth authorization URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redirect_uri: Where Discord should redirect after authorization.
|
||||||
|
state: Random string for CSRF protection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full authorization URL to redirect user to.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DiscordOAuthError: If Discord OAuth is not configured.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
url = discord_oauth.get_authorization_url(
|
||||||
|
redirect_uri="https://play.mantimon.com/api/auth/discord/callback",
|
||||||
|
state="abc123"
|
||||||
|
)
|
||||||
|
# Redirect user to url
|
||||||
|
"""
|
||||||
|
if not settings.discord_client_id:
|
||||||
|
raise DiscordOAuthError("Discord OAuth is not configured")
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"client_id": settings.discord_client_id,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": " ".join(self.SCOPES),
|
||||||
|
"state": state,
|
||||||
|
"prompt": "consent", # Always show consent screen
|
||||||
|
}
|
||||||
|
|
||||||
|
return f"{self.AUTHORIZATION_URL}?{urlencode(params)}"
|
||||||
|
|
||||||
|
async def exchange_code_for_tokens(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
) -> dict:
|
||||||
|
"""Exchange authorization code for access tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Authorization code from Discord callback.
|
||||||
|
redirect_uri: Same redirect_uri used in authorization request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token response containing access_token, refresh_token, etc.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DiscordOAuthError: If token exchange fails.
|
||||||
|
"""
|
||||||
|
if not settings.discord_client_id or not settings.discord_client_secret:
|
||||||
|
raise DiscordOAuthError("Discord OAuth is not configured")
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"client_id": settings.discord_client_id,
|
||||||
|
"client_secret": settings.discord_client_secret.get_secret_value(),
|
||||||
|
"code": code,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.TOKEN_URL,
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_data = response.json() if response.content else {}
|
||||||
|
error_msg = error_data.get("error_description", response.text)
|
||||||
|
raise DiscordOAuthError(f"Token exchange failed: {error_msg}")
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def fetch_user_info(self, access_token: str) -> dict:
|
||||||
|
"""Fetch user information from Discord.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
access_token: Valid Discord access token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User info dict with id, username, email, avatar, etc.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DiscordOAuthError: If fetching user info fails.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
self.USER_INFO_URL,
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise DiscordOAuthError(f"Failed to fetch user info: {response.text}")
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def _build_avatar_url(self, user_id: str, avatar_hash: str | None) -> str | None:
|
||||||
|
"""Build Discord avatar URL from user ID and avatar hash.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Discord user ID.
|
||||||
|
avatar_hash: Avatar hash from Discord API (can be None).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full CDN URL for avatar, or None if no avatar.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Discord avatar format: https://cdn.discordapp.com/avatars/{user_id}/{avatar_hash}.png
|
||||||
|
If avatar_hash starts with 'a_', it's animated (gif).
|
||||||
|
"""
|
||||||
|
if not avatar_hash:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Animated avatars start with 'a_'
|
||||||
|
extension = "gif" if avatar_hash.startswith("a_") else "png"
|
||||||
|
return f"{self.CDN_URL}/avatars/{user_id}/{avatar_hash}.{extension}"
|
||||||
|
|
||||||
|
async def get_user_info(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
) -> OAuthUserInfo:
|
||||||
|
"""Complete OAuth flow: exchange code and fetch user info.
|
||||||
|
|
||||||
|
This is the main method to call after receiving the OAuth callback.
|
||||||
|
It exchanges the authorization code for tokens, then fetches user info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Authorization code from Discord callback.
|
||||||
|
redirect_uri: Same redirect_uri used in authorization request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized OAuthUserInfo ready for user creation/lookup.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DiscordOAuthError: If any step of the OAuth flow fails.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# In your callback handler:
|
||||||
|
user_info = await discord_oauth.get_user_info(code, redirect_uri)
|
||||||
|
user, created = await user_service.get_or_create_from_oauth(db, user_info)
|
||||||
|
"""
|
||||||
|
# Exchange code for tokens
|
||||||
|
tokens = await self.exchange_code_for_tokens(code, redirect_uri)
|
||||||
|
access_token = tokens.get("access_token")
|
||||||
|
|
||||||
|
if not access_token:
|
||||||
|
raise DiscordOAuthError("No access token in response")
|
||||||
|
|
||||||
|
# Fetch user info
|
||||||
|
user_data = await self.fetch_user_info(access_token)
|
||||||
|
|
||||||
|
# Discord requires email scope, but email can still be None if not verified
|
||||||
|
email = user_data.get("email")
|
||||||
|
if not email:
|
||||||
|
raise DiscordOAuthError("Discord account does not have a verified email")
|
||||||
|
|
||||||
|
# Build display name: prefer global_name, then username
|
||||||
|
display_name = user_data.get("global_name") or user_data["username"]
|
||||||
|
|
||||||
|
# Build avatar URL
|
||||||
|
avatar_url = self._build_avatar_url(
|
||||||
|
user_data["id"],
|
||||||
|
user_data.get("avatar"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize to our schema
|
||||||
|
return OAuthUserInfo(
|
||||||
|
provider="discord",
|
||||||
|
oauth_id=user_data["id"],
|
||||||
|
email=email,
|
||||||
|
name=display_name,
|
||||||
|
avatar_url=avatar_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_configured(self) -> bool:
|
||||||
|
"""Check if Discord OAuth is properly configured.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if client ID and secret are set.
|
||||||
|
"""
|
||||||
|
return bool(settings.discord_client_id and settings.discord_client_secret)
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance
|
||||||
|
discord_oauth = DiscordOAuth()
|
||||||
207
backend/app/services/oauth/google.py
Normal file
207
backend/app/services/oauth/google.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
"""Google OAuth service for Mantimon TCG.
|
||||||
|
|
||||||
|
This module handles Google OAuth 2.0 authentication flow:
|
||||||
|
1. Generate authorization URL for user redirect
|
||||||
|
2. Exchange authorization code for tokens
|
||||||
|
3. Fetch user information from Google
|
||||||
|
|
||||||
|
Google OAuth Endpoints:
|
||||||
|
- Authorization: https://accounts.google.com/o/oauth2/v2/auth
|
||||||
|
- Token: https://oauth2.googleapis.com/token
|
||||||
|
- User Info: https://www.googleapis.com/oauth2/v2/userinfo
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from app.services.oauth.google import google_oauth
|
||||||
|
|
||||||
|
# Step 1: Redirect user to Google
|
||||||
|
auth_url = google_oauth.get_authorization_url(
|
||||||
|
redirect_uri="https://play.mantimon.com/api/auth/google/callback",
|
||||||
|
state="random-csrf-token"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: Handle callback and get user info
|
||||||
|
user_info = await google_oauth.get_user_info(code, redirect_uri)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.schemas.user import OAuthUserInfo
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleOAuthError(Exception):
|
||||||
|
"""Exception raised for Google OAuth errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleOAuth:
|
||||||
|
"""Google OAuth 2.0 service.
|
||||||
|
|
||||||
|
Handles the OAuth flow for authenticating users with Google accounts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
AUTHORIZATION_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||||
|
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||||
|
USER_INFO_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||||
|
|
||||||
|
# Scopes we request from Google
|
||||||
|
SCOPES = ["openid", "email", "profile"]
|
||||||
|
|
||||||
|
def get_authorization_url(self, redirect_uri: str, state: str) -> str:
|
||||||
|
"""Generate the Google OAuth authorization URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redirect_uri: Where Google should redirect after authorization.
|
||||||
|
state: Random string for CSRF protection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full authorization URL to redirect user to.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GoogleOAuthError: If Google OAuth is not configured.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
url = google_oauth.get_authorization_url(
|
||||||
|
redirect_uri="https://play.mantimon.com/api/auth/google/callback",
|
||||||
|
state="abc123"
|
||||||
|
)
|
||||||
|
# Redirect user to url
|
||||||
|
"""
|
||||||
|
if not settings.google_client_id:
|
||||||
|
raise GoogleOAuthError("Google OAuth is not configured")
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"client_id": settings.google_client_id,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": " ".join(self.SCOPES),
|
||||||
|
"state": state,
|
||||||
|
"access_type": "offline", # Get refresh token
|
||||||
|
"prompt": "select_account", # Always show account picker
|
||||||
|
}
|
||||||
|
|
||||||
|
return f"{self.AUTHORIZATION_URL}?{urlencode(params)}"
|
||||||
|
|
||||||
|
async def exchange_code_for_tokens(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
) -> dict:
|
||||||
|
"""Exchange authorization code for access tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Authorization code from Google callback.
|
||||||
|
redirect_uri: Same redirect_uri used in authorization request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token response containing access_token, id_token, etc.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GoogleOAuthError: If token exchange fails.
|
||||||
|
"""
|
||||||
|
if not settings.google_client_id or not settings.google_client_secret:
|
||||||
|
raise GoogleOAuthError("Google OAuth is not configured")
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"client_id": settings.google_client_id,
|
||||||
|
"client_secret": settings.google_client_secret.get_secret_value(),
|
||||||
|
"code": code,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.TOKEN_URL,
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_data = response.json() if response.content else {}
|
||||||
|
error_msg = error_data.get("error_description", response.text)
|
||||||
|
raise GoogleOAuthError(f"Token exchange failed: {error_msg}")
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def fetch_user_info(self, access_token: str) -> dict:
|
||||||
|
"""Fetch user information from Google.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
access_token: Valid Google access token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User info dict with id, email, name, picture, etc.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GoogleOAuthError: If fetching user info fails.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
self.USER_INFO_URL,
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise GoogleOAuthError(f"Failed to fetch user info: {response.text}")
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def get_user_info(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
) -> OAuthUserInfo:
|
||||||
|
"""Complete OAuth flow: exchange code and fetch user info.
|
||||||
|
|
||||||
|
This is the main method to call after receiving the OAuth callback.
|
||||||
|
It exchanges the authorization code for tokens, then fetches user info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Authorization code from Google callback.
|
||||||
|
redirect_uri: Same redirect_uri used in authorization request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized OAuthUserInfo ready for user creation/lookup.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GoogleOAuthError: If any step of the OAuth flow fails.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# In your callback handler:
|
||||||
|
user_info = await google_oauth.get_user_info(code, redirect_uri)
|
||||||
|
user, created = await user_service.get_or_create_from_oauth(db, user_info)
|
||||||
|
"""
|
||||||
|
# Exchange code for tokens
|
||||||
|
tokens = await self.exchange_code_for_tokens(code, redirect_uri)
|
||||||
|
access_token = tokens.get("access_token")
|
||||||
|
|
||||||
|
if not access_token:
|
||||||
|
raise GoogleOAuthError("No access token in response")
|
||||||
|
|
||||||
|
# Fetch user info
|
||||||
|
user_data = await self.fetch_user_info(access_token)
|
||||||
|
|
||||||
|
# Normalize to our schema
|
||||||
|
return OAuthUserInfo(
|
||||||
|
provider="google",
|
||||||
|
oauth_id=user_data["id"],
|
||||||
|
email=user_data["email"],
|
||||||
|
name=user_data.get("name", user_data["email"].split("@")[0]),
|
||||||
|
avatar_url=user_data.get("picture"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_configured(self) -> bool:
|
||||||
|
"""Check if Google OAuth is properly configured.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if client ID and secret are set.
|
||||||
|
"""
|
||||||
|
return bool(settings.google_client_id and settings.google_client_secret)
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance
|
||||||
|
google_oauth = GoogleOAuth()
|
||||||
195
backend/app/services/token_store.py
Normal file
195
backend/app/services/token_store.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
"""Refresh token storage for Mantimon TCG.
|
||||||
|
|
||||||
|
This module provides Redis-based storage for refresh token tracking
|
||||||
|
and revocation. Each refresh token's JTI is stored in Redis with
|
||||||
|
a TTL matching the token's expiration.
|
||||||
|
|
||||||
|
Key Pattern:
|
||||||
|
refresh_token:{user_id}:{jti} -> "1" (exists = valid)
|
||||||
|
|
||||||
|
Revocation:
|
||||||
|
- Single token: Delete the specific key
|
||||||
|
- All user tokens: Delete all keys matching refresh_token:{user_id}:*
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from app.services.token_store import token_store
|
||||||
|
|
||||||
|
# Store a new refresh token
|
||||||
|
await token_store.store_refresh_token(user_id, jti, expires_at)
|
||||||
|
|
||||||
|
# Check if token is valid (not revoked)
|
||||||
|
if await token_store.is_token_valid(user_id, jti):
|
||||||
|
# Issue new access token
|
||||||
|
|
||||||
|
# Revoke on logout
|
||||||
|
await token_store.revoke_token(user_id, jti)
|
||||||
|
|
||||||
|
# Logout from all devices
|
||||||
|
await token_store.revoke_all_user_tokens(user_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from app.db.redis import get_redis
|
||||||
|
|
||||||
|
|
||||||
|
class TokenStore:
|
||||||
|
"""Redis-based refresh token storage for revocation support.
|
||||||
|
|
||||||
|
Tracks valid refresh tokens by storing their JTIs in Redis.
|
||||||
|
Tokens can be revoked individually or all at once per user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
KEY_PREFIX = "refresh_token"
|
||||||
|
|
||||||
|
def _make_key(self, user_id: UUID, jti: str) -> str:
|
||||||
|
"""Create Redis key for a refresh token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
jti: The token's unique identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Redis key string.
|
||||||
|
"""
|
||||||
|
return f"{self.KEY_PREFIX}:{user_id}:{jti}"
|
||||||
|
|
||||||
|
def _make_user_pattern(self, user_id: UUID) -> str:
|
||||||
|
"""Create Redis key pattern for all user's tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Redis key pattern for SCAN/KEYS.
|
||||||
|
"""
|
||||||
|
return f"{self.KEY_PREFIX}:{user_id}:*"
|
||||||
|
|
||||||
|
async def store_refresh_token(
|
||||||
|
self,
|
||||||
|
user_id: UUID,
|
||||||
|
jti: str,
|
||||||
|
expires_at: datetime,
|
||||||
|
) -> None:
|
||||||
|
"""Store a refresh token's JTI in Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
jti: The token's unique identifier (from JWT).
|
||||||
|
expires_at: When the token expires (for TTL calculation).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
expires_at = datetime.now(UTC) + timedelta(days=7)
|
||||||
|
await token_store.store_refresh_token(user_id, jti, expires_at)
|
||||||
|
"""
|
||||||
|
key = self._make_key(user_id, jti)
|
||||||
|
|
||||||
|
# Calculate TTL in seconds
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
ttl_seconds = int((expires_at - now).total_seconds())
|
||||||
|
|
||||||
|
if ttl_seconds <= 0:
|
||||||
|
# Token already expired, don't store
|
||||||
|
return
|
||||||
|
|
||||||
|
async with get_redis() as redis:
|
||||||
|
await redis.setex(key, ttl_seconds, "1")
|
||||||
|
|
||||||
|
async def is_token_valid(self, user_id: UUID, jti: str) -> bool:
|
||||||
|
"""Check if a refresh token is valid (not revoked).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
jti: The token's unique identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if token exists in store (valid), False if revoked or expired.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
if await token_store.is_token_valid(user_id, jti):
|
||||||
|
# Token is valid, issue new access token
|
||||||
|
else:
|
||||||
|
# Token was revoked, require re-authentication
|
||||||
|
"""
|
||||||
|
key = self._make_key(user_id, jti)
|
||||||
|
|
||||||
|
async with get_redis() as redis:
|
||||||
|
result = await redis.exists(key)
|
||||||
|
return result > 0
|
||||||
|
|
||||||
|
async def revoke_token(self, user_id: UUID, jti: str) -> bool:
|
||||||
|
"""Revoke a specific refresh token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
jti: The token's unique identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if token was revoked, False if it didn't exist.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# On logout
|
||||||
|
await token_store.revoke_token(user_id, jti)
|
||||||
|
"""
|
||||||
|
key = self._make_key(user_id, jti)
|
||||||
|
|
||||||
|
async with get_redis() as redis:
|
||||||
|
result = await redis.delete(key)
|
||||||
|
return result > 0
|
||||||
|
|
||||||
|
async def revoke_all_user_tokens(self, user_id: UUID) -> int:
|
||||||
|
"""Revoke all refresh tokens for a user.
|
||||||
|
|
||||||
|
Useful for "logout from all devices" or security incidents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens revoked.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Logout from all devices
|
||||||
|
count = await token_store.revoke_all_user_tokens(user_id)
|
||||||
|
print(f"Revoked {count} sessions")
|
||||||
|
"""
|
||||||
|
pattern = self._make_user_pattern(user_id)
|
||||||
|
|
||||||
|
async with get_redis() as redis:
|
||||||
|
# Use SCAN to find all matching keys (safer than KEYS for large datasets)
|
||||||
|
keys_to_delete = []
|
||||||
|
async for key in redis.scan_iter(match=pattern):
|
||||||
|
keys_to_delete.append(key)
|
||||||
|
|
||||||
|
if not keys_to_delete:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Delete all found keys
|
||||||
|
result = await redis.delete(*keys_to_delete)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_active_session_count(self, user_id: UUID) -> int:
|
||||||
|
"""Get the number of active sessions (valid refresh tokens) for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of active sessions.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
count = await token_store.get_active_session_count(user_id)
|
||||||
|
print(f"User has {count} active sessions")
|
||||||
|
"""
|
||||||
|
pattern = self._make_user_pattern(user_id)
|
||||||
|
|
||||||
|
async with get_redis() as redis:
|
||||||
|
count = 0
|
||||||
|
async for _ in redis.scan_iter(match=pattern):
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
# Global token store instance
|
||||||
|
token_store = TokenStore()
|
||||||
439
backend/app/services/user_service.py
Normal file
439
backend/app/services/user_service.py
Normal file
@ -0,0 +1,439 @@
|
|||||||
|
"""User service for Mantimon TCG.
|
||||||
|
|
||||||
|
This module provides async CRUD operations for user accounts,
|
||||||
|
including OAuth-based user creation and premium status management.
|
||||||
|
|
||||||
|
All database operations use async SQLAlchemy sessions.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from app.services.user_service import user_service
|
||||||
|
|
||||||
|
# Get user by ID
|
||||||
|
user = await user_service.get_by_id(db, user_id)
|
||||||
|
|
||||||
|
# Create from OAuth
|
||||||
|
user = await user_service.create_from_oauth(db, oauth_info)
|
||||||
|
|
||||||
|
# Update premium status
|
||||||
|
user = await user_service.update_premium(db, user_id, premium_until)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db.models.oauth_account import OAuthLinkedAccount
|
||||||
|
from app.db.models.user import User
|
||||||
|
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
|
||||||
|
|
||||||
|
|
||||||
|
class AccountLinkingError(Exception):
|
||||||
|
"""Error during account linking operation."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UserService:
|
||||||
|
"""Service for user account operations.
|
||||||
|
|
||||||
|
Provides async methods for user CRUD, OAuth-based creation,
|
||||||
|
and premium subscription management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get_by_id(self, db: AsyncSession, user_id: UUID) -> User | None:
|
||||||
|
"""Get a user by their ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Async database session.
|
||||||
|
user_id: The user's UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User if found, None otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
user = await user_service.get_by_id(db, user_id)
|
||||||
|
if user:
|
||||||
|
print(f"Found user: {user.display_name}")
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
|
||||||
|
"""Get a user by their email address.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Async database session.
|
||||||
|
email: The user's email address.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User if found, None otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
user = await user_service.get_by_email(db, "player@example.com")
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(User).where(User.email == email))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_by_oauth(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
provider: str,
|
||||||
|
oauth_id: str,
|
||||||
|
) -> User | None:
|
||||||
|
"""Get a user by their OAuth provider and ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Async database session.
|
||||||
|
provider: OAuth provider name (google, discord).
|
||||||
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User if found, None otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
user = await user_service.get_by_oauth(db, "google", "123456789")
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(User).where(
|
||||||
|
User.oauth_provider == provider,
|
||||||
|
User.oauth_id == oauth_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def create(self, db: AsyncSession, user_data: UserCreate) -> User:
|
||||||
|
"""Create a new user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Async database session.
|
||||||
|
user_data: User creation data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created User instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
user_data = UserCreate(
|
||||||
|
email="player@example.com",
|
||||||
|
display_name="Player1",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="123456789"
|
||||||
|
)
|
||||||
|
user = await user_service.create(db, user_data)
|
||||||
|
"""
|
||||||
|
user = User(
|
||||||
|
email=user_data.email,
|
||||||
|
display_name=user_data.display_name,
|
||||||
|
avatar_url=user_data.avatar_url,
|
||||||
|
oauth_provider=user_data.oauth_provider,
|
||||||
|
oauth_id=user_data.oauth_id,
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def create_from_oauth(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
oauth_info: OAuthUserInfo,
|
||||||
|
) -> User:
|
||||||
|
"""Create a new user from OAuth provider info.
|
||||||
|
|
||||||
|
Convenience method that converts OAuthUserInfo to UserCreate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Async database session.
|
||||||
|
oauth_info: Normalized OAuth user information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created User instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
oauth_info = OAuthUserInfo(
|
||||||
|
provider="google",
|
||||||
|
oauth_id="123456789",
|
||||||
|
email="player@example.com",
|
||||||
|
name="Player One",
|
||||||
|
avatar_url="https://..."
|
||||||
|
)
|
||||||
|
user = await user_service.create_from_oauth(db, oauth_info)
|
||||||
|
"""
|
||||||
|
user_data = oauth_info.to_user_create()
|
||||||
|
return await self.create(db, user_data)
|
||||||
|
|
||||||
|
async def get_or_create_from_oauth(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
oauth_info: OAuthUserInfo,
|
||||||
|
) -> tuple[User, bool]:
|
||||||
|
"""Get existing user or create new one from OAuth info.
|
||||||
|
|
||||||
|
First checks for existing user by OAuth provider+ID, then by email
|
||||||
|
(for account linking), and finally creates a new user if not found.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Async database session.
|
||||||
|
oauth_info: Normalized OAuth user information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (User, created) where created is True if new user.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
user, created = await user_service.get_or_create_from_oauth(db, oauth_info)
|
||||||
|
if created:
|
||||||
|
print("Welcome, new user!")
|
||||||
|
else:
|
||||||
|
print("Welcome back!")
|
||||||
|
"""
|
||||||
|
# First, check by OAuth provider + ID (exact match)
|
||||||
|
user = await self.get_by_oauth(db, oauth_info.provider, oauth_info.oauth_id)
|
||||||
|
if user:
|
||||||
|
return user, False
|
||||||
|
|
||||||
|
# Check by email for potential account linking
|
||||||
|
# If user exists with same email but different OAuth, update their OAuth
|
||||||
|
user = await self.get_by_email(db, oauth_info.email)
|
||||||
|
if user:
|
||||||
|
# Update OAuth credentials for existing user
|
||||||
|
# This links the new OAuth provider to the existing account
|
||||||
|
user.oauth_provider = oauth_info.provider
|
||||||
|
user.oauth_id = oauth_info.oauth_id
|
||||||
|
# Optionally update avatar if not set
|
||||||
|
if not user.avatar_url and oauth_info.avatar_url:
|
||||||
|
user.avatar_url = oauth_info.avatar_url
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
return user, False
|
||||||
|
|
||||||
|
# Create new user
|
||||||
|
user = await self.create_from_oauth(db, oauth_info)
|
||||||
|
return user, True
|
||||||
|
|
||||||
|
async def update(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user: User,
|
||||||
|
update_data: UserUpdate,
|
||||||
|
) -> User:
|
||||||
|
"""Update user profile fields.
|
||||||
|
|
||||||
|
Only updates fields that are provided (not None).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Async database session.
|
||||||
|
user: The user to update.
|
||||||
|
update_data: Fields to update.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated User instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
update_data = UserUpdate(display_name="New Name")
|
||||||
|
user = await user_service.update(db, user, update_data)
|
||||||
|
"""
|
||||||
|
if update_data.display_name is not None:
|
||||||
|
user.display_name = update_data.display_name
|
||||||
|
if update_data.avatar_url is not None:
|
||||||
|
user.avatar_url = update_data.avatar_url
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def update_last_login(self, db: AsyncSession, user: User) -> User:
|
||||||
|
"""Update the user's last login timestamp.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Async database session.
|
||||||
|
user: The user to update.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated User instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
user = await user_service.update_last_login(db, user)
|
||||||
|
"""
|
||||||
|
user.last_login = datetime.now(UTC)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def update_premium(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user: User,
|
||||||
|
premium_until: datetime | None,
|
||||||
|
) -> User:
|
||||||
|
"""Update user's premium subscription status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Async database session.
|
||||||
|
user: The user to update.
|
||||||
|
premium_until: When premium expires, or None to remove premium.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated User instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Grant 30 days of premium
|
||||||
|
expires = datetime.now(UTC) + timedelta(days=30)
|
||||||
|
user = await user_service.update_premium(db, user, expires)
|
||||||
|
|
||||||
|
# Remove premium
|
||||||
|
user = await user_service.update_premium(db, user, None)
|
||||||
|
"""
|
||||||
|
if premium_until is not None:
|
||||||
|
user.is_premium = True
|
||||||
|
user.premium_until = premium_until
|
||||||
|
else:
|
||||||
|
user.is_premium = False
|
||||||
|
user.premium_until = None
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def delete(self, db: AsyncSession, user: User) -> None:
|
||||||
|
"""Delete a user account.
|
||||||
|
|
||||||
|
This will cascade delete all related data (decks, collection, etc.)
|
||||||
|
based on the model relationships.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Async database session.
|
||||||
|
user: The user to delete.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
await user_service.delete(db, user)
|
||||||
|
"""
|
||||||
|
await db.delete(user)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def get_linked_account(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
provider: str,
|
||||||
|
oauth_id: str,
|
||||||
|
) -> OAuthLinkedAccount | None:
|
||||||
|
"""Get a linked account by provider and OAuth ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Async database session.
|
||||||
|
provider: OAuth provider name (google, discord).
|
||||||
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OAuthLinkedAccount if found, None otherwise.
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthLinkedAccount).where(
|
||||||
|
OAuthLinkedAccount.provider == provider,
|
||||||
|
OAuthLinkedAccount.oauth_id == oauth_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def link_oauth_account(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user: User,
|
||||||
|
oauth_info: OAuthUserInfo,
|
||||||
|
) -> OAuthLinkedAccount:
|
||||||
|
"""Link an additional OAuth provider to a user account.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Async database session.
|
||||||
|
user: The user to link the account to.
|
||||||
|
oauth_info: OAuth information from the provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created OAuthLinkedAccount.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AccountLinkingError: If provider is already linked to this or another user.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
linked = await user_service.link_oauth_account(db, user, discord_info)
|
||||||
|
"""
|
||||||
|
# Check if this provider+oauth_id is already linked to any user
|
||||||
|
existing = await self.get_linked_account(db, oauth_info.provider, oauth_info.oauth_id)
|
||||||
|
if existing:
|
||||||
|
if str(existing.user_id) == str(user.id):
|
||||||
|
raise AccountLinkingError(
|
||||||
|
f"{oauth_info.provider.title()} account is already linked to your account"
|
||||||
|
)
|
||||||
|
raise AccountLinkingError(
|
||||||
|
f"This {oauth_info.provider.title()} account is already linked to another user"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if this is the user's primary OAuth provider
|
||||||
|
if user.oauth_provider == oauth_info.provider:
|
||||||
|
raise AccountLinkingError(
|
||||||
|
f"{oauth_info.provider.title()} is your primary login provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user already has this provider linked
|
||||||
|
for linked in user.linked_accounts:
|
||||||
|
if linked.provider == oauth_info.provider:
|
||||||
|
raise AccountLinkingError(
|
||||||
|
f"You already have a {oauth_info.provider.title()} account linked"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the linked account
|
||||||
|
linked_account = OAuthLinkedAccount(
|
||||||
|
user_id=str(user.id),
|
||||||
|
provider=oauth_info.provider,
|
||||||
|
oauth_id=oauth_info.oauth_id,
|
||||||
|
email=oauth_info.email,
|
||||||
|
display_name=oauth_info.name,
|
||||||
|
avatar_url=oauth_info.avatar_url,
|
||||||
|
)
|
||||||
|
db.add(linked_account)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(linked_account)
|
||||||
|
return linked_account
|
||||||
|
|
||||||
|
async def unlink_oauth_account(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user: User,
|
||||||
|
provider: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Unlink an OAuth provider from a user account.
|
||||||
|
|
||||||
|
Cannot unlink the primary OAuth provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Async database session.
|
||||||
|
user: The user to unlink from.
|
||||||
|
provider: OAuth provider name to unlink.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if unlinked, False if provider wasn't linked.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AccountLinkingError: If trying to unlink the primary provider.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
success = await user_service.unlink_oauth_account(db, user, "discord")
|
||||||
|
"""
|
||||||
|
# Cannot unlink primary provider
|
||||||
|
if user.oauth_provider == provider:
|
||||||
|
raise AccountLinkingError(
|
||||||
|
f"Cannot unlink {provider.title()} - it is your primary login provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find and delete the linked account
|
||||||
|
for linked in user.linked_accounts:
|
||||||
|
if linked.provider == provider:
|
||||||
|
await db.delete(linked)
|
||||||
|
await db.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Global service instance
|
||||||
|
user_service = UserService()
|
||||||
@ -43,6 +43,18 @@ services:
|
|||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 5
|
retries: 5
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
|
adminer:
|
||||||
|
image: adminer:latest
|
||||||
|
restart: unless-stopped
|
||||||
|
container_name: mantimon-adminer
|
||||||
|
ports:
|
||||||
|
- "8090:8080"
|
||||||
|
environment:
|
||||||
|
- ADMINER_DEFAULT_SERVER=mantimon-postgres
|
||||||
|
- TZ=America/Chicago
|
||||||
|
depends_on:
|
||||||
|
- postgres
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres_data:
|
postgres_data:
|
||||||
|
|||||||
478
backend/project_plans/PHASE_2_AUTH.json
Normal file
478
backend/project_plans/PHASE_2_AUTH.json
Normal file
@ -0,0 +1,478 @@
|
|||||||
|
{
|
||||||
|
"meta": {
|
||||||
|
"version": "1.0.0",
|
||||||
|
"created": "2026-01-27",
|
||||||
|
"lastUpdated": "2026-01-27",
|
||||||
|
"planType": "phase",
|
||||||
|
"phaseId": "PHASE_2",
|
||||||
|
"phaseName": "Authentication",
|
||||||
|
"description": "OAuth login (Google, Discord), JWT session management, user management, premium tier tracking",
|
||||||
|
"totalEstimatedHours": 24,
|
||||||
|
"totalTasks": 15,
|
||||||
|
"completedTasks": 15,
|
||||||
|
"status": "complete",
|
||||||
|
"masterPlan": "../PROJECT_PLAN_MASTER.json"
|
||||||
|
},
|
||||||
|
|
||||||
|
"goals": [
|
||||||
|
"Implement OAuth 2.0 authentication with Google and Discord providers",
|
||||||
|
"Create JWT-based session management with access/refresh token pattern",
|
||||||
|
"Build user management service with create, read, update operations",
|
||||||
|
"Implement FastAPI dependencies for protected endpoints",
|
||||||
|
"Support account linking (multiple OAuth providers per user)",
|
||||||
|
"Track premium subscription status with expiration dates"
|
||||||
|
],
|
||||||
|
|
||||||
|
"architectureNotes": {
|
||||||
|
"tokenStrategy": {
|
||||||
|
"accessToken": "Short-lived JWT (30 min default), contains user_id",
|
||||||
|
"refreshToken": "Longer-lived JWT (7 days default), stored in Redis for revocation",
|
||||||
|
"storage": "Refresh tokens tracked in Redis for logout/revocation support"
|
||||||
|
},
|
||||||
|
"oauthFlow": {
|
||||||
|
"pattern": "Authorization Code Flow (PKCE deferred)",
|
||||||
|
"callback": "Backend receives code, exchanges for tokens, creates/updates user",
|
||||||
|
"security": "Never store OAuth provider tokens, only OAuth ID"
|
||||||
|
},
|
||||||
|
"accountLinking": {
|
||||||
|
"strategy": "Email-based matching + explicit linking via OAuth flow",
|
||||||
|
"flow": "If user exists with same email, add OAuth provider to existing account. Users can also explicitly link additional providers via /auth/link/{provider}"
|
||||||
|
},
|
||||||
|
"existingInfrastructure": {
|
||||||
|
"config": "JWT, OAuth, and base_url settings in app/config.py Settings class",
|
||||||
|
"dependencies": "python-jose, passlib, bcrypt, httpx already installed",
|
||||||
|
"userModel": "User model with OAuth fields in app/db/models/user.py"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
"directoryStructure": {
|
||||||
|
"schemas": "backend/app/schemas/",
|
||||||
|
"services": "backend/app/services/",
|
||||||
|
"api": "backend/app/api/",
|
||||||
|
"tests": "backend/tests/api/, backend/tests/services/"
|
||||||
|
},
|
||||||
|
|
||||||
|
"tasks": [
|
||||||
|
{
|
||||||
|
"id": "AUTH-001",
|
||||||
|
"name": "Create Pydantic schemas for auth",
|
||||||
|
"description": "Define request/response models for authentication flows",
|
||||||
|
"category": "critical",
|
||||||
|
"priority": 1,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": [],
|
||||||
|
"files": [
|
||||||
|
{"path": "app/schemas/__init__.py", "status": "complete"},
|
||||||
|
{"path": "app/schemas/auth.py", "status": "complete"},
|
||||||
|
{"path": "app/schemas/user.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"TokenPayload: sub (user_id), exp, iat, type (access/refresh)",
|
||||||
|
"TokenResponse: access_token, refresh_token, token_type, expires_in",
|
||||||
|
"UserResponse: id, email, display_name, avatar_url, is_premium, premium_until",
|
||||||
|
"UserCreate: internal model for user creation from OAuth",
|
||||||
|
"OAuthUserInfo: normalized structure for OAuth provider data"
|
||||||
|
],
|
||||||
|
"estimatedHours": 1.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-002",
|
||||||
|
"name": "Create JWT utilities service",
|
||||||
|
"description": "Functions for creating and verifying JWT tokens",
|
||||||
|
"category": "critical",
|
||||||
|
"priority": 2,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": ["AUTH-001"],
|
||||||
|
"files": [
|
||||||
|
{"path": "app/services/jwt_service.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"create_access_token(user_id: UUID) -> str - Uses settings.jwt_expire_minutes",
|
||||||
|
"create_refresh_token(user_id: UUID) -> tuple[str, str] - Returns token and jti",
|
||||||
|
"verify_access_token(token: str) -> UUID | None - Returns user_id or None",
|
||||||
|
"verify_refresh_token(token: str) -> tuple[UUID, str] | None - Returns (user_id, jti) or None",
|
||||||
|
"Uses python-jose with HS256 algorithm",
|
||||||
|
"All timing uses datetime.now(UTC) per project standards"
|
||||||
|
],
|
||||||
|
"estimatedHours": 1.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-003",
|
||||||
|
"name": "Create refresh token Redis storage",
|
||||||
|
"description": "Redis-based storage for refresh token tracking and revocation",
|
||||||
|
"category": "high",
|
||||||
|
"priority": 3,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": ["AUTH-002"],
|
||||||
|
"files": [
|
||||||
|
{"path": "app/services/token_store.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"Key format: refresh_token:{user_id}:{jti} -> '1' with TTL",
|
||||||
|
"store_refresh_token(user_id, jti, expires_at) - Store with TTL",
|
||||||
|
"is_token_valid(user_id, jti) -> bool - Check if not revoked",
|
||||||
|
"revoke_token(user_id, jti) - Delete specific token",
|
||||||
|
"revoke_all_user_tokens(user_id) - Logout from all devices",
|
||||||
|
"get_active_session_count(user_id) - Count valid tokens",
|
||||||
|
"Uses existing Redis connection from app/db/redis.py"
|
||||||
|
],
|
||||||
|
"estimatedHours": 1.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-004",
|
||||||
|
"name": "Create UserService",
|
||||||
|
"description": "Service layer for user CRUD operations",
|
||||||
|
"category": "critical",
|
||||||
|
"priority": 4,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": ["AUTH-001"],
|
||||||
|
"files": [
|
||||||
|
{"path": "app/services/user_service.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"get_by_id(db, user_id: UUID) -> User | None",
|
||||||
|
"get_by_email(db, email: str) -> User | None",
|
||||||
|
"get_by_oauth(db, provider: str, oauth_id: str) -> User | None",
|
||||||
|
"create(db, user_data: UserCreate) -> User",
|
||||||
|
"create_from_oauth(db, oauth_info: OAuthUserInfo) -> User",
|
||||||
|
"get_or_create_from_oauth(db, oauth_info: OAuthUserInfo) -> tuple[User, bool]",
|
||||||
|
"update(db, user: User, update_data: UserUpdate) -> User",
|
||||||
|
"update_last_login(db, user: User) -> User",
|
||||||
|
"update_premium(db, user: User, premium_until: datetime | None) -> User",
|
||||||
|
"link_oauth_account(db, user: User, oauth_info: OAuthUserInfo) -> OAuthLinkedAccount",
|
||||||
|
"unlink_oauth_account(db, user: User, provider: str) -> bool",
|
||||||
|
"All operations are async using SQLAlchemy async session"
|
||||||
|
],
|
||||||
|
"estimatedHours": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-005",
|
||||||
|
"name": "Create OAuthLinkedAccount model",
|
||||||
|
"description": "Database model for multiple OAuth providers per user (account linking)",
|
||||||
|
"category": "high",
|
||||||
|
"priority": 5,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": [],
|
||||||
|
"files": [
|
||||||
|
{"path": "app/db/models/oauth_account.py", "status": "complete"},
|
||||||
|
{"path": "app/db/migrations/versions/5ce887128ab1_add_oauth_linked_accounts.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"Fields: id, user_id (FK), provider, oauth_id, email, display_name, avatar_url, linked_at",
|
||||||
|
"Unique constraint on (provider, oauth_id)",
|
||||||
|
"User.oauth_provider/oauth_id kept as 'primary' provider",
|
||||||
|
"Relationship: User.linked_accounts -> list[OAuthLinkedAccount]"
|
||||||
|
],
|
||||||
|
"estimatedHours": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-006",
|
||||||
|
"name": "Create Google OAuth service",
|
||||||
|
"description": "Handle Google OAuth authorization code flow",
|
||||||
|
"category": "critical",
|
||||||
|
"priority": 6,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": ["AUTH-004"],
|
||||||
|
"files": [
|
||||||
|
{"path": "app/services/oauth/google.py", "status": "complete"},
|
||||||
|
{"path": "app/services/oauth/__init__.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"get_authorization_url(redirect_uri, state) -> str",
|
||||||
|
"get_user_info(code, redirect_uri) -> OAuthUserInfo",
|
||||||
|
"is_configured() -> bool",
|
||||||
|
"Uses httpx for async HTTP requests",
|
||||||
|
"Google OAuth endpoints: accounts.google.com/o/oauth2/v2/auth, oauth2.googleapis.com/token",
|
||||||
|
"User info endpoint: www.googleapis.com/oauth2/v2/userinfo"
|
||||||
|
],
|
||||||
|
"estimatedHours": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-007",
|
||||||
|
"name": "Create Discord OAuth service",
|
||||||
|
"description": "Handle Discord OAuth authorization code flow",
|
||||||
|
"category": "critical",
|
||||||
|
"priority": 7,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": ["AUTH-004"],
|
||||||
|
"files": [
|
||||||
|
{"path": "app/services/oauth/discord.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"get_authorization_url(redirect_uri, state) -> str",
|
||||||
|
"get_user_info(code, redirect_uri) -> OAuthUserInfo",
|
||||||
|
"is_configured() -> bool",
|
||||||
|
"Uses httpx for async HTTP requests",
|
||||||
|
"Discord OAuth endpoints: discord.com/oauth2/authorize, discord.com/api/oauth2/token",
|
||||||
|
"User info endpoint: discord.com/api/users/@me",
|
||||||
|
"Avatar URL construction from user ID and avatar hash"
|
||||||
|
],
|
||||||
|
"estimatedHours": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-008",
|
||||||
|
"name": "Create FastAPI auth dependencies",
|
||||||
|
"description": "Dependency injection for protected endpoints",
|
||||||
|
"category": "critical",
|
||||||
|
"priority": 8,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": ["AUTH-002", "AUTH-003", "AUTH-004"],
|
||||||
|
"files": [
|
||||||
|
{"path": "app/api/__init__.py", "status": "complete"},
|
||||||
|
{"path": "app/api/deps.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"OAuth2PasswordBearer scheme for token extraction",
|
||||||
|
"get_current_user(token, db) -> User - Validates token, fetches user",
|
||||||
|
"CurrentUser type alias with Annotated for dependency injection",
|
||||||
|
"DbSession type alias for database dependency",
|
||||||
|
"Proper error responses: 401 Unauthorized"
|
||||||
|
],
|
||||||
|
"estimatedHours": 1.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-009",
|
||||||
|
"name": "Create auth API router",
|
||||||
|
"description": "REST endpoints for OAuth login, token refresh, logout",
|
||||||
|
"category": "critical",
|
||||||
|
"priority": 9,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": ["AUTH-006", "AUTH-007", "AUTH-008"],
|
||||||
|
"files": [
|
||||||
|
{"path": "app/api/auth.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"GET /auth/google - Redirects to Google OAuth consent screen",
|
||||||
|
"GET /auth/google/callback - Handles OAuth callback, returns tokens",
|
||||||
|
"GET /auth/discord - Redirects to Discord OAuth consent screen",
|
||||||
|
"GET /auth/discord/callback - Handles OAuth callback, returns tokens",
|
||||||
|
"GET /auth/link/google - Start account linking for Google (requires auth)",
|
||||||
|
"GET /auth/link/google/callback - Handle account linking callback",
|
||||||
|
"GET /auth/link/discord - Start account linking for Discord (requires auth)",
|
||||||
|
"GET /auth/link/discord/callback - Handle account linking callback",
|
||||||
|
"POST /auth/refresh - Exchange refresh token for new access token",
|
||||||
|
"POST /auth/logout - Revoke refresh token",
|
||||||
|
"POST /auth/logout-all - Revoke all refresh tokens (requires auth)",
|
||||||
|
"State parameter stored in Redis with short TTL for CSRF protection",
|
||||||
|
"Uses settings.base_url for absolute OAuth callback URLs"
|
||||||
|
],
|
||||||
|
"estimatedHours": 3
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-010",
|
||||||
|
"name": "Create user API router",
|
||||||
|
"description": "REST endpoints for user profile and account management",
|
||||||
|
"category": "high",
|
||||||
|
"priority": 10,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": ["AUTH-008", "AUTH-004"],
|
||||||
|
"files": [
|
||||||
|
{"path": "app/api/users.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"GET /users/me - Get current user profile",
|
||||||
|
"PATCH /users/me - Update display_name, avatar_url",
|
||||||
|
"GET /users/me/linked-accounts - List linked OAuth providers",
|
||||||
|
"DELETE /users/me/link/{provider} - Unlink OAuth provider",
|
||||||
|
"GET /users/me/sessions - Get active session count",
|
||||||
|
"All endpoints require authentication via CurrentUser dependency"
|
||||||
|
],
|
||||||
|
"estimatedHours": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-011",
|
||||||
|
"name": "Integrate routers in main.py",
|
||||||
|
"description": "Mount auth and user routers",
|
||||||
|
"category": "high",
|
||||||
|
"priority": 11,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": ["AUTH-009", "AUTH-010"],
|
||||||
|
"files": [
|
||||||
|
{"path": "app/main.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"Include auth router: prefix='/api'",
|
||||||
|
"Include users router: prefix='/api'"
|
||||||
|
],
|
||||||
|
"estimatedHours": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-012",
|
||||||
|
"name": "Create JWT service tests",
|
||||||
|
"description": "Unit tests for token creation and verification",
|
||||||
|
"category": "high",
|
||||||
|
"priority": 12,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": ["AUTH-002"],
|
||||||
|
"files": [
|
||||||
|
{"path": "tests/services/test_jwt_service.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"Test create_access_token returns valid JWT",
|
||||||
|
"Test create_refresh_token returns valid JWT with jti",
|
||||||
|
"Test verify_access_token extracts correct user_id",
|
||||||
|
"Test verify_refresh_token returns user_id and jti",
|
||||||
|
"Test expired tokens return None",
|
||||||
|
"Test invalid signatures return None",
|
||||||
|
"20 tests covering all token operations"
|
||||||
|
],
|
||||||
|
"estimatedHours": 1.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-013",
|
||||||
|
"name": "Create UserService tests",
|
||||||
|
"description": "Integration tests for user CRUD operations",
|
||||||
|
"category": "high",
|
||||||
|
"priority": 13,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": ["AUTH-004"],
|
||||||
|
"files": [
|
||||||
|
{"path": "tests/services/test_user_service.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"Test get_by_id returns user or None",
|
||||||
|
"Test get_by_email returns user or None",
|
||||||
|
"Test get_by_oauth finds by provider+oauth_id",
|
||||||
|
"Test create creates user with correct fields",
|
||||||
|
"Test create_from_oauth creates from OAuthUserInfo",
|
||||||
|
"Test get_or_create_from_oauth handles all scenarios",
|
||||||
|
"Test update updates profile fields",
|
||||||
|
"Test update_last_login updates timestamp",
|
||||||
|
"Test update_premium manages subscription status",
|
||||||
|
"Test link_oauth_account links new providers",
|
||||||
|
"Test unlink_oauth_account removes linked providers",
|
||||||
|
"29 tests using real Postgres via testcontainers"
|
||||||
|
],
|
||||||
|
"estimatedHours": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-014",
|
||||||
|
"name": "Create OAuth service tests",
|
||||||
|
"description": "Unit tests for OAuth flows with mocked HTTP",
|
||||||
|
"category": "high",
|
||||||
|
"priority": 14,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": ["AUTH-006", "AUTH-007"],
|
||||||
|
"files": [
|
||||||
|
{"path": "tests/services/oauth/test_google.py", "status": "complete"},
|
||||||
|
{"path": "tests/services/oauth/test_discord.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"Mock httpx responses for token exchange",
|
||||||
|
"Mock httpx responses for user info",
|
||||||
|
"Test authorization URL construction",
|
||||||
|
"Test error handling for invalid codes",
|
||||||
|
"Test OAuthUserInfo normalization",
|
||||||
|
"10 tests for Google, 14 tests for Discord",
|
||||||
|
"Uses respx for httpx mocking"
|
||||||
|
],
|
||||||
|
"estimatedHours": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "AUTH-015",
|
||||||
|
"name": "Create auth API endpoint tests",
|
||||||
|
"description": "Integration tests for auth endpoints",
|
||||||
|
"category": "high",
|
||||||
|
"priority": 15,
|
||||||
|
"completed": true,
|
||||||
|
"dependencies": ["AUTH-009", "AUTH-010"],
|
||||||
|
"files": [
|
||||||
|
{"path": "tests/api/__init__.py", "status": "complete"},
|
||||||
|
{"path": "tests/api/conftest.py", "status": "complete"},
|
||||||
|
{"path": "tests/api/test_auth.py", "status": "complete"},
|
||||||
|
{"path": "tests/api/test_users.py", "status": "complete"}
|
||||||
|
],
|
||||||
|
"details": [
|
||||||
|
"Test OAuth redirect returns correct URL",
|
||||||
|
"Test refresh endpoint returns new access token",
|
||||||
|
"Test logout revokes refresh token",
|
||||||
|
"Test /users/me returns current user",
|
||||||
|
"Test /users/me update works",
|
||||||
|
"Test /users/me/linked-accounts returns accounts",
|
||||||
|
"Test /users/me/sessions returns count",
|
||||||
|
"Test DELETE /users/me/link/{provider} unlinks account",
|
||||||
|
"10 tests for auth, 15 tests for users",
|
||||||
|
"Uses TestClient with dependency overrides and fakeredis"
|
||||||
|
],
|
||||||
|
"estimatedHours": 3
|
||||||
|
}
|
||||||
|
],
|
||||||
|
|
||||||
|
"testingStrategy": {
|
||||||
|
"approach": "Unit tests for services, integration tests for API endpoints",
|
||||||
|
"mocking": "httpx responses mocked with respx for OAuth providers, fakeredis for token store",
|
||||||
|
"database": "Real Postgres via testcontainers for service tests",
|
||||||
|
"coverage": "1072 total tests, 98 tests for auth system"
|
||||||
|
},
|
||||||
|
|
||||||
|
"acceptanceCriteria": [
|
||||||
|
{"criterion": "User can login with Google OAuth and receive JWT tokens", "met": true},
|
||||||
|
{"criterion": "User can login with Discord OAuth and receive JWT tokens", "met": true},
|
||||||
|
{"criterion": "Access tokens expire after configured time", "met": true},
|
||||||
|
{"criterion": "Refresh tokens can be used to get new access tokens", "met": true},
|
||||||
|
{"criterion": "Logout revokes refresh token (cannot be reused)", "met": true},
|
||||||
|
{"criterion": "Protected endpoints return 401 without valid token", "met": true},
|
||||||
|
{"criterion": "User can link multiple OAuth providers to one account", "met": true},
|
||||||
|
{"criterion": "Premium status is tracked with expiration date", "met": true},
|
||||||
|
{"criterion": "All tests pass with high coverage", "met": true}
|
||||||
|
],
|
||||||
|
|
||||||
|
"securityConsiderations": [
|
||||||
|
"OAuth state parameter validated to prevent CSRF attacks",
|
||||||
|
"OAuth provider tokens never stored (only OAuth ID)",
|
||||||
|
"JWT secret key loaded from environment, never hardcoded",
|
||||||
|
"Refresh tokens stored in Redis with TTL for revocation support",
|
||||||
|
"Access tokens short-lived (30 min) to limit exposure",
|
||||||
|
"OAuth callbacks use absolute URLs (base_url config setting)",
|
||||||
|
"HTTPS required in production for all auth endpoints"
|
||||||
|
],
|
||||||
|
|
||||||
|
"deferredItems": [
|
||||||
|
{
|
||||||
|
"item": "PKCE for OAuth",
|
||||||
|
"reason": "Not strictly required for server-side OAuth flow",
|
||||||
|
"priority": "low"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"item": "Rate limiting on auth endpoints",
|
||||||
|
"reason": "Can be added as infrastructure concern later",
|
||||||
|
"priority": "medium"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"item": "Refresh token rotation",
|
||||||
|
"reason": "Current implementation is secure; rotation adds complexity",
|
||||||
|
"priority": "low"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
|
||||||
|
"dependencies": {
|
||||||
|
"existing": [
|
||||||
|
"python-jose>=3.5.0 (already installed)",
|
||||||
|
"passlib>=1.7.4 (already installed)",
|
||||||
|
"bcrypt>=5.0.0 (already installed)"
|
||||||
|
],
|
||||||
|
"added": [
|
||||||
|
"email-validator (for Pydantic EmailStr)",
|
||||||
|
"fakeredis (dev - Redis mocking in tests)",
|
||||||
|
"respx (dev - httpx mocking in tests)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
|
||||||
|
"phase1Prerequisites": {
|
||||||
|
"met": [
|
||||||
|
"JWT configuration in Settings class",
|
||||||
|
"OAuth configuration in Settings class",
|
||||||
|
"User model with OAuth fields",
|
||||||
|
"python-jose dependency installed",
|
||||||
|
"Database session management",
|
||||||
|
"Redis connection utilities"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
|
||||||
|
"completionNotes": {
|
||||||
|
"totalNewTests": 98,
|
||||||
|
"totalTestsAfter": 1072,
|
||||||
|
"commit": "996c43f - Implement Phase 2: Authentication system",
|
||||||
|
"additionalCommitNeeded": "Fix OAuth absolute URLs and add account linking endpoints"
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -8,7 +8,9 @@ dependencies = [
|
|||||||
"alembic>=1.18.1",
|
"alembic>=1.18.1",
|
||||||
"asyncpg>=0.31.0",
|
"asyncpg>=0.31.0",
|
||||||
"bcrypt>=5.0.0",
|
"bcrypt>=5.0.0",
|
||||||
|
"email-validator>=2.3.0",
|
||||||
"fastapi>=0.128.0",
|
"fastapi>=0.128.0",
|
||||||
|
"httpx>=0.28.1",
|
||||||
"passlib>=1.7.4",
|
"passlib>=1.7.4",
|
||||||
"psycopg2-binary>=2.9.11",
|
"psycopg2-binary>=2.9.11",
|
||||||
"pydantic>=2.12.5",
|
"pydantic>=2.12.5",
|
||||||
@ -24,12 +26,14 @@ dependencies = [
|
|||||||
dev = [
|
dev = [
|
||||||
"beautifulsoup4>=4.12.0",
|
"beautifulsoup4>=4.12.0",
|
||||||
"black>=26.1.0",
|
"black>=26.1.0",
|
||||||
|
"fakeredis>=2.33.0",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
"mypy>=1.19.1",
|
"mypy>=1.19.1",
|
||||||
"pytest>=9.0.2",
|
"pytest>=9.0.2",
|
||||||
"pytest-asyncio>=1.3.0",
|
"pytest-asyncio>=1.3.0",
|
||||||
"pytest-cov>=7.0.0",
|
"pytest-cov>=7.0.0",
|
||||||
"requests>=2.31.0",
|
"requests>=2.31.0",
|
||||||
|
"respx>=0.22.0",
|
||||||
"ruff>=0.14.14",
|
"ruff>=0.14.14",
|
||||||
"testcontainers[postgres,redis]>=4.0.0",
|
"testcontainers[postgres,redis]>=4.0.0",
|
||||||
]
|
]
|
||||||
|
|||||||
1
backend/tests/api/__init__.py
Normal file
1
backend/tests/api/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""API endpoint tests."""
|
||||||
133
backend/tests/api/conftest.py
Normal file
133
backend/tests/api/conftest.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
"""Test fixtures for API endpoint tests.
|
||||||
|
|
||||||
|
Provides fixtures for testing FastAPI endpoints with mocked dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import fakeredis.aioredis
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api import deps as api_deps
|
||||||
|
from app.api.auth import router as auth_router
|
||||||
|
from app.api.users import router as users_router
|
||||||
|
from app.db.models import User
|
||||||
|
from app.services.jwt_service import create_access_token, create_refresh_token
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_redis():
|
||||||
|
"""Provide a fake Redis instance for testing."""
|
||||||
|
return fakeredis.aioredis.FakeRedis(decode_responses=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_get_redis(fake_redis):
|
||||||
|
"""Mock the get_redis context manager to use fake Redis."""
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _mock_get_redis():
|
||||||
|
yield fake_redis
|
||||||
|
|
||||||
|
return _mock_get_redis
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_user():
|
||||||
|
"""Create a test user object.
|
||||||
|
|
||||||
|
Returns a User model instance that can be used in tests.
|
||||||
|
The user is not persisted to database.
|
||||||
|
"""
|
||||||
|
user = User(
|
||||||
|
email="test@example.com",
|
||||||
|
display_name="Test User",
|
||||||
|
avatar_url="https://example.com/avatar.jpg",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-123",
|
||||||
|
is_premium=False,
|
||||||
|
premium_until=None,
|
||||||
|
)
|
||||||
|
# Manually set the ID since we're not using database
|
||||||
|
user.id = str(uuid4())
|
||||||
|
user.created_at = datetime.now(UTC)
|
||||||
|
user.updated_at = datetime.now(UTC)
|
||||||
|
user.last_login = None
|
||||||
|
user.linked_accounts = []
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def premium_user(test_user):
|
||||||
|
"""Create a premium test user."""
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
test_user.is_premium = True
|
||||||
|
test_user.premium_until = datetime.now(UTC) + timedelta(days=30)
|
||||||
|
return test_user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def access_token(test_user):
|
||||||
|
"""Create a valid access token for the test user."""
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
|
||||||
|
return create_access_token(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def refresh_token_data(test_user):
|
||||||
|
"""Create a valid refresh token and JTI for the test user."""
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
|
||||||
|
token, jti = create_refresh_token(user_id)
|
||||||
|
return {"token": token, "jti": jti, "user_id": user_id}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db_session():
|
||||||
|
"""Create a mock database session."""
|
||||||
|
return MagicMock(spec=AsyncSession)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(mock_get_redis, mock_db_session):
|
||||||
|
"""Create a test FastAPI app with mocked Redis and DB.
|
||||||
|
|
||||||
|
This creates a minimal app with just the auth and users routers,
|
||||||
|
with Redis and database mocked.
|
||||||
|
"""
|
||||||
|
# Create test app (no lifespan since we're mocking everything)
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(auth_router, prefix="/api")
|
||||||
|
test_app.include_router(users_router, prefix="/api")
|
||||||
|
|
||||||
|
# Override get_db dependency to return mock session
|
||||||
|
async def override_get_db():
|
||||||
|
yield mock_db_session
|
||||||
|
|
||||||
|
test_app.dependency_overrides[api_deps.get_db] = override_get_db
|
||||||
|
|
||||||
|
# Patch get_redis globally for this app
|
||||||
|
with (
|
||||||
|
patch("app.api.auth.get_redis", mock_get_redis),
|
||||||
|
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||||
|
):
|
||||||
|
yield test_app
|
||||||
|
|
||||||
|
# Clean up overrides
|
||||||
|
test_app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(app):
|
||||||
|
"""Create a test client for the app."""
|
||||||
|
return TestClient(app)
|
||||||
260
backend/tests/api/test_auth.py
Normal file
260
backend/tests/api/test_auth.py
Normal file
@ -0,0 +1,260 @@
|
|||||||
|
"""Tests for auth API endpoints.
|
||||||
|
|
||||||
|
Tests the authentication endpoints including OAuth redirects,
|
||||||
|
token refresh, and logout.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import status
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
class TestGoogleAuthRedirect:
|
||||||
|
"""Tests for GET /api/auth/google endpoint."""
|
||||||
|
|
||||||
|
def test_returns_501_when_not_configured(self, client: TestClient):
|
||||||
|
"""Test that endpoint returns 501 when Google OAuth is not configured.
|
||||||
|
|
||||||
|
Without client credentials, OAuth flow cannot proceed.
|
||||||
|
"""
|
||||||
|
with patch("app.api.auth.google_oauth") as mock_oauth:
|
||||||
|
mock_oauth.is_configured.return_value = False
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/api/auth/google",
|
||||||
|
params={"redirect_uri": "http://localhost/callback"},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_501_NOT_IMPLEMENTED
|
||||||
|
assert "not configured" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiscordAuthRedirect:
|
||||||
|
"""Tests for GET /api/auth/discord endpoint."""
|
||||||
|
|
||||||
|
def test_returns_501_when_not_configured(self, client: TestClient):
|
||||||
|
"""Test that endpoint returns 501 when Discord OAuth is not configured."""
|
||||||
|
with patch("app.api.auth.discord_oauth") as mock_oauth:
|
||||||
|
mock_oauth.is_configured.return_value = False
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/api/auth/discord",
|
||||||
|
params={"redirect_uri": "http://localhost/callback"},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_501_NOT_IMPLEMENTED
|
||||||
|
assert "not configured" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestRefreshTokens:
|
||||||
|
"""Tests for POST /api/auth/refresh endpoint."""
|
||||||
|
|
||||||
|
def test_returns_new_access_token(
|
||||||
|
self, client: TestClient, test_user, refresh_token_data, mock_get_redis
|
||||||
|
):
|
||||||
|
"""Test that refresh endpoint returns new access token for valid refresh token.
|
||||||
|
|
||||||
|
A valid, non-revoked refresh token should yield a new access token.
|
||||||
|
"""
|
||||||
|
# Store the refresh token in fake Redis
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def setup_token():
|
||||||
|
async with mock_get_redis() as redis:
|
||||||
|
key = f"refresh_token:{refresh_token_data['user_id']}:{refresh_token_data['jti']}"
|
||||||
|
await redis.setex(key, 86400, "1")
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(setup_token())
|
||||||
|
|
||||||
|
# Mock user service to return our test user
|
||||||
|
with patch("app.api.auth.user_service") as mock_user_service:
|
||||||
|
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.api.auth.get_redis", mock_get_redis),
|
||||||
|
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||||
|
):
|
||||||
|
response = client.post(
|
||||||
|
"/api/auth/refresh",
|
||||||
|
json={"refresh_token": refresh_token_data["token"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert data["refresh_token"] == refresh_token_data["token"]
|
||||||
|
assert data["token_type"] == "bearer"
|
||||||
|
assert "expires_in" in data
|
||||||
|
|
||||||
|
def test_returns_401_for_invalid_token(self, client: TestClient):
|
||||||
|
"""Test that refresh endpoint returns 401 for invalid refresh token."""
|
||||||
|
response = client.post(
|
||||||
|
"/api/auth/refresh",
|
||||||
|
json={"refresh_token": "invalid.token.here"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_returns_401_for_revoked_token(
|
||||||
|
self, client: TestClient, refresh_token_data, mock_get_redis
|
||||||
|
):
|
||||||
|
"""Test that refresh endpoint returns 401 for revoked token.
|
||||||
|
|
||||||
|
A refresh token not in Redis (revoked/expired) should be rejected.
|
||||||
|
"""
|
||||||
|
# Don't store the token in Redis - simulating revocation
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.api.auth.get_redis", mock_get_redis),
|
||||||
|
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||||
|
):
|
||||||
|
response = client.post(
|
||||||
|
"/api/auth/refresh",
|
||||||
|
json={"refresh_token": refresh_token_data["token"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
assert "revoked" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_returns_401_for_deleted_user(
|
||||||
|
self, client: TestClient, refresh_token_data, mock_get_redis
|
||||||
|
):
|
||||||
|
"""Test that refresh endpoint returns 401 if user no longer exists."""
|
||||||
|
# Store the token
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def setup_token():
|
||||||
|
async with mock_get_redis() as redis:
|
||||||
|
key = f"refresh_token:{refresh_token_data['user_id']}:{refresh_token_data['jti']}"
|
||||||
|
await redis.setex(key, 86400, "1")
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(setup_token())
|
||||||
|
|
||||||
|
# Mock user service to return None (user deleted)
|
||||||
|
with patch("app.api.auth.user_service") as mock_user_service:
|
||||||
|
mock_user_service.get_by_id = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.api.auth.get_redis", mock_get_redis),
|
||||||
|
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||||
|
):
|
||||||
|
response = client.post(
|
||||||
|
"/api/auth/refresh",
|
||||||
|
json={"refresh_token": refresh_token_data["token"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
assert "User not found" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestLogout:
|
||||||
|
"""Tests for POST /api/auth/logout endpoint."""
|
||||||
|
|
||||||
|
def test_revokes_token(self, client: TestClient, refresh_token_data, mock_get_redis):
|
||||||
|
"""Test that logout revokes the refresh token.
|
||||||
|
|
||||||
|
After logout, the token should no longer be in Redis.
|
||||||
|
"""
|
||||||
|
# Store the token first
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def setup_and_check():
|
||||||
|
async with mock_get_redis() as redis:
|
||||||
|
key = f"refresh_token:{refresh_token_data['user_id']}:{refresh_token_data['jti']}"
|
||||||
|
await redis.setex(key, 86400, "1")
|
||||||
|
return key
|
||||||
|
|
||||||
|
key = asyncio.get_event_loop().run_until_complete(setup_and_check())
|
||||||
|
|
||||||
|
# Logout
|
||||||
|
with (
|
||||||
|
patch("app.api.auth.get_redis", mock_get_redis),
|
||||||
|
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||||
|
):
|
||||||
|
response = client.post(
|
||||||
|
"/api/auth/logout",
|
||||||
|
json={"refresh_token": refresh_token_data["token"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
# Verify token is gone
|
||||||
|
async def verify_deleted():
|
||||||
|
async with mock_get_redis() as redis:
|
||||||
|
return await redis.exists(key)
|
||||||
|
|
||||||
|
exists = asyncio.get_event_loop().run_until_complete(verify_deleted())
|
||||||
|
assert exists == 0
|
||||||
|
|
||||||
|
def test_succeeds_for_invalid_token(self, client: TestClient):
|
||||||
|
"""Test that logout succeeds even for invalid tokens.
|
||||||
|
|
||||||
|
Invalid tokens are effectively "already logged out", so no error.
|
||||||
|
"""
|
||||||
|
response = client.post(
|
||||||
|
"/api/auth/logout",
|
||||||
|
json={"refresh_token": "invalid.token.here"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
|
||||||
|
class TestLogoutAll:
|
||||||
|
"""Tests for POST /api/auth/logout-all endpoint."""
|
||||||
|
|
||||||
|
def test_requires_authentication(self, client: TestClient):
|
||||||
|
"""Test that logout-all requires a valid access token.
|
||||||
|
|
||||||
|
Without authentication, endpoint should return 401.
|
||||||
|
"""
|
||||||
|
response = client.post("/api/auth/logout-all")
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_revokes_all_tokens(self, client: TestClient, test_user, access_token, mock_get_redis):
|
||||||
|
"""Test that logout-all revokes all refresh tokens for user.
|
||||||
|
|
||||||
|
Should delete all tokens matching the user's ID pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
|
||||||
|
|
||||||
|
# Store multiple tokens
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def setup_tokens():
|
||||||
|
async with mock_get_redis() as redis:
|
||||||
|
await redis.setex(f"refresh_token:{user_id}:jti-1", 86400, "1")
|
||||||
|
await redis.setex(f"refresh_token:{user_id}:jti-2", 86400, "1")
|
||||||
|
await redis.setex(f"refresh_token:{user_id}:jti-3", 86400, "1")
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
||||||
|
|
||||||
|
# Mock dependencies
|
||||||
|
with patch("app.api.deps.user_service") as mock_user_service:
|
||||||
|
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.api.auth.get_redis", mock_get_redis),
|
||||||
|
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||||
|
):
|
||||||
|
response = client.post(
|
||||||
|
"/api/auth/logout-all",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
# Verify all tokens are gone
|
||||||
|
async def count_remaining():
|
||||||
|
async with mock_get_redis() as redis:
|
||||||
|
count = 0
|
||||||
|
async for _ in redis.scan_iter(match=f"refresh_token:{user_id}:*"):
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
remaining = asyncio.get_event_loop().run_until_complete(count_remaining())
|
||||||
|
assert remaining == 0
|
||||||
259
backend/tests/api/test_users.py
Normal file
259
backend/tests/api/test_users.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
"""Tests for users API endpoints.
|
||||||
|
|
||||||
|
Tests the user profile management endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import status
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.services.user_service import AccountLinkingError
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCurrentUser:
|
||||||
|
"""Tests for GET /api/users/me endpoint."""
|
||||||
|
|
||||||
|
def test_returns_user_profile(self, client: TestClient, test_user, access_token):
|
||||||
|
"""Test that endpoint returns user profile for authenticated user.
|
||||||
|
|
||||||
|
Should return the user's profile information.
|
||||||
|
"""
|
||||||
|
with patch("app.api.deps.user_service") as mock_user_service:
|
||||||
|
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/api/users/me",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["email"] == test_user.email
|
||||||
|
assert data["display_name"] == test_user.display_name
|
||||||
|
assert data["avatar_url"] == test_user.avatar_url
|
||||||
|
assert data["is_premium"] == test_user.is_premium
|
||||||
|
|
||||||
|
def test_requires_authentication(self, client: TestClient):
|
||||||
|
"""Test that endpoint returns 401 without authentication."""
|
||||||
|
response = client.get("/api/users/me")
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_returns_401_for_invalid_token(self, client: TestClient):
|
||||||
|
"""Test that endpoint returns 401 for invalid access token."""
|
||||||
|
response = client.get(
|
||||||
|
"/api/users/me",
|
||||||
|
headers={"Authorization": "Bearer invalid.token.here"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateCurrentUser:
|
||||||
|
"""Tests for PATCH /api/users/me endpoint."""
|
||||||
|
|
||||||
|
def test_updates_display_name(self, client: TestClient, test_user, access_token):
|
||||||
|
"""Test that endpoint updates display_name when provided."""
|
||||||
|
updated_user = test_user
|
||||||
|
updated_user.display_name = "New Name"
|
||||||
|
|
||||||
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||||
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||||
|
|
||||||
|
with patch("app.api.users.user_service") as mock_user_service:
|
||||||
|
mock_user_service.update = AsyncMock(return_value=updated_user)
|
||||||
|
|
||||||
|
response = client.patch(
|
||||||
|
"/api/users/me",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json={"display_name": "New Name"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["display_name"] == "New Name"
|
||||||
|
|
||||||
|
def test_updates_avatar_url(self, client: TestClient, test_user, access_token):
|
||||||
|
"""Test that endpoint updates avatar_url when provided."""
|
||||||
|
updated_user = test_user
|
||||||
|
updated_user.avatar_url = "https://new-avatar.com/img.jpg"
|
||||||
|
|
||||||
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||||
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||||
|
|
||||||
|
with patch("app.api.users.user_service") as mock_user_service:
|
||||||
|
mock_user_service.update = AsyncMock(return_value=updated_user)
|
||||||
|
|
||||||
|
response = client.patch(
|
||||||
|
"/api/users/me",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json={"avatar_url": "https://new-avatar.com/img.jpg"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
|
||||||
|
|
||||||
|
def test_requires_authentication(self, client: TestClient):
|
||||||
|
"""Test that endpoint returns 401 without authentication."""
|
||||||
|
response = client.patch(
|
||||||
|
"/api/users/me",
|
||||||
|
json={"display_name": "New Name"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLinkedAccounts:
|
||||||
|
"""Tests for GET /api/users/me/linked-accounts endpoint."""
|
||||||
|
|
||||||
|
def test_returns_linked_accounts(self, client: TestClient, test_user, access_token):
|
||||||
|
"""Test that endpoint returns list of linked OAuth accounts.
|
||||||
|
|
||||||
|
Should include the primary provider and any linked accounts.
|
||||||
|
"""
|
||||||
|
with patch("app.api.deps.user_service") as mock_user_service:
|
||||||
|
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/api/users/me/linked-accounts",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) >= 1 # At least primary account
|
||||||
|
assert data[0]["provider"] == test_user.oauth_provider
|
||||||
|
|
||||||
|
def test_requires_authentication(self, client: TestClient):
|
||||||
|
"""Test that endpoint returns 401 without authentication."""
|
||||||
|
response = client.get("/api/users/me/linked-accounts")
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetActiveSessions:
|
||||||
|
"""Tests for GET /api/users/me/sessions endpoint."""
|
||||||
|
|
||||||
|
def test_returns_session_count(
|
||||||
|
self, client: TestClient, test_user, access_token, mock_get_redis
|
||||||
|
):
|
||||||
|
"""Test that endpoint returns count of active sessions.
|
||||||
|
|
||||||
|
Should return the number of valid refresh tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
|
||||||
|
|
||||||
|
# Store some tokens
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def setup_tokens():
|
||||||
|
async with mock_get_redis() as redis:
|
||||||
|
await redis.setex(f"refresh_token:{user_id}:jti-1", 86400, "1")
|
||||||
|
await redis.setex(f"refresh_token:{user_id}:jti-2", 86400, "1")
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
||||||
|
|
||||||
|
with patch("app.api.deps.user_service") as mock_user_service:
|
||||||
|
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||||
|
|
||||||
|
with patch("app.services.token_store.get_redis", mock_get_redis):
|
||||||
|
response = client.get(
|
||||||
|
"/api/users/me/sessions",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert "active_sessions" in data
|
||||||
|
assert data["active_sessions"] == 2
|
||||||
|
|
||||||
|
def test_requires_authentication(self, client: TestClient):
|
||||||
|
"""Test that endpoint returns 401 without authentication."""
|
||||||
|
response = client.get("/api/users/me/sessions")
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnlinkOAuthAccount:
|
||||||
|
"""Tests for DELETE /api/users/me/link/{provider} endpoint."""
|
||||||
|
|
||||||
|
def test_unlinks_provider_successfully(self, client: TestClient, test_user, access_token):
|
||||||
|
"""Test that endpoint successfully unlinks a provider.
|
||||||
|
|
||||||
|
Should return 204 when provider is unlinked.
|
||||||
|
"""
|
||||||
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||||
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||||
|
|
||||||
|
with patch("app.api.users.user_service") as mock_user_service:
|
||||||
|
mock_user_service.unlink_oauth_account = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
response = client.delete(
|
||||||
|
"/api/users/me/link/discord",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
def test_returns_404_if_not_linked(self, client: TestClient, test_user, access_token):
|
||||||
|
"""Test that endpoint returns 404 if provider isn't linked.
|
||||||
|
|
||||||
|
Should return 404 when trying to unlink a provider that isn't linked.
|
||||||
|
"""
|
||||||
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||||
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||||
|
|
||||||
|
with patch("app.api.users.user_service") as mock_user_service:
|
||||||
|
mock_user_service.unlink_oauth_account = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
response = client.delete(
|
||||||
|
"/api/users/me/link/discord",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "not linked" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
def test_returns_400_for_primary_provider(self, client: TestClient, test_user, access_token):
|
||||||
|
"""Test that endpoint returns 400 when trying to unlink primary provider.
|
||||||
|
|
||||||
|
Cannot unlink the provider used to create the account.
|
||||||
|
"""
|
||||||
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||||
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||||
|
|
||||||
|
with patch("app.api.users.user_service") as mock_user_service:
|
||||||
|
mock_user_service.unlink_oauth_account = AsyncMock(
|
||||||
|
side_effect=AccountLinkingError(
|
||||||
|
"Cannot unlink Google - it is your primary login provider"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.delete(
|
||||||
|
"/api/users/me/link/google",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert "primary" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
def test_returns_400_for_unknown_provider(self, client: TestClient, test_user, access_token):
|
||||||
|
"""Test that endpoint returns 400 for unknown provider.
|
||||||
|
|
||||||
|
Only 'google' and 'discord' are valid providers.
|
||||||
|
"""
|
||||||
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||||
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||||
|
|
||||||
|
response = client.delete(
|
||||||
|
"/api/users/me/link/twitter",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert "unknown provider" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
def test_requires_authentication(self, client: TestClient):
|
||||||
|
"""Test that endpoint returns 401 without authentication."""
|
||||||
|
response = client.delete("/api/users/me/link/discord")
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
@ -70,6 +70,7 @@ TABLES_TO_TRUNCATE = [
|
|||||||
"campaign_progress",
|
"campaign_progress",
|
||||||
"collections",
|
"collections",
|
||||||
"decks",
|
"decks",
|
||||||
|
"oauth_linked_accounts",
|
||||||
"users",
|
"users",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
1
backend/tests/services/oauth/__init__.py
Normal file
1
backend/tests/services/oauth/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""OAuth service tests."""
|
||||||
321
backend/tests/services/oauth/test_discord.py
Normal file
321
backend/tests/services/oauth/test_discord.py
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
"""Tests for Discord OAuth service.
|
||||||
|
|
||||||
|
Tests the Discord OAuth flow with mocked HTTP responses using respx.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
from httpx import Response
|
||||||
|
|
||||||
|
from app.services.oauth.discord import DiscordOAuth, DiscordOAuthError
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAuthorizationUrl:
|
||||||
|
"""Tests for get_authorization_url method."""
|
||||||
|
|
||||||
|
def test_raises_when_not_configured(self):
|
||||||
|
"""Test that get_authorization_url raises when Discord OAuth is not configured.
|
||||||
|
|
||||||
|
Without client ID, the method should raise DiscordOAuthError.
|
||||||
|
"""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
|
||||||
|
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||||
|
mock_settings.discord_client_id = None
|
||||||
|
|
||||||
|
with pytest.raises(DiscordOAuthError, match="not configured"):
|
||||||
|
oauth.get_authorization_url("http://localhost/callback", "state123")
|
||||||
|
|
||||||
|
def test_returns_valid_url_when_configured(self):
|
||||||
|
"""Test that get_authorization_url returns properly formatted URL.
|
||||||
|
|
||||||
|
The URL should include client_id, redirect_uri, state, and scopes.
|
||||||
|
"""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
|
||||||
|
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||||
|
mock_settings.discord_client_id = "test-client-id"
|
||||||
|
mock_settings.discord_client_secret = "test-secret"
|
||||||
|
|
||||||
|
url = oauth.get_authorization_url("http://localhost/callback", "state123")
|
||||||
|
|
||||||
|
assert "discord.com/api/oauth2/authorize" in url
|
||||||
|
assert "client_id=test-client-id" in url
|
||||||
|
assert "redirect_uri=http" in url
|
||||||
|
assert "state=state123" in url
|
||||||
|
assert "scope=" in url
|
||||||
|
assert "response_type=code" in url
|
||||||
|
|
||||||
|
|
||||||
|
class TestExchangeCodeForTokens:
|
||||||
|
"""Tests for exchange_code_for_tokens method."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_returns_tokens_on_success(self):
|
||||||
|
"""Test that exchange_code_for_tokens returns tokens on success.
|
||||||
|
|
||||||
|
Mocks Discord's token endpoint to return valid tokens.
|
||||||
|
"""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
|
||||||
|
respx.post("https://discord.com/api/oauth2/token").mock(
|
||||||
|
return_value=Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"access_token": "test-access-token",
|
||||||
|
"refresh_token": "test-refresh-token",
|
||||||
|
"expires_in": 604800,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||||
|
mock_settings.discord_client_id = "test-client-id"
|
||||||
|
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
|
||||||
|
|
||||||
|
tokens = await oauth.exchange_code_for_tokens("auth-code", "http://localhost/callback")
|
||||||
|
|
||||||
|
assert tokens["access_token"] == "test-access-token"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_raises_on_error_response(self):
|
||||||
|
"""Test that exchange_code_for_tokens raises on error from Discord.
|
||||||
|
|
||||||
|
If Discord returns an error, DiscordOAuthError should be raised.
|
||||||
|
"""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
|
||||||
|
respx.post("https://discord.com/api/oauth2/token").mock(
|
||||||
|
return_value=Response(
|
||||||
|
400,
|
||||||
|
json={
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": "Invalid code",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||||
|
mock_settings.discord_client_id = "test-client-id"
|
||||||
|
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
|
||||||
|
|
||||||
|
with pytest.raises(DiscordOAuthError, match="Token exchange failed"):
|
||||||
|
await oauth.exchange_code_for_tokens("invalid-code", "http://localhost/callback")
|
||||||
|
|
||||||
|
|
||||||
|
class TestFetchUserInfo:
|
||||||
|
"""Tests for fetch_user_info method."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_returns_user_info_on_success(self):
|
||||||
|
"""Test that fetch_user_info returns user data from Discord.
|
||||||
|
|
||||||
|
Mocks Discord's users/@me endpoint.
|
||||||
|
"""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
|
||||||
|
respx.get("https://discord.com/api/users/@me").mock(
|
||||||
|
return_value=Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"id": "discord-user-123",
|
||||||
|
"username": "testuser",
|
||||||
|
"global_name": "Test User",
|
||||||
|
"email": "user@discord.com",
|
||||||
|
"avatar": "abc123",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
user_info = await oauth.fetch_user_info("test-access-token")
|
||||||
|
|
||||||
|
assert user_info["id"] == "discord-user-123"
|
||||||
|
assert user_info["email"] == "user@discord.com"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_raises_on_error_response(self):
|
||||||
|
"""Test that fetch_user_info raises on error from Discord."""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
|
||||||
|
respx.get("https://discord.com/api/users/@me").mock(
|
||||||
|
return_value=Response(401, json={"message": "401: Unauthorized"})
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(DiscordOAuthError, match="Failed to fetch user info"):
|
||||||
|
await oauth.fetch_user_info("invalid-token")
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildAvatarUrl:
|
||||||
|
"""Tests for _build_avatar_url method."""
|
||||||
|
|
||||||
|
def test_returns_none_for_no_avatar(self):
|
||||||
|
"""Test that _build_avatar_url returns None when avatar is None."""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
result = oauth._build_avatar_url("123456", None)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_builds_png_url_for_static_avatar(self):
|
||||||
|
"""Test that _build_avatar_url builds PNG URL for static avatars."""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
result = oauth._build_avatar_url("123456", "abcdef123")
|
||||||
|
|
||||||
|
assert result == "https://cdn.discordapp.com/avatars/123456/abcdef123.png"
|
||||||
|
|
||||||
|
def test_builds_gif_url_for_animated_avatar(self):
|
||||||
|
"""Test that _build_avatar_url builds GIF URL for animated avatars.
|
||||||
|
|
||||||
|
Animated avatars start with 'a_' prefix.
|
||||||
|
"""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
result = oauth._build_avatar_url("123456", "a_animated123")
|
||||||
|
|
||||||
|
assert result == "https://cdn.discordapp.com/avatars/123456/a_animated123.gif"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUserInfo:
|
||||||
|
"""Tests for get_user_info method (full flow)."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_returns_oauth_user_info_on_success(self):
|
||||||
|
"""Test that get_user_info completes full OAuth flow.
|
||||||
|
|
||||||
|
This tests the combined token exchange + user info fetch.
|
||||||
|
"""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
|
||||||
|
# Mock token exchange
|
||||||
|
respx.post("https://discord.com/api/oauth2/token").mock(
|
||||||
|
return_value=Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"access_token": "test-access-token",
|
||||||
|
"refresh_token": "test-refresh",
|
||||||
|
"expires_in": 604800,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock user info
|
||||||
|
respx.get("https://discord.com/api/users/@me").mock(
|
||||||
|
return_value=Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"id": "discord-user-456",
|
||||||
|
"username": "fullflowuser",
|
||||||
|
"global_name": "Full Flow User",
|
||||||
|
"email": "fullflow@discord.com",
|
||||||
|
"avatar": "avatar123",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||||
|
mock_settings.discord_client_id = "test-client-id"
|
||||||
|
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
|
||||||
|
|
||||||
|
result = await oauth.get_user_info("auth-code", "http://localhost/callback")
|
||||||
|
|
||||||
|
assert result.provider == "discord"
|
||||||
|
assert result.oauth_id == "discord-user-456"
|
||||||
|
assert result.email == "fullflow@discord.com"
|
||||||
|
assert result.name == "Full Flow User"
|
||||||
|
assert "avatar123.png" in result.avatar_url
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_uses_username_when_no_global_name(self):
|
||||||
|
"""Test that get_user_info falls back to username for display name.
|
||||||
|
|
||||||
|
Discord users may not have global_name set.
|
||||||
|
"""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
|
||||||
|
respx.post("https://discord.com/api/oauth2/token").mock(
|
||||||
|
return_value=Response(
|
||||||
|
200,
|
||||||
|
json={"access_token": "test-access-token"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
respx.get("https://discord.com/api/users/@me").mock(
|
||||||
|
return_value=Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"id": "discord-user-789",
|
||||||
|
"username": "legacyuser",
|
||||||
|
"global_name": None, # No global name
|
||||||
|
"email": "legacy@discord.com",
|
||||||
|
"avatar": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||||
|
mock_settings.discord_client_id = "test-client-id"
|
||||||
|
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
|
||||||
|
|
||||||
|
result = await oauth.get_user_info("auth-code", "http://localhost/callback")
|
||||||
|
|
||||||
|
assert result.name == "legacyuser"
|
||||||
|
assert result.avatar_url is None
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_raises_when_no_email(self):
|
||||||
|
"""Test that get_user_info raises when Discord user has no email.
|
||||||
|
|
||||||
|
Email is required for account creation.
|
||||||
|
"""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
|
||||||
|
respx.post("https://discord.com/api/oauth2/token").mock(
|
||||||
|
return_value=Response(
|
||||||
|
200,
|
||||||
|
json={"access_token": "test-access-token"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
respx.get("https://discord.com/api/users/@me").mock(
|
||||||
|
return_value=Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"id": "discord-user-noemail",
|
||||||
|
"username": "noemailuser",
|
||||||
|
"email": None, # No verified email
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||||
|
mock_settings.discord_client_id = "test-client-id"
|
||||||
|
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
|
||||||
|
|
||||||
|
with pytest.raises(DiscordOAuthError, match="verified email"):
|
||||||
|
await oauth.get_user_info("auth-code", "http://localhost/callback")
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsConfigured:
|
||||||
|
"""Tests for is_configured method."""
|
||||||
|
|
||||||
|
def test_returns_false_when_not_configured(self):
|
||||||
|
"""Test that is_configured returns False without credentials."""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
|
||||||
|
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||||
|
mock_settings.discord_client_id = None
|
||||||
|
mock_settings.discord_client_secret = None
|
||||||
|
|
||||||
|
assert oauth.is_configured() is False
|
||||||
|
|
||||||
|
def test_returns_true_when_configured(self):
|
||||||
|
"""Test that is_configured returns True with credentials."""
|
||||||
|
oauth = DiscordOAuth()
|
||||||
|
|
||||||
|
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||||
|
mock_settings.discord_client_id = "client-id"
|
||||||
|
mock_settings.discord_client_secret = "client-secret"
|
||||||
|
|
||||||
|
assert oauth.is_configured() is True
|
||||||
241
backend/tests/services/oauth/test_google.py
Normal file
241
backend/tests/services/oauth/test_google.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
"""Tests for Google OAuth service.
|
||||||
|
|
||||||
|
Tests the Google OAuth flow with mocked HTTP responses using respx.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
from httpx import Response
|
||||||
|
|
||||||
|
from app.services.oauth.google import GoogleOAuth, GoogleOAuthError
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAuthorizationUrl:
|
||||||
|
"""Tests for get_authorization_url method."""
|
||||||
|
|
||||||
|
def test_raises_when_not_configured(self):
|
||||||
|
"""Test that get_authorization_url raises when Google OAuth is not configured.
|
||||||
|
|
||||||
|
Without client ID, the method should raise GoogleOAuthError.
|
||||||
|
"""
|
||||||
|
oauth = GoogleOAuth()
|
||||||
|
|
||||||
|
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||||
|
mock_settings.google_client_id = None
|
||||||
|
|
||||||
|
with pytest.raises(GoogleOAuthError, match="not configured"):
|
||||||
|
oauth.get_authorization_url("http://localhost/callback", "state123")
|
||||||
|
|
||||||
|
def test_returns_valid_url_when_configured(self):
|
||||||
|
"""Test that get_authorization_url returns properly formatted URL.
|
||||||
|
|
||||||
|
The URL should include client_id, redirect_uri, state, and scopes.
|
||||||
|
"""
|
||||||
|
oauth = GoogleOAuth()
|
||||||
|
|
||||||
|
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||||
|
mock_settings.google_client_id = "test-client-id"
|
||||||
|
mock_settings.google_client_secret = "test-secret"
|
||||||
|
|
||||||
|
url = oauth.get_authorization_url("http://localhost/callback", "state123")
|
||||||
|
|
||||||
|
assert "accounts.google.com/o/oauth2/v2/auth" in url
|
||||||
|
assert "client_id=test-client-id" in url
|
||||||
|
assert "redirect_uri=http" in url
|
||||||
|
assert "state=state123" in url
|
||||||
|
assert "scope=" in url
|
||||||
|
assert "response_type=code" in url
|
||||||
|
|
||||||
|
|
||||||
|
class TestExchangeCodeForTokens:
|
||||||
|
"""Tests for exchange_code_for_tokens method."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_returns_tokens_on_success(self):
|
||||||
|
"""Test that exchange_code_for_tokens returns tokens on success.
|
||||||
|
|
||||||
|
Mocks Google's token endpoint to return valid tokens.
|
||||||
|
"""
|
||||||
|
oauth = GoogleOAuth()
|
||||||
|
|
||||||
|
respx.post("https://oauth2.googleapis.com/token").mock(
|
||||||
|
return_value=Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"access_token": "test-access-token",
|
||||||
|
"id_token": "test-id-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||||
|
mock_settings.google_client_id = "test-client-id"
|
||||||
|
mock_settings.google_client_secret.get_secret_value.return_value = "test-secret"
|
||||||
|
|
||||||
|
tokens = await oauth.exchange_code_for_tokens("auth-code", "http://localhost/callback")
|
||||||
|
|
||||||
|
assert tokens["access_token"] == "test-access-token"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_raises_on_error_response(self):
|
||||||
|
"""Test that exchange_code_for_tokens raises on error from Google.
|
||||||
|
|
||||||
|
If Google returns an error, GoogleOAuthError should be raised.
|
||||||
|
"""
|
||||||
|
oauth = GoogleOAuth()
|
||||||
|
|
||||||
|
respx.post("https://oauth2.googleapis.com/token").mock(
|
||||||
|
return_value=Response(
|
||||||
|
400,
|
||||||
|
json={
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": "Code has expired",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||||
|
mock_settings.google_client_id = "test-client-id"
|
||||||
|
mock_settings.google_client_secret.get_secret_value.return_value = "test-secret"
|
||||||
|
|
||||||
|
with pytest.raises(GoogleOAuthError, match="Token exchange failed"):
|
||||||
|
await oauth.exchange_code_for_tokens("expired-code", "http://localhost/callback")
|
||||||
|
|
||||||
|
|
||||||
|
class TestFetchUserInfo:
|
||||||
|
"""Tests for fetch_user_info method."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_returns_user_info_on_success(self):
|
||||||
|
"""Test that fetch_user_info returns user data from Google.
|
||||||
|
|
||||||
|
Mocks Google's userinfo endpoint.
|
||||||
|
"""
|
||||||
|
oauth = GoogleOAuth()
|
||||||
|
|
||||||
|
respx.get("https://www.googleapis.com/oauth2/v2/userinfo").mock(
|
||||||
|
return_value=Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"id": "google-user-123",
|
||||||
|
"email": "user@gmail.com",
|
||||||
|
"name": "Test User",
|
||||||
|
"picture": "https://google.com/avatar.jpg",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
user_info = await oauth.fetch_user_info("test-access-token")
|
||||||
|
|
||||||
|
assert user_info["id"] == "google-user-123"
|
||||||
|
assert user_info["email"] == "user@gmail.com"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_raises_on_error_response(self):
|
||||||
|
"""Test that fetch_user_info raises on error from Google."""
|
||||||
|
oauth = GoogleOAuth()
|
||||||
|
|
||||||
|
respx.get("https://www.googleapis.com/oauth2/v2/userinfo").mock(
|
||||||
|
return_value=Response(401, json={"error": "Invalid token"})
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(GoogleOAuthError, match="Failed to fetch user info"):
|
||||||
|
await oauth.fetch_user_info("invalid-token")
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUserInfo:
|
||||||
|
"""Tests for get_user_info method (full flow)."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_returns_oauth_user_info_on_success(self):
|
||||||
|
"""Test that get_user_info completes full OAuth flow.
|
||||||
|
|
||||||
|
This tests the combined token exchange + user info fetch.
|
||||||
|
"""
|
||||||
|
oauth = GoogleOAuth()
|
||||||
|
|
||||||
|
# Mock token exchange
|
||||||
|
respx.post("https://oauth2.googleapis.com/token").mock(
|
||||||
|
return_value=Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"access_token": "test-access-token",
|
||||||
|
"id_token": "test-id-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock user info
|
||||||
|
respx.get("https://www.googleapis.com/oauth2/v2/userinfo").mock(
|
||||||
|
return_value=Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"id": "google-user-456",
|
||||||
|
"email": "fullflow@gmail.com",
|
||||||
|
"name": "Full Flow User",
|
||||||
|
"picture": "https://google.com/fullflow.jpg",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||||
|
mock_settings.google_client_id = "test-client-id"
|
||||||
|
mock_settings.google_client_secret.get_secret_value.return_value = "test-secret"
|
||||||
|
|
||||||
|
result = await oauth.get_user_info("auth-code", "http://localhost/callback")
|
||||||
|
|
||||||
|
assert result.provider == "google"
|
||||||
|
assert result.oauth_id == "google-user-456"
|
||||||
|
assert result.email == "fullflow@gmail.com"
|
||||||
|
assert result.name == "Full Flow User"
|
||||||
|
assert result.avatar_url == "https://google.com/fullflow.jpg"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_raises_when_no_access_token(self):
|
||||||
|
"""Test that get_user_info raises when token response lacks access_token."""
|
||||||
|
oauth = GoogleOAuth()
|
||||||
|
|
||||||
|
respx.post("https://oauth2.googleapis.com/token").mock(
|
||||||
|
return_value=Response(
|
||||||
|
200,
|
||||||
|
json={"id_token": "only-id-token"}, # No access_token
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||||
|
mock_settings.google_client_id = "test-client-id"
|
||||||
|
mock_settings.google_client_secret.get_secret_value.return_value = "test-secret"
|
||||||
|
|
||||||
|
with pytest.raises(GoogleOAuthError, match="No access token"):
|
||||||
|
await oauth.get_user_info("auth-code", "http://localhost/callback")
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsConfigured:
|
||||||
|
"""Tests for is_configured method."""
|
||||||
|
|
||||||
|
def test_returns_false_when_not_configured(self):
|
||||||
|
"""Test that is_configured returns False without credentials."""
|
||||||
|
oauth = GoogleOAuth()
|
||||||
|
|
||||||
|
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||||
|
mock_settings.google_client_id = None
|
||||||
|
mock_settings.google_client_secret = None
|
||||||
|
|
||||||
|
assert oauth.is_configured() is False
|
||||||
|
|
||||||
|
def test_returns_true_when_configured(self):
|
||||||
|
"""Test that is_configured returns True with credentials."""
|
||||||
|
oauth = GoogleOAuth()
|
||||||
|
|
||||||
|
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||||
|
mock_settings.google_client_id = "client-id"
|
||||||
|
mock_settings.google_client_secret = "client-secret"
|
||||||
|
|
||||||
|
assert oauth.is_configured() is True
|
||||||
370
backend/tests/services/test_jwt_service.py
Normal file
370
backend/tests/services/test_jwt_service.py
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
"""Tests for JWT service.
|
||||||
|
|
||||||
|
Tests the JWT token creation and verification functions used for
|
||||||
|
authentication throughout the application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.schemas.auth import TokenType
|
||||||
|
from app.services.jwt_service import (
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
decode_token,
|
||||||
|
get_refresh_token_expiration,
|
||||||
|
get_token_expiration_seconds,
|
||||||
|
verify_access_token,
|
||||||
|
verify_refresh_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateAccessToken:
|
||||||
|
"""Tests for create_access_token function."""
|
||||||
|
|
||||||
|
def test_creates_valid_jwt(self):
|
||||||
|
"""Test that create_access_token returns a valid JWT string.
|
||||||
|
|
||||||
|
The returned token should be decodable and contain the expected
|
||||||
|
claims including subject, expiration, and token type.
|
||||||
|
"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
token = create_access_token(user_id)
|
||||||
|
|
||||||
|
# Should be a valid JWT (three dot-separated parts)
|
||||||
|
assert isinstance(token, str)
|
||||||
|
assert token.count(".") == 2
|
||||||
|
|
||||||
|
# Should be decodable
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.secret_key.get_secret_value(),
|
||||||
|
algorithms=[settings.jwt_algorithm],
|
||||||
|
)
|
||||||
|
assert payload["sub"] == str(user_id)
|
||||||
|
assert payload["type"] == TokenType.ACCESS.value
|
||||||
|
|
||||||
|
def test_sets_correct_expiration(self):
|
||||||
|
"""Test that access token expiration matches configured setting.
|
||||||
|
|
||||||
|
The token should expire approximately jwt_expire_minutes from now.
|
||||||
|
JWT timestamps have second precision, so we allow 1 second tolerance.
|
||||||
|
"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
before = datetime.now(UTC)
|
||||||
|
token = create_access_token(user_id)
|
||||||
|
after = datetime.now(UTC)
|
||||||
|
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.secret_key.get_secret_value(),
|
||||||
|
algorithms=[settings.jwt_algorithm],
|
||||||
|
)
|
||||||
|
|
||||||
|
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
||||||
|
expected_min = (
|
||||||
|
before + timedelta(minutes=settings.jwt_expire_minutes) - timedelta(seconds=1)
|
||||||
|
)
|
||||||
|
expected_max = after + timedelta(minutes=settings.jwt_expire_minutes) + timedelta(seconds=1)
|
||||||
|
|
||||||
|
assert expected_min <= exp <= expected_max
|
||||||
|
|
||||||
|
def test_includes_issued_at(self):
|
||||||
|
"""Test that access token includes iat (issued at) claim.
|
||||||
|
|
||||||
|
The iat claim should be set to approximately the current time.
|
||||||
|
JWT timestamps have second precision, so we allow 1 second tolerance.
|
||||||
|
"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
before = datetime.now(UTC)
|
||||||
|
token = create_access_token(user_id)
|
||||||
|
after = datetime.now(UTC)
|
||||||
|
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.secret_key.get_secret_value(),
|
||||||
|
algorithms=[settings.jwt_algorithm],
|
||||||
|
)
|
||||||
|
|
||||||
|
iat = datetime.fromtimestamp(payload["iat"], tz=UTC)
|
||||||
|
assert before - timedelta(seconds=1) <= iat <= after + timedelta(seconds=1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateRefreshToken:
|
||||||
|
"""Tests for create_refresh_token function."""
|
||||||
|
|
||||||
|
def test_creates_valid_jwt_with_jti(self):
|
||||||
|
"""Test that create_refresh_token returns a valid JWT and JTI.
|
||||||
|
|
||||||
|
The function should return a tuple of (token, jti) where the
|
||||||
|
token contains the JTI for revocation tracking.
|
||||||
|
"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
token, jti = create_refresh_token(user_id)
|
||||||
|
|
||||||
|
# Should return token and jti
|
||||||
|
assert isinstance(token, str)
|
||||||
|
assert isinstance(jti, str)
|
||||||
|
assert token.count(".") == 2
|
||||||
|
|
||||||
|
# JTI should be a valid UUID
|
||||||
|
uuid.UUID(jti) # Will raise if invalid
|
||||||
|
|
||||||
|
# Token should contain the JTI
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.secret_key.get_secret_value(),
|
||||||
|
algorithms=[settings.jwt_algorithm],
|
||||||
|
)
|
||||||
|
assert payload["jti"] == jti
|
||||||
|
assert payload["type"] == TokenType.REFRESH.value
|
||||||
|
|
||||||
|
def test_sets_correct_expiration(self):
|
||||||
|
"""Test that refresh token expiration matches configured setting.
|
||||||
|
|
||||||
|
The token should expire approximately jwt_refresh_expire_days from now.
|
||||||
|
JWT timestamps have second precision, so we allow 1 second tolerance.
|
||||||
|
"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
before = datetime.now(UTC)
|
||||||
|
token, _ = create_refresh_token(user_id)
|
||||||
|
after = datetime.now(UTC)
|
||||||
|
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.secret_key.get_secret_value(),
|
||||||
|
algorithms=[settings.jwt_algorithm],
|
||||||
|
)
|
||||||
|
|
||||||
|
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
||||||
|
expected_min = (
|
||||||
|
before + timedelta(days=settings.jwt_refresh_expire_days) - timedelta(seconds=1)
|
||||||
|
)
|
||||||
|
expected_max = (
|
||||||
|
after + timedelta(days=settings.jwt_refresh_expire_days) + timedelta(seconds=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert expected_min <= exp <= expected_max
|
||||||
|
|
||||||
|
def test_generates_unique_jti(self):
|
||||||
|
"""Test that each refresh token gets a unique JTI.
|
||||||
|
|
||||||
|
Multiple calls should generate different JTIs to ensure
|
||||||
|
each token can be individually revoked.
|
||||||
|
"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
_, jti1 = create_refresh_token(user_id)
|
||||||
|
_, jti2 = create_refresh_token(user_id)
|
||||||
|
|
||||||
|
assert jti1 != jti2
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecodeToken:
|
||||||
|
"""Tests for decode_token function."""
|
||||||
|
|
||||||
|
def test_decodes_valid_token(self):
|
||||||
|
"""Test that decode_token returns TokenPayload for valid tokens.
|
||||||
|
|
||||||
|
A valid token should be decoded into a TokenPayload with
|
||||||
|
all expected fields populated.
|
||||||
|
"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
token = create_access_token(user_id)
|
||||||
|
|
||||||
|
payload = decode_token(token)
|
||||||
|
|
||||||
|
assert payload is not None
|
||||||
|
assert payload.sub == str(user_id)
|
||||||
|
assert payload.type == TokenType.ACCESS
|
||||||
|
assert payload.exp is not None
|
||||||
|
assert payload.iat is not None
|
||||||
|
|
||||||
|
def test_returns_none_for_invalid_token(self):
|
||||||
|
"""Test that decode_token returns None for malformed tokens.
|
||||||
|
|
||||||
|
Invalid JWT strings should not raise exceptions but return None.
|
||||||
|
"""
|
||||||
|
result = decode_token("invalid.token.here")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_returns_none_for_wrong_signature(self):
|
||||||
|
"""Test that decode_token returns None for tokens with wrong signature.
|
||||||
|
|
||||||
|
Tokens signed with a different key should be rejected.
|
||||||
|
"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
# Create token with different secret
|
||||||
|
payload = {
|
||||||
|
"sub": str(user_id),
|
||||||
|
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||||
|
"iat": datetime.now(UTC),
|
||||||
|
"type": "access",
|
||||||
|
}
|
||||||
|
token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||||
|
|
||||||
|
result = decode_token(token)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_returns_none_for_expired_token(self):
|
||||||
|
"""Test that decode_token returns None for expired tokens.
|
||||||
|
|
||||||
|
Tokens past their expiration should be rejected.
|
||||||
|
"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
payload = {
|
||||||
|
"sub": str(user_id),
|
||||||
|
"exp": datetime.now(UTC) - timedelta(hours=1), # Already expired
|
||||||
|
"iat": datetime.now(UTC) - timedelta(hours=2),
|
||||||
|
"type": "access",
|
||||||
|
}
|
||||||
|
token = jwt.encode(
|
||||||
|
payload,
|
||||||
|
settings.secret_key.get_secret_value(),
|
||||||
|
algorithm=settings.jwt_algorithm,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = decode_token(token)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestVerifyAccessToken:
|
||||||
|
"""Tests for verify_access_token function."""
|
||||||
|
|
||||||
|
def test_returns_user_id_for_valid_access_token(self):
|
||||||
|
"""Test that verify_access_token returns user ID for valid tokens.
|
||||||
|
|
||||||
|
A valid access token should return the UUID of the user.
|
||||||
|
"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
token = create_access_token(user_id)
|
||||||
|
|
||||||
|
result = verify_access_token(token)
|
||||||
|
|
||||||
|
assert result == user_id
|
||||||
|
|
||||||
|
def test_returns_none_for_refresh_token(self):
|
||||||
|
"""Test that verify_access_token rejects refresh tokens.
|
||||||
|
|
||||||
|
Even valid refresh tokens should be rejected when verifying
|
||||||
|
as access tokens to prevent token type confusion.
|
||||||
|
"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
token, _ = create_refresh_token(user_id)
|
||||||
|
|
||||||
|
result = verify_access_token(token)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_returns_none_for_invalid_token(self):
|
||||||
|
"""Test that verify_access_token returns None for invalid tokens."""
|
||||||
|
result = verify_access_token("invalid.token.here")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_returns_none_for_invalid_uuid_subject(self):
|
||||||
|
"""Test that verify_access_token returns None for non-UUID subject.
|
||||||
|
|
||||||
|
If the subject claim is not a valid UUID, the token should be rejected.
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"sub": "not-a-uuid",
|
||||||
|
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||||
|
"iat": datetime.now(UTC),
|
||||||
|
"type": "access",
|
||||||
|
}
|
||||||
|
token = jwt.encode(
|
||||||
|
payload,
|
||||||
|
settings.secret_key.get_secret_value(),
|
||||||
|
algorithm=settings.jwt_algorithm,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = verify_access_token(token)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestVerifyRefreshToken:
|
||||||
|
"""Tests for verify_refresh_token function."""
|
||||||
|
|
||||||
|
def test_returns_user_id_and_jti_for_valid_refresh_token(self):
|
||||||
|
"""Test that verify_refresh_token returns user ID and JTI.
|
||||||
|
|
||||||
|
A valid refresh token should return both values needed for
|
||||||
|
revocation checking.
|
||||||
|
"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
token, jti = create_refresh_token(user_id)
|
||||||
|
|
||||||
|
result = verify_refresh_token(token)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
result_user_id, result_jti = result
|
||||||
|
assert result_user_id == user_id
|
||||||
|
assert result_jti == jti
|
||||||
|
|
||||||
|
def test_returns_none_for_access_token(self):
|
||||||
|
"""Test that verify_refresh_token rejects access tokens.
|
||||||
|
|
||||||
|
Even valid access tokens should be rejected when verifying
|
||||||
|
as refresh tokens.
|
||||||
|
"""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
token = create_access_token(user_id)
|
||||||
|
|
||||||
|
result = verify_refresh_token(token)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_returns_none_for_token_without_jti(self):
|
||||||
|
"""Test that verify_refresh_token rejects tokens missing JTI.
|
||||||
|
|
||||||
|
Refresh tokens must have a JTI for revocation tracking.
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"sub": str(uuid.uuid4()),
|
||||||
|
"exp": datetime.now(UTC) + timedelta(days=7),
|
||||||
|
"iat": datetime.now(UTC),
|
||||||
|
"type": "refresh",
|
||||||
|
# No jti
|
||||||
|
}
|
||||||
|
token = jwt.encode(
|
||||||
|
payload,
|
||||||
|
settings.secret_key.get_secret_value(),
|
||||||
|
algorithm=settings.jwt_algorithm,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = verify_refresh_token(token)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_returns_none_for_invalid_token(self):
|
||||||
|
"""Test that verify_refresh_token returns None for invalid tokens."""
|
||||||
|
result = verify_refresh_token("invalid.token.here")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestHelperFunctions:
|
||||||
|
"""Tests for helper functions."""
|
||||||
|
|
||||||
|
def test_get_token_expiration_seconds(self):
|
||||||
|
"""Test that get_token_expiration_seconds returns correct value.
|
||||||
|
|
||||||
|
Should return jwt_expire_minutes converted to seconds.
|
||||||
|
"""
|
||||||
|
result = get_token_expiration_seconds()
|
||||||
|
assert result == settings.jwt_expire_minutes * 60
|
||||||
|
|
||||||
|
def test_get_refresh_token_expiration(self):
|
||||||
|
"""Test that get_refresh_token_expiration returns future datetime.
|
||||||
|
|
||||||
|
Should return a datetime approximately jwt_refresh_expire_days
|
||||||
|
in the future.
|
||||||
|
"""
|
||||||
|
before = datetime.now(UTC)
|
||||||
|
result = get_refresh_token_expiration()
|
||||||
|
after = datetime.now(UTC)
|
||||||
|
|
||||||
|
expected_min = before + timedelta(days=settings.jwt_refresh_expire_days)
|
||||||
|
expected_max = after + timedelta(days=settings.jwt_refresh_expire_days)
|
||||||
|
|
||||||
|
assert expected_min <= result <= expected_max
|
||||||
664
backend/tests/services/test_user_service.py
Normal file
664
backend/tests/services/test_user_service.py
Normal file
@ -0,0 +1,664 @@
|
|||||||
|
"""Tests for UserService.
|
||||||
|
|
||||||
|
Tests the user service CRUD operations and OAuth-based user creation.
|
||||||
|
Uses real Postgres via the db_session fixture from conftest.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.db.models import User
|
||||||
|
from app.db.models.oauth_account import OAuthLinkedAccount
|
||||||
|
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
|
||||||
|
from app.services.user_service import AccountLinkingError, user_service
|
||||||
|
|
||||||
|
# Import db_session fixture from db conftest
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetById:
|
||||||
|
"""Tests for get_by_id method."""
|
||||||
|
|
||||||
|
async def test_returns_user_when_found(self, db_session):
|
||||||
|
"""Test that get_by_id returns user when it exists.
|
||||||
|
|
||||||
|
Creates a user and verifies it can be retrieved by ID.
|
||||||
|
"""
|
||||||
|
# Create user directly
|
||||||
|
user = User(
|
||||||
|
email="test@example.com",
|
||||||
|
display_name="Test User",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="123456",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Retrieve by ID
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
|
result = await user_service.get_by_id(db_session, user_id)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.email == "test@example.com"
|
||||||
|
|
||||||
|
async def test_returns_none_when_not_found(self, db_session):
|
||||||
|
"""Test that get_by_id returns None for nonexistent users."""
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
result = await user_service.get_by_id(db_session, uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetByEmail:
|
||||||
|
"""Tests for get_by_email method."""
|
||||||
|
|
||||||
|
async def test_returns_user_when_found(self, db_session):
|
||||||
|
"""Test that get_by_email returns user when it exists."""
|
||||||
|
user = User(
|
||||||
|
email="findme@example.com",
|
||||||
|
display_name="Find Me",
|
||||||
|
oauth_provider="discord",
|
||||||
|
oauth_id="discord123",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await user_service.get_by_email(db_session, "findme@example.com")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.display_name == "Find Me"
|
||||||
|
|
||||||
|
async def test_returns_none_when_not_found(self, db_session):
|
||||||
|
"""Test that get_by_email returns None for nonexistent emails."""
|
||||||
|
result = await user_service.get_by_email(db_session, "nobody@example.com")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetByOAuth:
|
||||||
|
"""Tests for get_by_oauth method."""
|
||||||
|
|
||||||
|
async def test_returns_user_when_found(self, db_session):
|
||||||
|
"""Test that get_by_oauth returns user for matching provider+id."""
|
||||||
|
user = User(
|
||||||
|
email="oauth@example.com",
|
||||||
|
display_name="OAuth User",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-unique-id",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await user_service.get_by_oauth(db_session, "google", "google-unique-id")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.email == "oauth@example.com"
|
||||||
|
|
||||||
|
async def test_returns_none_for_wrong_provider(self, db_session):
|
||||||
|
"""Test that get_by_oauth returns None if provider doesn't match."""
|
||||||
|
user = User(
|
||||||
|
email="oauth2@example.com",
|
||||||
|
display_name="OAuth User 2",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-id-2",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Same ID, different provider
|
||||||
|
result = await user_service.get_by_oauth(db_session, "discord", "google-id-2")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_returns_none_when_not_found(self, db_session):
|
||||||
|
"""Test that get_by_oauth returns None for nonexistent OAuth."""
|
||||||
|
result = await user_service.get_by_oauth(db_session, "google", "nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreate:
|
||||||
|
"""Tests for create method."""
|
||||||
|
|
||||||
|
async def test_creates_user_with_all_fields(self, db_session):
|
||||||
|
"""Test that create properly persists all user fields."""
|
||||||
|
user_data = UserCreate(
|
||||||
|
email="new@example.com",
|
||||||
|
display_name="New User",
|
||||||
|
avatar_url="https://example.com/avatar.jpg",
|
||||||
|
oauth_provider="discord",
|
||||||
|
oauth_id="discord-new-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await user_service.create(db_session, user_data)
|
||||||
|
|
||||||
|
assert result.id is not None
|
||||||
|
assert result.email == "new@example.com"
|
||||||
|
assert result.display_name == "New User"
|
||||||
|
assert result.avatar_url == "https://example.com/avatar.jpg"
|
||||||
|
assert result.oauth_provider == "discord"
|
||||||
|
assert result.oauth_id == "discord-new-id"
|
||||||
|
assert result.is_premium is False
|
||||||
|
assert result.premium_until is None
|
||||||
|
|
||||||
|
async def test_creates_user_without_avatar(self, db_session):
|
||||||
|
"""Test that create works without optional avatar_url."""
|
||||||
|
user_data = UserCreate(
|
||||||
|
email="noavatar@example.com",
|
||||||
|
display_name="No Avatar",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-no-avatar",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await user_service.create(db_session, user_data)
|
||||||
|
|
||||||
|
assert result.avatar_url is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateFromOAuth:
|
||||||
|
"""Tests for create_from_oauth method."""
|
||||||
|
|
||||||
|
async def test_creates_user_from_oauth_info(self, db_session):
|
||||||
|
"""Test that create_from_oauth converts OAuthUserInfo to User."""
|
||||||
|
oauth_info = OAuthUserInfo(
|
||||||
|
provider="google",
|
||||||
|
oauth_id="google-oauth-123",
|
||||||
|
email="oauthcreate@example.com",
|
||||||
|
name="OAuth Created User",
|
||||||
|
avatar_url="https://google.com/avatar.jpg",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await user_service.create_from_oauth(db_session, oauth_info)
|
||||||
|
|
||||||
|
assert result.email == "oauthcreate@example.com"
|
||||||
|
assert result.display_name == "OAuth Created User"
|
||||||
|
assert result.oauth_provider == "google"
|
||||||
|
assert result.oauth_id == "google-oauth-123"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetOrCreateFromOAuth:
|
||||||
|
"""Tests for get_or_create_from_oauth method."""
|
||||||
|
|
||||||
|
async def test_returns_existing_user_by_oauth(self, db_session):
|
||||||
|
"""Test that existing user is returned when OAuth matches.
|
||||||
|
|
||||||
|
Verifies the method returns (user, False) for existing users.
|
||||||
|
"""
|
||||||
|
# Create existing user
|
||||||
|
existing = User(
|
||||||
|
email="existing@example.com",
|
||||||
|
display_name="Existing",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="existing-oauth-id",
|
||||||
|
)
|
||||||
|
db_session.add(existing)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Try to get or create with same OAuth
|
||||||
|
oauth_info = OAuthUserInfo(
|
||||||
|
provider="google",
|
||||||
|
oauth_id="existing-oauth-id",
|
||||||
|
email="existing@example.com",
|
||||||
|
name="Existing",
|
||||||
|
)
|
||||||
|
|
||||||
|
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
||||||
|
|
||||||
|
assert created is False
|
||||||
|
assert result.id == existing.id
|
||||||
|
|
||||||
|
async def test_links_existing_user_by_email(self, db_session):
|
||||||
|
"""Test that OAuth is linked when email matches existing user.
|
||||||
|
|
||||||
|
If a user exists with the same email but different OAuth,
|
||||||
|
the new OAuth should be linked to the existing account.
|
||||||
|
"""
|
||||||
|
# Create user with Google
|
||||||
|
existing = User(
|
||||||
|
email="link@example.com",
|
||||||
|
display_name="Link Me",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-link-id",
|
||||||
|
)
|
||||||
|
db_session.add(existing)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Login with Discord (same email)
|
||||||
|
oauth_info = OAuthUserInfo(
|
||||||
|
provider="discord",
|
||||||
|
oauth_id="discord-link-id",
|
||||||
|
email="link@example.com",
|
||||||
|
name="Link Me",
|
||||||
|
avatar_url="https://discord.com/avatar.jpg",
|
||||||
|
)
|
||||||
|
|
||||||
|
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
||||||
|
|
||||||
|
assert created is False
|
||||||
|
assert result.id == existing.id
|
||||||
|
# OAuth should be updated to Discord
|
||||||
|
assert result.oauth_provider == "discord"
|
||||||
|
assert result.oauth_id == "discord-link-id"
|
||||||
|
|
||||||
|
async def test_creates_new_user_when_not_found(self, db_session):
|
||||||
|
"""Test that new user is created when no match exists.
|
||||||
|
|
||||||
|
Verifies the method returns (user, True) for new users.
|
||||||
|
"""
|
||||||
|
oauth_info = OAuthUserInfo(
|
||||||
|
provider="discord",
|
||||||
|
oauth_id="brand-new-id",
|
||||||
|
email="brandnew@example.com",
|
||||||
|
name="Brand New",
|
||||||
|
)
|
||||||
|
|
||||||
|
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
||||||
|
|
||||||
|
assert created is True
|
||||||
|
assert result.email == "brandnew@example.com"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdate:
|
||||||
|
"""Tests for update method."""
|
||||||
|
|
||||||
|
async def test_updates_display_name(self, db_session):
|
||||||
|
"""Test that update changes display_name when provided."""
|
||||||
|
user = User(
|
||||||
|
email="update@example.com",
|
||||||
|
display_name="Old Name",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="update-id",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
update_data = UserUpdate(display_name="New Name")
|
||||||
|
result = await user_service.update(db_session, user, update_data)
|
||||||
|
|
||||||
|
assert result.display_name == "New Name"
|
||||||
|
|
||||||
|
async def test_updates_avatar_url(self, db_session):
|
||||||
|
"""Test that update changes avatar_url when provided."""
|
||||||
|
user = User(
|
||||||
|
email="avatar@example.com",
|
||||||
|
display_name="Avatar User",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="avatar-id",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
update_data = UserUpdate(avatar_url="https://new-avatar.com/img.jpg")
|
||||||
|
result = await user_service.update(db_session, user, update_data)
|
||||||
|
|
||||||
|
assert result.avatar_url == "https://new-avatar.com/img.jpg"
|
||||||
|
|
||||||
|
async def test_ignores_none_values(self, db_session):
|
||||||
|
"""Test that update doesn't change fields set to None.
|
||||||
|
|
||||||
|
Only explicitly provided fields should be updated.
|
||||||
|
"""
|
||||||
|
user = User(
|
||||||
|
email="keep@example.com",
|
||||||
|
display_name="Keep Me",
|
||||||
|
avatar_url="https://keep.com/avatar.jpg",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="keep-id",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Update only display_name, leave avatar alone
|
||||||
|
update_data = UserUpdate(display_name="Changed")
|
||||||
|
result = await user_service.update(db_session, user, update_data)
|
||||||
|
|
||||||
|
assert result.display_name == "Changed"
|
||||||
|
assert result.avatar_url == "https://keep.com/avatar.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateLastLogin:
|
||||||
|
"""Tests for update_last_login method."""
|
||||||
|
|
||||||
|
async def test_updates_last_login_timestamp(self, db_session):
|
||||||
|
"""Test that update_last_login sets current timestamp."""
|
||||||
|
user = User(
|
||||||
|
email="login@example.com",
|
||||||
|
display_name="Login User",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="login-id",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
assert user.last_login is None
|
||||||
|
|
||||||
|
before = datetime.now(UTC)
|
||||||
|
result = await user_service.update_last_login(db_session, user)
|
||||||
|
after = datetime.now(UTC)
|
||||||
|
|
||||||
|
assert result.last_login is not None
|
||||||
|
# Allow 1 second tolerance
|
||||||
|
assert before - timedelta(seconds=1) <= result.last_login <= after + timedelta(seconds=1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdatePremium:
|
||||||
|
"""Tests for update_premium method."""
|
||||||
|
|
||||||
|
async def test_grants_premium(self, db_session):
|
||||||
|
"""Test that update_premium sets premium status and expiration."""
|
||||||
|
user = User(
|
||||||
|
email="premium@example.com",
|
||||||
|
display_name="Premium User",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="premium-id",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
assert user.is_premium is False
|
||||||
|
|
||||||
|
expires = datetime.now(UTC) + timedelta(days=30)
|
||||||
|
result = await user_service.update_premium(db_session, user, expires)
|
||||||
|
|
||||||
|
assert result.is_premium is True
|
||||||
|
assert result.premium_until == expires
|
||||||
|
|
||||||
|
async def test_removes_premium(self, db_session):
|
||||||
|
"""Test that update_premium with None removes premium status."""
|
||||||
|
user = User(
|
||||||
|
email="unpremium@example.com",
|
||||||
|
display_name="Unpremium User",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="unpremium-id",
|
||||||
|
is_premium=True,
|
||||||
|
premium_until=datetime.now(UTC) + timedelta(days=30),
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await user_service.update_premium(db_session, user, None)
|
||||||
|
|
||||||
|
assert result.is_premium is False
|
||||||
|
assert result.premium_until is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestDelete:
|
||||||
|
"""Tests for delete method."""
|
||||||
|
|
||||||
|
async def test_deletes_user(self, db_session):
|
||||||
|
"""Test that delete removes user from database."""
|
||||||
|
user = User(
|
||||||
|
email="delete@example.com",
|
||||||
|
display_name="Delete Me",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="delete-id",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
user_id = user.id
|
||||||
|
await user_service.delete(db_session, user)
|
||||||
|
|
||||||
|
# Verify user is gone
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
result = await user_service.get_by_id(
|
||||||
|
db_session, UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLinkedAccount:
|
||||||
|
"""Tests for get_linked_account method."""
|
||||||
|
|
||||||
|
async def test_returns_linked_account_when_found(self, db_session):
|
||||||
|
"""Test that get_linked_account returns account when it exists.
|
||||||
|
|
||||||
|
Creates a user with a linked account and verifies it can be retrieved.
|
||||||
|
"""
|
||||||
|
# Create user
|
||||||
|
user = User(
|
||||||
|
email="primary@example.com",
|
||||||
|
display_name="Primary User",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-primary",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Create linked account
|
||||||
|
linked = OAuthLinkedAccount(
|
||||||
|
user_id=user.id,
|
||||||
|
provider="discord",
|
||||||
|
oauth_id="discord-linked-123",
|
||||||
|
email="linked@example.com",
|
||||||
|
)
|
||||||
|
db_session.add(linked)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Retrieve linked account
|
||||||
|
result = await user_service.get_linked_account(db_session, "discord", "discord-linked-123")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.provider == "discord"
|
||||||
|
assert result.oauth_id == "discord-linked-123"
|
||||||
|
|
||||||
|
async def test_returns_none_when_not_found(self, db_session):
|
||||||
|
"""Test that get_linked_account returns None for nonexistent accounts."""
|
||||||
|
result = await user_service.get_linked_account(db_session, "discord", "nonexistent-id")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestLinkOAuthAccount:
|
||||||
|
"""Tests for link_oauth_account method."""
|
||||||
|
|
||||||
|
async def test_links_new_provider(self, db_session):
|
||||||
|
"""Test that link_oauth_account successfully links a new provider.
|
||||||
|
|
||||||
|
Creates a Google user and links Discord to them.
|
||||||
|
"""
|
||||||
|
# Create user with Google
|
||||||
|
user = User(
|
||||||
|
email="google-user@example.com",
|
||||||
|
display_name="Google User",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-123",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Link Discord
|
||||||
|
discord_info = OAuthUserInfo(
|
||||||
|
provider="discord",
|
||||||
|
oauth_id="discord-456",
|
||||||
|
email="discord@example.com",
|
||||||
|
name="Discord Name",
|
||||||
|
avatar_url="https://discord.com/avatar.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await user_service.link_oauth_account(db_session, user, discord_info)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.provider == "discord"
|
||||||
|
assert result.oauth_id == "discord-456"
|
||||||
|
assert result.email == "discord@example.com"
|
||||||
|
assert result.display_name == "Discord Name"
|
||||||
|
assert str(result.user_id) == str(user.id)
|
||||||
|
|
||||||
|
async def test_raises_error_if_already_linked_to_same_user(self, db_session):
|
||||||
|
"""Test that linking same provider twice raises error.
|
||||||
|
|
||||||
|
A user cannot have the same provider linked multiple times.
|
||||||
|
"""
|
||||||
|
user = User(
|
||||||
|
email="double-link@example.com",
|
||||||
|
display_name="Double Link",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-double",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Link Discord first time
|
||||||
|
discord_info = OAuthUserInfo(
|
||||||
|
provider="discord",
|
||||||
|
oauth_id="discord-first",
|
||||||
|
email="first@discord.com",
|
||||||
|
name="First",
|
||||||
|
)
|
||||||
|
await user_service.link_oauth_account(db_session, user, discord_info)
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Try to link same Discord account again
|
||||||
|
with pytest.raises(AccountLinkingError) as exc_info:
|
||||||
|
await user_service.link_oauth_account(db_session, user, discord_info)
|
||||||
|
|
||||||
|
assert "already linked to your account" in str(exc_info.value)
|
||||||
|
|
||||||
|
async def test_raises_error_if_linked_to_another_user(self, db_session):
|
||||||
|
"""Test that linking account already linked to another user raises error.
|
||||||
|
|
||||||
|
The same OAuth provider+ID cannot be linked to multiple users.
|
||||||
|
"""
|
||||||
|
# Create first user and link Discord
|
||||||
|
user1 = User(
|
||||||
|
email="user1@example.com",
|
||||||
|
display_name="User 1",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-user1",
|
||||||
|
)
|
||||||
|
db_session.add(user1)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user1)
|
||||||
|
|
||||||
|
discord_info = OAuthUserInfo(
|
||||||
|
provider="discord",
|
||||||
|
oauth_id="shared-discord",
|
||||||
|
email="shared@discord.com",
|
||||||
|
name="Shared",
|
||||||
|
)
|
||||||
|
await user_service.link_oauth_account(db_session, user1, discord_info)
|
||||||
|
|
||||||
|
# Create second user
|
||||||
|
user2 = User(
|
||||||
|
email="user2@example.com",
|
||||||
|
display_name="User 2",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-user2",
|
||||||
|
)
|
||||||
|
db_session.add(user2)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user2)
|
||||||
|
|
||||||
|
# Try to link same Discord account to second user
|
||||||
|
with pytest.raises(AccountLinkingError) as exc_info:
|
||||||
|
await user_service.link_oauth_account(db_session, user2, discord_info)
|
||||||
|
|
||||||
|
assert "already linked to another user" in str(exc_info.value)
|
||||||
|
|
||||||
|
async def test_raises_error_if_linking_primary_provider(self, db_session):
|
||||||
|
"""Test that linking the same provider as primary raises error.
|
||||||
|
|
||||||
|
User cannot link Google if they already signed up with Google.
|
||||||
|
"""
|
||||||
|
user = User(
|
||||||
|
email="google-primary@example.com",
|
||||||
|
display_name="Google Primary",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-primary-id",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Try to link another Google account
|
||||||
|
google_info = OAuthUserInfo(
|
||||||
|
provider="google",
|
||||||
|
oauth_id="google-different-id",
|
||||||
|
email="different@gmail.com",
|
||||||
|
name="Different",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(AccountLinkingError) as exc_info:
|
||||||
|
await user_service.link_oauth_account(db_session, user, google_info)
|
||||||
|
|
||||||
|
assert "primary login provider" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnlinkOAuthAccount:
|
||||||
|
"""Tests for unlink_oauth_account method."""
|
||||||
|
|
||||||
|
async def test_unlinks_linked_account(self, db_session):
|
||||||
|
"""Test that unlink_oauth_account removes a linked account.
|
||||||
|
|
||||||
|
Links Discord then unlinks it successfully.
|
||||||
|
"""
|
||||||
|
user = User(
|
||||||
|
email="unlink@example.com",
|
||||||
|
display_name="Unlink User",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-unlink",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Link Discord
|
||||||
|
discord_info = OAuthUserInfo(
|
||||||
|
provider="discord",
|
||||||
|
oauth_id="discord-unlink",
|
||||||
|
email="discord@unlink.com",
|
||||||
|
name="Discord Unlink",
|
||||||
|
)
|
||||||
|
await user_service.link_oauth_account(db_session, user, discord_info)
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Verify linked
|
||||||
|
assert len(user.linked_accounts) == 1
|
||||||
|
|
||||||
|
# Unlink
|
||||||
|
result = await user_service.unlink_oauth_account(db_session, user, "discord")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
# Verify unlinked
|
||||||
|
linked = await user_service.get_linked_account(db_session, "discord", "discord-unlink")
|
||||||
|
assert linked is None
|
||||||
|
|
||||||
|
async def test_returns_false_if_not_linked(self, db_session):
|
||||||
|
"""Test that unlink returns False if provider isn't linked."""
|
||||||
|
user = User(
|
||||||
|
email="not-linked@example.com",
|
||||||
|
display_name="Not Linked",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-notlinked",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
result = await user_service.unlink_oauth_account(db_session, user, "discord")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
async def test_raises_error_if_unlinking_primary(self, db_session):
|
||||||
|
"""Test that unlinking primary provider raises error.
|
||||||
|
|
||||||
|
User cannot unlink their primary OAuth provider.
|
||||||
|
"""
|
||||||
|
user = User(
|
||||||
|
email="primary-unlink@example.com",
|
||||||
|
display_name="Primary Unlink",
|
||||||
|
oauth_provider="google",
|
||||||
|
oauth_id="google-primary-unlink",
|
||||||
|
)
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
with pytest.raises(AccountLinkingError) as exc_info:
|
||||||
|
await user_service.unlink_oauth_account(db_session, user, "google")
|
||||||
|
|
||||||
|
assert "primary login provider" in str(exc_info.value)
|
||||||
64
backend/uv.lock
generated
64
backend/uv.lock
generated
@ -368,6 +368,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/cc/48/d9f421cb8da5afaa1a64570d9989e00fb7955e6acddc5a12979f7666ef60/coverage-7.13.1-py3-none-any.whl", hash = "sha256:2016745cb3ba554469d02819d78958b571792bb68e31302610e898f80dd3a573", size = 210722, upload-time = "2025-12-28T15:42:54.901Z" },
|
{ url = "https://files.pythonhosted.org/packages/cc/48/d9f421cb8da5afaa1a64570d9989e00fb7955e6acddc5a12979f7666ef60/coverage-7.13.1-py3-none-any.whl", hash = "sha256:2016745cb3ba554469d02819d78958b571792bb68e31302610e898f80dd3a573", size = 210722, upload-time = "2025-12-28T15:42:54.901Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dnspython"
|
||||||
|
version = "2.8.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/8c/8b/57666417c0f90f08bcafa776861060426765fdb422eb10212086fb811d26/dnspython-2.8.0.tar.gz", hash = "sha256:181d3c6996452cb1189c4046c61599b84a5a86e099562ffde77d26984ff26d0f", size = 368251, upload-time = "2025-09-07T18:58:00.022Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "docker"
|
name = "docker"
|
||||||
version = "7.1.0"
|
version = "7.1.0"
|
||||||
@ -394,6 +403,32 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" },
|
{ url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "email-validator"
|
||||||
|
version = "2.3.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "dnspython" },
|
||||||
|
{ name = "idna" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/f5/22/900cb125c76b7aaa450ce02fd727f452243f2e91a61af068b40adba60ea9/email_validator-2.3.0.tar.gz", hash = "sha256:9fc05c37f2f6cf439ff414f8fc46d917929974a82244c20eb10231ba60c54426", size = 51238, upload-time = "2025-08-26T13:09:06.831Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fakeredis"
|
||||||
|
version = "2.33.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "redis" },
|
||||||
|
{ name = "sortedcontainers" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/5f/f9/57464119936414d60697fcbd32f38909bb5688b616ae13de6e98384433e0/fakeredis-2.33.0.tar.gz", hash = "sha256:d7bc9a69d21df108a6451bbffee23b3eba432c21a654afc7ff2d295428ec5770", size = 175187, upload-time = "2025-12-16T19:45:52.269Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6e/78/a850fed8aeef96d4a99043c90b818b2ed5419cd5b24a4049fd7cfb9f1471/fakeredis-2.33.0-py3-none-any.whl", hash = "sha256:de535f3f9ccde1c56672ab2fdd6a8efbc4f2619fc2f1acc87b8737177d71c965", size = 119605, upload-time = "2025-12-16T19:45:51.08Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi"
|
name = "fastapi"
|
||||||
version = "0.128.0"
|
version = "0.128.0"
|
||||||
@ -579,7 +614,9 @@ dependencies = [
|
|||||||
{ name = "alembic" },
|
{ name = "alembic" },
|
||||||
{ name = "asyncpg" },
|
{ name = "asyncpg" },
|
||||||
{ name = "bcrypt" },
|
{ name = "bcrypt" },
|
||||||
|
{ name = "email-validator" },
|
||||||
{ name = "fastapi" },
|
{ name = "fastapi" },
|
||||||
|
{ name = "httpx" },
|
||||||
{ name = "passlib" },
|
{ name = "passlib" },
|
||||||
{ name = "psycopg2-binary" },
|
{ name = "psycopg2-binary" },
|
||||||
{ name = "pydantic" },
|
{ name = "pydantic" },
|
||||||
@ -595,12 +632,14 @@ dependencies = [
|
|||||||
dev = [
|
dev = [
|
||||||
{ name = "beautifulsoup4" },
|
{ name = "beautifulsoup4" },
|
||||||
{ name = "black" },
|
{ name = "black" },
|
||||||
|
{ name = "fakeredis" },
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
{ name = "mypy" },
|
{ name = "mypy" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
{ name = "pytest-asyncio" },
|
{ name = "pytest-asyncio" },
|
||||||
{ name = "pytest-cov" },
|
{ name = "pytest-cov" },
|
||||||
{ name = "requests" },
|
{ name = "requests" },
|
||||||
|
{ name = "respx" },
|
||||||
{ name = "ruff" },
|
{ name = "ruff" },
|
||||||
{ name = "testcontainers", extra = ["redis"] },
|
{ name = "testcontainers", extra = ["redis"] },
|
||||||
]
|
]
|
||||||
@ -610,7 +649,9 @@ requires-dist = [
|
|||||||
{ name = "alembic", specifier = ">=1.18.1" },
|
{ name = "alembic", specifier = ">=1.18.1" },
|
||||||
{ name = "asyncpg", specifier = ">=0.31.0" },
|
{ name = "asyncpg", specifier = ">=0.31.0" },
|
||||||
{ name = "bcrypt", specifier = ">=5.0.0" },
|
{ name = "bcrypt", specifier = ">=5.0.0" },
|
||||||
|
{ name = "email-validator", specifier = ">=2.3.0" },
|
||||||
{ name = "fastapi", specifier = ">=0.128.0" },
|
{ name = "fastapi", specifier = ">=0.128.0" },
|
||||||
|
{ name = "httpx", specifier = ">=0.28.1" },
|
||||||
{ name = "passlib", specifier = ">=1.7.4" },
|
{ name = "passlib", specifier = ">=1.7.4" },
|
||||||
{ name = "psycopg2-binary", specifier = ">=2.9.11" },
|
{ name = "psycopg2-binary", specifier = ">=2.9.11" },
|
||||||
{ name = "pydantic", specifier = ">=2.12.5" },
|
{ name = "pydantic", specifier = ">=2.12.5" },
|
||||||
@ -626,12 +667,14 @@ requires-dist = [
|
|||||||
dev = [
|
dev = [
|
||||||
{ name = "beautifulsoup4", specifier = ">=4.12.0" },
|
{ name = "beautifulsoup4", specifier = ">=4.12.0" },
|
||||||
{ name = "black", specifier = ">=26.1.0" },
|
{ name = "black", specifier = ">=26.1.0" },
|
||||||
|
{ name = "fakeredis", specifier = ">=2.33.0" },
|
||||||
{ name = "httpx", specifier = ">=0.28.1" },
|
{ name = "httpx", specifier = ">=0.28.1" },
|
||||||
{ name = "mypy", specifier = ">=1.19.1" },
|
{ name = "mypy", specifier = ">=1.19.1" },
|
||||||
{ name = "pytest", specifier = ">=9.0.2" },
|
{ name = "pytest", specifier = ">=9.0.2" },
|
||||||
{ name = "pytest-asyncio", specifier = ">=1.3.0" },
|
{ name = "pytest-asyncio", specifier = ">=1.3.0" },
|
||||||
{ name = "pytest-cov", specifier = ">=7.0.0" },
|
{ name = "pytest-cov", specifier = ">=7.0.0" },
|
||||||
{ name = "requests", specifier = ">=2.31.0" },
|
{ name = "requests", specifier = ">=2.31.0" },
|
||||||
|
{ name = "respx", specifier = ">=0.22.0" },
|
||||||
{ name = "ruff", specifier = ">=0.14.14" },
|
{ name = "ruff", specifier = ">=0.14.14" },
|
||||||
{ name = "testcontainers", extras = ["postgres", "redis"], specifier = ">=4.0.0" },
|
{ name = "testcontainers", extras = ["postgres", "redis"], specifier = ">=4.0.0" },
|
||||||
]
|
]
|
||||||
@ -1105,6 +1148,18 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "respx"
|
||||||
|
version = "0.22.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "httpx" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/f4/7c/96bd0bc759cf009675ad1ee1f96535edcb11e9666b985717eb8c87192a95/respx-0.22.0.tar.gz", hash = "sha256:3c8924caa2a50bd71aefc07aa812f2466ff489f1848c96e954a5362d17095d91", size = 28439, upload-time = "2024-12-19T22:33:59.374Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8e/67/afbb0978d5399bc9ea200f1d4489a23c9a1dad4eee6376242b8182389c79/respx-0.22.0-py2.py3-none-any.whl", hash = "sha256:631128d4c9aba15e56903fb5f66fb1eff412ce28dd387ca3a81339e52dbd3ad0", size = 25127, upload-time = "2024-12-19T22:33:57.837Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rsa"
|
name = "rsa"
|
||||||
version = "4.9.1"
|
version = "4.9.1"
|
||||||
@ -1164,6 +1219,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" },
|
{ url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sortedcontainers"
|
||||||
|
version = "2.4.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload-time = "2021-05-16T22:03:42.897Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "soupsieve"
|
name = "soupsieve"
|
||||||
version = "2.8.3"
|
version = "2.8.3"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user