Implement Phase 2: Authentication system
Complete OAuth-based authentication with JWT session management:
Core Services:
- JWT service for access/refresh token creation and verification
- Token store with Redis-backed refresh token revocation
- User service for CRUD operations and OAuth-based creation
- Google and Discord OAuth services with full flow support
API Endpoints:
- GET /api/auth/{google,discord} - Start OAuth flows
- GET /api/auth/{google,discord}/callback - Handle OAuth callbacks
- POST /api/auth/refresh - Exchange refresh token for new access token
- POST /api/auth/logout - Revoke single refresh token
- POST /api/auth/logout-all - Revoke all user sessions
- GET/PATCH /api/users/me - User profile management
- GET /api/users/me/linked-accounts - List OAuth providers
- GET /api/users/me/sessions - Count active sessions
Infrastructure:
- Pydantic schemas for auth/user request/response models
- FastAPI dependencies (get_current_user, get_current_premium_user)
- OAuthLinkedAccount model for multi-provider support
- Alembic migration for oauth_linked_accounts table
Dependencies added: email-validator, fakeredis (dev), respx (dev)
84 new tests, 1058 total passing
This commit is contained in:
parent
4ddc9b8c30
commit
996c43fbd9
14
backend/app/api/__init__.py
Normal file
14
backend/app/api/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""API routers and dependencies for Mantimon TCG.
|
||||
|
||||
This package contains FastAPI routers and common dependencies
|
||||
for the REST API.
|
||||
|
||||
Routers:
|
||||
- auth: OAuth login, token refresh, logout
|
||||
- users: User profile management
|
||||
|
||||
Dependencies:
|
||||
- get_current_user: Extract and validate user from JWT
|
||||
- get_current_active_user: Ensure user exists
|
||||
- get_current_premium_user: Require premium subscription
|
||||
"""
|
||||
391
backend/app/api/auth.py
Normal file
391
backend/app/api/auth.py
Normal file
@ -0,0 +1,391 @@
|
||||
"""Authentication API router for Mantimon TCG.
|
||||
|
||||
This module provides endpoints for OAuth authentication:
|
||||
- OAuth login redirects (Google, Discord)
|
||||
- OAuth callbacks (token exchange)
|
||||
- Token refresh
|
||||
- Logout
|
||||
|
||||
OAuth Flow:
|
||||
1. Client calls GET /auth/{provider} to get redirect URL
|
||||
2. Client redirects user to OAuth provider
|
||||
3. Provider redirects back to GET /auth/{provider}/callback
|
||||
4. Server exchanges code for tokens, creates/fetches user
|
||||
5. Server returns JWT access + refresh tokens
|
||||
|
||||
Example:
|
||||
# Start OAuth flow
|
||||
GET /api/auth/google?redirect_uri=https://play.mantimon.com/login/callback
|
||||
|
||||
# After OAuth callback, refresh tokens
|
||||
POST /api/auth/refresh
|
||||
{"refresh_token": "..."}
|
||||
|
||||
# Logout
|
||||
POST /api/auth/logout
|
||||
{"refresh_token": "..."}
|
||||
"""
|
||||
|
||||
import secrets
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from app.api.deps import CurrentUser, DbSession
|
||||
from app.db.redis import get_redis
|
||||
from app.schemas.auth import RefreshTokenRequest, TokenResponse
|
||||
from app.services.jwt_service import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
get_refresh_token_expiration,
|
||||
get_token_expiration_seconds,
|
||||
verify_refresh_token,
|
||||
)
|
||||
from app.services.oauth.discord import DiscordOAuthError, discord_oauth
|
||||
from app.services.oauth.google import GoogleOAuthError, google_oauth
|
||||
from app.services.token_store import token_store
|
||||
from app.services.user_service import user_service
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
# OAuth state TTL (5 minutes)
|
||||
OAUTH_STATE_TTL = 300
|
||||
|
||||
|
||||
async def _store_oauth_state(state: str, provider: str, redirect_uri: str) -> None:
|
||||
"""Store OAuth state in Redis for CSRF validation."""
|
||||
async with get_redis() as redis:
|
||||
key = f"oauth_state:{state}"
|
||||
value = f"{provider}:{redirect_uri}"
|
||||
await redis.setex(key, OAUTH_STATE_TTL, value)
|
||||
|
||||
|
||||
async def _validate_oauth_state(state: str, provider: str) -> str | None:
|
||||
"""Validate and consume OAuth state, returning redirect_uri if valid."""
|
||||
async with get_redis() as redis:
|
||||
key = f"oauth_state:{state}"
|
||||
value = await redis.get(key)
|
||||
if not value:
|
||||
return None
|
||||
|
||||
# Delete state (one-time use)
|
||||
await redis.delete(key)
|
||||
|
||||
# Parse and validate
|
||||
stored_provider, redirect_uri = value.split(":", 1)
|
||||
if stored_provider != provider:
|
||||
return None
|
||||
|
||||
return redirect_uri
|
||||
|
||||
|
||||
async def _create_tokens_for_user(user_id) -> TokenResponse:
|
||||
"""Create access and refresh tokens for a user."""
|
||||
from uuid import UUID
|
||||
|
||||
if isinstance(user_id, str):
|
||||
user_id = UUID(user_id)
|
||||
|
||||
access_token = create_access_token(user_id)
|
||||
refresh_token, jti = create_refresh_token(user_id)
|
||||
|
||||
# Store refresh token in Redis for revocation tracking
|
||||
expires_at = get_refresh_token_expiration()
|
||||
await token_store.store_refresh_token(user_id, jti, expires_at)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=get_token_expiration_seconds(),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Google OAuth
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/google")
|
||||
async def google_auth_redirect(
|
||||
redirect_uri: str = Query(..., description="URI to redirect to after OAuth completes"),
|
||||
) -> RedirectResponse:
|
||||
"""Start Google OAuth flow.
|
||||
|
||||
Redirects the user to Google's OAuth consent screen.
|
||||
|
||||
Args:
|
||||
redirect_uri: Where to redirect after successful authentication.
|
||||
This is YOUR app's callback, not the OAuth callback.
|
||||
|
||||
Returns:
|
||||
Redirect to Google OAuth authorization URL.
|
||||
|
||||
Raises:
|
||||
HTTPException: 501 if Google OAuth is not configured.
|
||||
"""
|
||||
if not google_oauth.is_configured():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Google OAuth is not configured",
|
||||
)
|
||||
|
||||
# Generate state for CSRF protection
|
||||
state = secrets.token_urlsafe(32)
|
||||
await _store_oauth_state(state, "google", redirect_uri)
|
||||
|
||||
# Build OAuth callback URL (our server endpoint)
|
||||
# The redirect_uri param here is where Google sends the code
|
||||
oauth_callback = "/api/auth/google/callback"
|
||||
|
||||
# Get authorization URL
|
||||
auth_url = google_oauth.get_authorization_url(
|
||||
redirect_uri=oauth_callback,
|
||||
state=state,
|
||||
)
|
||||
|
||||
return RedirectResponse(url=auth_url, status_code=status.HTTP_302_FOUND)
|
||||
|
||||
|
||||
@router.get("/google/callback")
|
||||
async def google_auth_callback(
|
||||
db: DbSession,
|
||||
code: str = Query(..., description="Authorization code from Google"),
|
||||
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||
) -> TokenResponse:
|
||||
"""Handle Google OAuth callback.
|
||||
|
||||
Exchanges the authorization code for tokens and creates/fetches the user.
|
||||
|
||||
Args:
|
||||
code: Authorization code from Google.
|
||||
state: State parameter for CSRF validation.
|
||||
|
||||
Returns:
|
||||
JWT access and refresh tokens.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if state is invalid or OAuth fails.
|
||||
"""
|
||||
# Validate state
|
||||
redirect_uri = await _validate_oauth_state(state, "google")
|
||||
if redirect_uri is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired state parameter",
|
||||
)
|
||||
|
||||
try:
|
||||
# Exchange code for user info
|
||||
oauth_callback = "/api/auth/google/callback"
|
||||
user_info = await google_oauth.get_user_info(code, oauth_callback)
|
||||
|
||||
# Get or create user
|
||||
user, created = await user_service.get_or_create_from_oauth(db, user_info)
|
||||
|
||||
# Update last login
|
||||
await user_service.update_last_login(db, user)
|
||||
|
||||
# Create tokens
|
||||
return await _create_tokens_for_user(user.id)
|
||||
|
||||
except GoogleOAuthError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Google OAuth failed: {e}",
|
||||
) from None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Discord OAuth
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/discord")
|
||||
async def discord_auth_redirect(
|
||||
redirect_uri: str = Query(..., description="URI to redirect to after OAuth completes"),
|
||||
) -> RedirectResponse:
|
||||
"""Start Discord OAuth flow.
|
||||
|
||||
Redirects the user to Discord's OAuth consent screen.
|
||||
|
||||
Args:
|
||||
redirect_uri: Where to redirect after successful authentication.
|
||||
|
||||
Returns:
|
||||
Redirect to Discord OAuth authorization URL.
|
||||
|
||||
Raises:
|
||||
HTTPException: 501 if Discord OAuth is not configured.
|
||||
"""
|
||||
if not discord_oauth.is_configured():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Discord OAuth is not configured",
|
||||
)
|
||||
|
||||
# Generate state for CSRF protection
|
||||
state = secrets.token_urlsafe(32)
|
||||
await _store_oauth_state(state, "discord", redirect_uri)
|
||||
|
||||
# Build OAuth callback URL
|
||||
oauth_callback = "/api/auth/discord/callback"
|
||||
|
||||
# Get authorization URL
|
||||
auth_url = discord_oauth.get_authorization_url(
|
||||
redirect_uri=oauth_callback,
|
||||
state=state,
|
||||
)
|
||||
|
||||
return RedirectResponse(url=auth_url, status_code=status.HTTP_302_FOUND)
|
||||
|
||||
|
||||
@router.get("/discord/callback")
|
||||
async def discord_auth_callback(
|
||||
db: DbSession,
|
||||
code: str = Query(..., description="Authorization code from Discord"),
|
||||
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||
) -> TokenResponse:
|
||||
"""Handle Discord OAuth callback.
|
||||
|
||||
Exchanges the authorization code for tokens and creates/fetches the user.
|
||||
|
||||
Args:
|
||||
code: Authorization code from Discord.
|
||||
state: State parameter for CSRF validation.
|
||||
|
||||
Returns:
|
||||
JWT access and refresh tokens.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if state is invalid or OAuth fails.
|
||||
"""
|
||||
# Validate state
|
||||
redirect_uri = await _validate_oauth_state(state, "discord")
|
||||
if redirect_uri is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired state parameter",
|
||||
)
|
||||
|
||||
try:
|
||||
# Exchange code for user info
|
||||
oauth_callback = "/api/auth/discord/callback"
|
||||
user_info = await discord_oauth.get_user_info(code, oauth_callback)
|
||||
|
||||
# Get or create user
|
||||
user, created = await user_service.get_or_create_from_oauth(db, user_info)
|
||||
|
||||
# Update last login
|
||||
await user_service.update_last_login(db, user)
|
||||
|
||||
# Create tokens
|
||||
return await _create_tokens_for_user(user.id)
|
||||
|
||||
except DiscordOAuthError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Discord OAuth failed: {e}",
|
||||
) from None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Token Management
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh_tokens(
|
||||
db: DbSession,
|
||||
request: RefreshTokenRequest,
|
||||
) -> TokenResponse:
|
||||
"""Refresh access token using refresh token.
|
||||
|
||||
Validates the refresh token and issues a new access token.
|
||||
The refresh token itself is NOT rotated (same token can be used
|
||||
until it expires or is revoked).
|
||||
|
||||
Args:
|
||||
request: Contains the refresh token.
|
||||
|
||||
Returns:
|
||||
New JWT access token (same refresh token).
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if refresh token is invalid or revoked.
|
||||
"""
|
||||
# Verify refresh token
|
||||
result = verify_refresh_token(request.refresh_token)
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
)
|
||||
|
||||
user_id, jti = result
|
||||
|
||||
# Check if token is revoked
|
||||
is_valid = await token_store.is_token_valid(user_id, jti)
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Refresh token has been revoked",
|
||||
)
|
||||
|
||||
# Verify user still exists
|
||||
user = await user_service.get_by_id(db, user_id)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
# Create new access token (keep same refresh token)
|
||||
access_token = create_access_token(user_id)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=request.refresh_token, # Return same refresh token
|
||||
expires_in=get_token_expiration_seconds(),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def logout(
|
||||
request: RefreshTokenRequest,
|
||||
) -> None:
|
||||
"""Logout by revoking the refresh token.
|
||||
|
||||
The access token will continue to work until it expires,
|
||||
but the refresh token cannot be used to get new tokens.
|
||||
|
||||
Args:
|
||||
request: Contains the refresh token to revoke.
|
||||
"""
|
||||
# Verify refresh token to get user_id and jti
|
||||
result = verify_refresh_token(request.refresh_token)
|
||||
if result is None:
|
||||
# Token is invalid, but that's fine - user is effectively logged out
|
||||
return
|
||||
|
||||
user_id, jti = result
|
||||
|
||||
# Revoke the token
|
||||
await token_store.revoke_token(user_id, jti)
|
||||
|
||||
|
||||
@router.post("/logout-all", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def logout_all(
|
||||
user: CurrentUser,
|
||||
) -> None:
|
||||
"""Logout from all devices by revoking all refresh tokens.
|
||||
|
||||
Requires authentication (uses current access token).
|
||||
All refresh tokens for the user will be revoked.
|
||||
|
||||
Args:
|
||||
user: Current authenticated user.
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
await token_store.revoke_all_user_tokens(user_id)
|
||||
172
backend/app/api/deps.py
Normal file
172
backend/app/api/deps.py
Normal file
@ -0,0 +1,172 @@
|
||||
"""FastAPI dependencies for Mantimon TCG API.
|
||||
|
||||
This module provides dependency injection functions for authentication
|
||||
and database access in API endpoints.
|
||||
|
||||
Usage:
|
||||
from app.api.deps import get_current_user, get_db
|
||||
|
||||
@router.get("/me")
|
||||
async def get_me(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return user
|
||||
|
||||
Dependencies:
|
||||
- get_db: Async database session
|
||||
- get_current_user: Authenticated user from JWT (required)
|
||||
- get_optional_user: Authenticated user or None
|
||||
- get_current_premium_user: User with active premium
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_session_dependency
|
||||
from app.db.models import User
|
||||
from app.services.jwt_service import verify_access_token
|
||||
from app.services.user_service import user_service
|
||||
|
||||
# OAuth2 scheme for extracting Bearer token from Authorization header
|
||||
# tokenUrl is for OpenAPI docs - points to where tokens are obtained
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
tokenUrl="/api/auth/token", # For OpenAPI docs
|
||||
auto_error=True, # Raise 401 if no token
|
||||
)
|
||||
|
||||
oauth2_scheme_optional = OAuth2PasswordBearer(
|
||||
tokenUrl="/api/auth/token",
|
||||
auto_error=False, # Return None if no token
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
"""Get async database session.
|
||||
|
||||
Yields:
|
||||
AsyncSession for database operations.
|
||||
|
||||
Example:
|
||||
@router.get("/items")
|
||||
async def get_items(db: AsyncSession = Depends(get_db)):
|
||||
...
|
||||
"""
|
||||
async with get_session_dependency() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> User:
|
||||
"""Get the current authenticated user from JWT token.
|
||||
|
||||
Validates the access token and fetches the user from database.
|
||||
|
||||
Args:
|
||||
token: JWT access token from Authorization header.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
The authenticated User.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if token is invalid or user not found.
|
||||
|
||||
Example:
|
||||
@router.get("/me")
|
||||
async def get_me(user: User = Depends(get_current_user)):
|
||||
return {"email": user.email}
|
||||
"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Verify token and extract user ID
|
||||
user_id = verify_access_token(token)
|
||||
if user_id is None:
|
||||
raise credentials_exception
|
||||
|
||||
# Fetch user from database
|
||||
user = await user_service.get_by_id(db, user_id)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_optional_user(
|
||||
token: Annotated[str | None, Depends(oauth2_scheme_optional)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> User | None:
|
||||
"""Get the current user if authenticated, or None.
|
||||
|
||||
Useful for endpoints that work both with and without authentication,
|
||||
but may provide additional features for authenticated users.
|
||||
|
||||
Args:
|
||||
token: JWT access token or None.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
The authenticated User, or None if not authenticated.
|
||||
|
||||
Example:
|
||||
@router.get("/cards")
|
||||
async def get_cards(user: User | None = Depends(get_optional_user)):
|
||||
if user:
|
||||
# Show user's collection
|
||||
else:
|
||||
# Show public cards
|
||||
"""
|
||||
if token is None:
|
||||
return None
|
||||
|
||||
user_id = verify_access_token(token)
|
||||
if user_id is None:
|
||||
return None
|
||||
|
||||
return await user_service.get_by_id(db, user_id)
|
||||
|
||||
|
||||
async def get_current_premium_user(
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
) -> User:
|
||||
"""Get the current user and verify they have premium.
|
||||
|
||||
Args:
|
||||
user: The authenticated user.
|
||||
|
||||
Returns:
|
||||
The authenticated User with active premium.
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user doesn't have active premium.
|
||||
|
||||
Example:
|
||||
@router.post("/decks")
|
||||
async def create_unlimited_decks(
|
||||
user: User = Depends(get_current_premium_user)
|
||||
):
|
||||
# Only premium users can have unlimited decks
|
||||
...
|
||||
"""
|
||||
if not user.has_active_premium:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Premium subscription required",
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
# Type aliases for cleaner endpoint signatures
|
||||
CurrentUser = Annotated[User, Depends(get_current_user)]
|
||||
OptionalUser = Annotated[User | None, Depends(get_optional_user)]
|
||||
PremiumUser = Annotated[User, Depends(get_current_premium_user)]
|
||||
DbSession = Annotated[AsyncSession, Depends(get_db)]
|
||||
129
backend/app/api/users.py
Normal file
129
backend/app/api/users.py
Normal file
@ -0,0 +1,129 @@
|
||||
"""User API router for Mantimon TCG.
|
||||
|
||||
This module provides endpoints for user profile management:
|
||||
- Get current user profile
|
||||
- Update profile (display name, avatar)
|
||||
- List linked OAuth accounts
|
||||
- Session management
|
||||
|
||||
All endpoints require authentication.
|
||||
|
||||
Example:
|
||||
# Get current user
|
||||
GET /api/users/me
|
||||
Authorization: Bearer <access_token>
|
||||
|
||||
# Update profile
|
||||
PATCH /api/users/me
|
||||
{"display_name": "NewName"}
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.deps import CurrentUser, DbSession
|
||||
from app.schemas.user import UserResponse, UserUpdate
|
||||
from app.services.token_store import token_store
|
||||
from app.services.user_service import user_service
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
class LinkedAccountResponse(BaseModel):
|
||||
"""Response for a linked OAuth account."""
|
||||
|
||||
provider: str = Field(..., description="OAuth provider name")
|
||||
email: str | None = Field(None, description="Email from this provider")
|
||||
linked_at: str = Field(..., description="When account was linked (ISO format)")
|
||||
|
||||
|
||||
class SessionsResponse(BaseModel):
|
||||
"""Response for active sessions count."""
|
||||
|
||||
active_sessions: int = Field(..., description="Number of active sessions")
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_current_user_profile(
|
||||
user: CurrentUser,
|
||||
) -> UserResponse:
|
||||
"""Get the current user's profile.
|
||||
|
||||
Returns:
|
||||
User profile information.
|
||||
"""
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@router.patch("/me", response_model=UserResponse)
|
||||
async def update_current_user_profile(
|
||||
user: CurrentUser,
|
||||
db: DbSession,
|
||||
update_data: UserUpdate,
|
||||
) -> UserResponse:
|
||||
"""Update the current user's profile.
|
||||
|
||||
Only provided fields are updated.
|
||||
|
||||
Args:
|
||||
update_data: Fields to update (display_name, avatar_url).
|
||||
|
||||
Returns:
|
||||
Updated user profile.
|
||||
"""
|
||||
updated_user = await user_service.update(db, user, update_data)
|
||||
return UserResponse.model_validate(updated_user)
|
||||
|
||||
|
||||
@router.get("/me/linked-accounts", response_model=list[LinkedAccountResponse])
|
||||
async def get_linked_accounts(
|
||||
user: CurrentUser,
|
||||
) -> list[LinkedAccountResponse]:
|
||||
"""Get all OAuth accounts linked to the current user.
|
||||
|
||||
Returns the primary OAuth provider plus any additional linked accounts.
|
||||
|
||||
Returns:
|
||||
List of linked OAuth accounts.
|
||||
"""
|
||||
accounts = []
|
||||
|
||||
# Add primary OAuth account
|
||||
accounts.append(
|
||||
LinkedAccountResponse(
|
||||
provider=user.oauth_provider,
|
||||
email=user.email,
|
||||
linked_at=user.created_at.isoformat(),
|
||||
)
|
||||
)
|
||||
|
||||
# Add additional linked accounts
|
||||
for linked in user.linked_accounts:
|
||||
accounts.append(
|
||||
LinkedAccountResponse(
|
||||
provider=linked.provider,
|
||||
email=linked.email,
|
||||
linked_at=linked.linked_at.isoformat(),
|
||||
)
|
||||
)
|
||||
|
||||
return accounts
|
||||
|
||||
|
||||
@router.get("/me/sessions", response_model=SessionsResponse)
|
||||
async def get_active_sessions(
|
||||
user: CurrentUser,
|
||||
) -> SessionsResponse:
|
||||
"""Get the number of active sessions for the current user.
|
||||
|
||||
Each session corresponds to a valid refresh token.
|
||||
|
||||
Returns:
|
||||
Number of active sessions.
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
count = await token_store.get_active_session_count(user_id)
|
||||
|
||||
return SessionsResponse(active_sessions=count)
|
||||
@ -0,0 +1,69 @@
|
||||
"""add_oauth_linked_accounts
|
||||
|
||||
Revision ID: 5ce887128ab1
|
||||
Revises: ab8a0039fe55
|
||||
Create Date: 2026-01-27 16:42:12.335987
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "5ce887128ab1"
|
||||
down_revision: str | Sequence[str] | None = "ab8a0039fe55"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"oauth_linked_accounts",
|
||||
sa.Column("user_id", sa.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("provider", sa.String(length=20), nullable=False),
|
||||
sa.Column("oauth_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("email", sa.String(length=255), nullable=True),
|
||||
sa.Column("display_name", sa.String(length=100), nullable=True),
|
||||
sa.Column("avatar_url", sa.String(length=500), nullable=True),
|
||||
sa.Column(
|
||||
"linked_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False
|
||||
),
|
||||
sa.Column("id", sa.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_linked_accounts_provider_oauth_id",
|
||||
"oauth_linked_accounts",
|
||||
["provider", "oauth_id"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_linked_accounts_user_id"), "oauth_linked_accounts", ["user_id"], unique=False
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_oauth_linked_accounts_user_id"), table_name="oauth_linked_accounts")
|
||||
op.drop_index("ix_oauth_linked_accounts_provider_oauth_id", table_name="oauth_linked_accounts")
|
||||
op.drop_table("oauth_linked_accounts")
|
||||
# ### end Alembic commands ###
|
||||
@ -25,11 +25,14 @@ from app.db.models.campaign import CampaignProgress
|
||||
from app.db.models.collection import CardSource, Collection
|
||||
from app.db.models.deck import Deck
|
||||
from app.db.models.game import ActiveGame, EndReason, GameHistory, GameType
|
||||
from app.db.models.oauth_account import OAuthLinkedAccount
|
||||
from app.db.models.user import User
|
||||
|
||||
__all__ = [
|
||||
# User
|
||||
"User",
|
||||
# OAuth
|
||||
"OAuthLinkedAccount",
|
||||
# Collection
|
||||
"Collection",
|
||||
"CardSource",
|
||||
|
||||
122
backend/app/db/models/oauth_account.py
Normal file
122
backend/app/db/models/oauth_account.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""OAuth linked account model for Mantimon TCG.
|
||||
|
||||
This module defines the OAuthLinkedAccount model for supporting multiple
|
||||
OAuth providers per user (account linking).
|
||||
|
||||
A user can have multiple linked accounts (e.g., both Google and Discord),
|
||||
allowing them to log in with either provider.
|
||||
|
||||
Example:
|
||||
# User links Discord to their existing Google account
|
||||
linked_account = OAuthLinkedAccount(
|
||||
user_id=user.id,
|
||||
provider="discord",
|
||||
oauth_id="123456789",
|
||||
email="player@example.com"
|
||||
)
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Index, String, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.db.models.user import User
|
||||
|
||||
|
||||
class OAuthLinkedAccount(Base):
|
||||
"""Linked OAuth account for multi-provider authentication.
|
||||
|
||||
Allows users to link multiple OAuth providers to a single account,
|
||||
enabling login via any linked provider.
|
||||
|
||||
The User model still has oauth_provider/oauth_id for the "primary"
|
||||
provider (the one used to create the account). This table tracks
|
||||
additional linked providers.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier (UUID).
|
||||
user_id: Foreign key to the user who owns this linked account.
|
||||
provider: OAuth provider name ('google', 'discord').
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
email: Email address from this OAuth provider (may differ from user's primary email).
|
||||
display_name: Display name from this OAuth provider.
|
||||
avatar_url: Avatar URL from this OAuth provider.
|
||||
linked_at: When this account was linked.
|
||||
|
||||
Relationships:
|
||||
user: The User who owns this linked account.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_linked_accounts"
|
||||
|
||||
# Foreign key to user
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
doc="User who owns this linked account",
|
||||
)
|
||||
|
||||
# OAuth provider info
|
||||
provider: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
doc="OAuth provider name (google, discord)",
|
||||
)
|
||||
oauth_id: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
doc="Unique ID from OAuth provider",
|
||||
)
|
||||
|
||||
# Additional info from provider
|
||||
email: Mapped[str | None] = mapped_column(
|
||||
String(255),
|
||||
nullable=True,
|
||||
doc="Email from this OAuth provider",
|
||||
)
|
||||
display_name: Mapped[str | None] = mapped_column(
|
||||
String(100),
|
||||
nullable=True,
|
||||
doc="Display name from this OAuth provider",
|
||||
)
|
||||
avatar_url: Mapped[str | None] = mapped_column(
|
||||
String(500),
|
||||
nullable=True,
|
||||
doc="Avatar URL from this OAuth provider",
|
||||
)
|
||||
|
||||
# When linked
|
||||
linked_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
doc="When this account was linked",
|
||||
)
|
||||
|
||||
# Relationship back to user
|
||||
user: Mapped["User"] = relationship(
|
||||
"User",
|
||||
back_populates="linked_accounts",
|
||||
)
|
||||
|
||||
# Indexes and constraints
|
||||
__table_args__ = (
|
||||
# Each OAuth provider+ID can only be linked to one user
|
||||
Index(
|
||||
"ix_oauth_linked_accounts_provider_oauth_id",
|
||||
"provider",
|
||||
"oauth_id",
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<OAuthLinkedAccount(user_id={self.user_id!r}, provider={self.provider!r})>"
|
||||
@ -24,6 +24,7 @@ if TYPE_CHECKING:
|
||||
from app.db.models.campaign import CampaignProgress
|
||||
from app.db.models.collection import Collection
|
||||
from app.db.models.deck import Deck
|
||||
from app.db.models.oauth_account import OAuthLinkedAccount
|
||||
|
||||
|
||||
class User(Base):
|
||||
@ -45,6 +46,7 @@ class User(Base):
|
||||
decks: User's deck configurations.
|
||||
collection: User's card collection.
|
||||
campaign_progress: User's campaign state.
|
||||
linked_accounts: Additional linked OAuth providers.
|
||||
"""
|
||||
|
||||
__tablename__ = "users"
|
||||
@ -120,6 +122,12 @@ class User(Base):
|
||||
uselist=False,
|
||||
lazy="selectin",
|
||||
)
|
||||
linked_accounts: Mapped[list["OAuthLinkedAccount"]] = relationship(
|
||||
"OAuthLinkedAccount",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (Index("ix_users_oauth", "oauth_provider", "oauth_id", unique=True),)
|
||||
|
||||
@ -18,6 +18,8 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.auth import router as auth_router
|
||||
from app.api.users import router as users_router
|
||||
from app.config import settings
|
||||
from app.db import close_db, init_db
|
||||
from app.db.redis import close_redis, init_redis
|
||||
@ -159,9 +161,11 @@ async def readiness_check() -> dict[str, str | int]:
|
||||
|
||||
|
||||
# === API Routers ===
|
||||
# TODO: Add routers in Phase 2
|
||||
# from app.api import auth, games, cards, decks, campaign
|
||||
# app.include_router(auth.router, prefix="/api/auth", tags=["auth"])
|
||||
app.include_router(auth_router, prefix="/api")
|
||||
app.include_router(users_router, prefix="/api")
|
||||
|
||||
# TODO: Add remaining routers in future phases
|
||||
# from app.api import cards, decks, games, campaign
|
||||
# app.include_router(cards.router, prefix="/api/cards", tags=["cards"])
|
||||
# app.include_router(decks.router, prefix="/api/decks", tags=["decks"])
|
||||
# app.include_router(games.router, prefix="/api/games", tags=["games"])
|
||||
|
||||
32
backend/app/schemas/__init__.py
Normal file
32
backend/app/schemas/__init__.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""Pydantic schemas for Mantimon TCG API.
|
||||
|
||||
This package contains request/response models for all API endpoints.
|
||||
"""
|
||||
|
||||
from app.schemas.auth import (
|
||||
OAuthState,
|
||||
RefreshTokenRequest,
|
||||
TokenPayload,
|
||||
TokenResponse,
|
||||
TokenType,
|
||||
)
|
||||
from app.schemas.user import (
|
||||
OAuthUserInfo,
|
||||
UserCreate,
|
||||
UserResponse,
|
||||
UserUpdate,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Auth schemas
|
||||
"TokenType",
|
||||
"TokenPayload",
|
||||
"TokenResponse",
|
||||
"RefreshTokenRequest",
|
||||
"OAuthState",
|
||||
# User schemas
|
||||
"UserResponse",
|
||||
"UserCreate",
|
||||
"UserUpdate",
|
||||
"OAuthUserInfo",
|
||||
]
|
||||
87
backend/app/schemas/auth.py
Normal file
87
backend/app/schemas/auth.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""Authentication schemas for Mantimon TCG.
|
||||
|
||||
This module defines Pydantic models for JWT tokens and authentication
|
||||
responses used throughout the auth system.
|
||||
|
||||
Example:
|
||||
token_payload = TokenPayload(
|
||||
sub="550e8400-e29b-41d4-a716-446655440000",
|
||||
exp=datetime.now(UTC) + timedelta(minutes=30),
|
||||
iat=datetime.now(UTC),
|
||||
type="access"
|
||||
)
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TokenType(str, Enum):
|
||||
"""Type of JWT token."""
|
||||
|
||||
ACCESS = "access"
|
||||
REFRESH = "refresh"
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""JWT token payload structure.
|
||||
|
||||
Attributes:
|
||||
sub: Subject - the user ID as a string (UUID format).
|
||||
exp: Expiration timestamp.
|
||||
iat: Issued-at timestamp.
|
||||
type: Token type (access or refresh).
|
||||
jti: JWT ID - unique identifier for refresh token tracking.
|
||||
"""
|
||||
|
||||
sub: str = Field(..., description="User ID (UUID as string)")
|
||||
exp: datetime = Field(..., description="Token expiration timestamp")
|
||||
iat: datetime = Field(..., description="Token issued-at timestamp")
|
||||
type: TokenType = Field(..., description="Token type (access/refresh)")
|
||||
jti: str | None = Field(default=None, description="JWT ID for refresh token tracking")
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Response containing JWT tokens after successful authentication.
|
||||
|
||||
Attributes:
|
||||
access_token: Short-lived JWT for API authentication.
|
||||
refresh_token: Longer-lived JWT for obtaining new access tokens.
|
||||
token_type: Always "bearer" for OAuth 2.0 compatibility.
|
||||
expires_in: Access token lifetime in seconds.
|
||||
"""
|
||||
|
||||
access_token: str = Field(..., description="JWT access token")
|
||||
refresh_token: str = Field(..., description="JWT refresh token")
|
||||
token_type: str = Field(default="bearer", description="Token type (always bearer)")
|
||||
expires_in: int = Field(..., description="Access token lifetime in seconds")
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""Request to refresh an access token.
|
||||
|
||||
Attributes:
|
||||
refresh_token: The refresh token to exchange for a new access token.
|
||||
"""
|
||||
|
||||
refresh_token: str = Field(..., description="Refresh token to exchange")
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
"""OAuth state parameter for CSRF protection.
|
||||
|
||||
Stored in Redis with short TTL during OAuth flow.
|
||||
|
||||
Attributes:
|
||||
state: Random string for CSRF protection.
|
||||
redirect_uri: Where to redirect after OAuth callback.
|
||||
provider: OAuth provider name.
|
||||
created_at: When the state was created.
|
||||
"""
|
||||
|
||||
state: str = Field(..., description="Random CSRF protection string")
|
||||
redirect_uri: str = Field(..., description="Post-auth redirect URI")
|
||||
provider: str = Field(..., description="OAuth provider (google, discord)")
|
||||
created_at: datetime = Field(..., description="State creation timestamp")
|
||||
114
backend/app/schemas/user.py
Normal file
114
backend/app/schemas/user.py
Normal file
@ -0,0 +1,114 @@
|
||||
"""User schemas for Mantimon TCG.
|
||||
|
||||
This module defines Pydantic models for user-related API requests
|
||||
and responses.
|
||||
|
||||
Example:
|
||||
user_response = UserResponse(
|
||||
id="550e8400-e29b-41d4-a716-446655440000",
|
||||
email="player@example.com",
|
||||
display_name="Player1",
|
||||
is_premium=False,
|
||||
created_at=datetime.now(UTC)
|
||||
)
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""Public user information returned by API endpoints.
|
||||
|
||||
Attributes:
|
||||
id: User's unique identifier.
|
||||
email: User's email address.
|
||||
display_name: User's public display name.
|
||||
avatar_url: URL to user's avatar image.
|
||||
is_premium: Whether user has active premium subscription.
|
||||
premium_until: When premium subscription expires (if premium).
|
||||
created_at: When the account was created.
|
||||
"""
|
||||
|
||||
id: UUID = Field(..., description="User ID")
|
||||
email: EmailStr = Field(..., description="User's email address")
|
||||
display_name: str = Field(..., description="Public display name")
|
||||
avatar_url: str | None = Field(default=None, description="Avatar image URL")
|
||||
is_premium: bool = Field(default=False, description="Premium subscription status")
|
||||
premium_until: datetime | None = Field(default=None, description="Premium expiration date")
|
||||
created_at: datetime = Field(..., description="Account creation date")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
"""Internal schema for creating a user from OAuth data.
|
||||
|
||||
Not exposed via API - used internally by auth service.
|
||||
|
||||
Attributes:
|
||||
email: User's email from OAuth provider.
|
||||
display_name: User's name from OAuth provider.
|
||||
avatar_url: Avatar URL from OAuth provider.
|
||||
oauth_provider: OAuth provider name (google, discord).
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
"""
|
||||
|
||||
email: EmailStr = Field(..., description="Email from OAuth provider")
|
||||
display_name: str = Field(..., max_length=50, description="Display name")
|
||||
avatar_url: str | None = Field(default=None, description="Avatar URL")
|
||||
oauth_provider: str = Field(..., description="OAuth provider (google, discord)")
|
||||
oauth_id: str = Field(..., description="Unique ID from OAuth provider")
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""Schema for updating user profile.
|
||||
|
||||
All fields are optional - only provided fields are updated.
|
||||
|
||||
Attributes:
|
||||
display_name: New display name.
|
||||
avatar_url: New avatar URL.
|
||||
"""
|
||||
|
||||
display_name: str | None = Field(
|
||||
default=None, min_length=1, max_length=50, description="New display name"
|
||||
)
|
||||
avatar_url: str | None = Field(default=None, description="New avatar URL")
|
||||
|
||||
|
||||
class OAuthUserInfo(BaseModel):
|
||||
"""Normalized user information from OAuth providers.
|
||||
|
||||
This provides a consistent structure regardless of whether
|
||||
the user authenticated via Google or Discord.
|
||||
|
||||
Attributes:
|
||||
provider: OAuth provider name.
|
||||
oauth_id: Unique ID from the provider.
|
||||
email: User's email address.
|
||||
name: User's display name from provider.
|
||||
avatar_url: Avatar URL from provider.
|
||||
"""
|
||||
|
||||
provider: str = Field(..., description="OAuth provider (google, discord)")
|
||||
oauth_id: str = Field(..., description="Unique ID from provider")
|
||||
email: EmailStr = Field(..., description="User's email")
|
||||
name: str = Field(..., description="User's name from provider")
|
||||
avatar_url: str | None = Field(default=None, description="Avatar URL from provider")
|
||||
|
||||
def to_user_create(self) -> UserCreate:
|
||||
"""Convert OAuth info to UserCreate schema.
|
||||
|
||||
Returns:
|
||||
UserCreate instance ready for user creation.
|
||||
"""
|
||||
return UserCreate(
|
||||
email=self.email,
|
||||
display_name=self.name[:50], # Enforce max length
|
||||
avatar_url=self.avatar_url,
|
||||
oauth_provider=self.provider,
|
||||
oauth_id=self.oauth_id,
|
||||
)
|
||||
206
backend/app/services/jwt_service.py
Normal file
206
backend/app/services/jwt_service.py
Normal file
@ -0,0 +1,206 @@
|
||||
"""JWT token service for Mantimon TCG.
|
||||
|
||||
This module provides functions for creating and verifying JWT tokens
|
||||
used in the authentication system.
|
||||
|
||||
Token Types:
|
||||
- Access tokens: Short-lived (30 min default), used for API authentication
|
||||
- Refresh tokens: Longer-lived (7 days default), used to obtain new access tokens
|
||||
|
||||
Example:
|
||||
from app.services.jwt_service import create_access_token, verify_token
|
||||
|
||||
# Create tokens
|
||||
access_token = create_access_token(user_id)
|
||||
refresh_token, jti = create_refresh_token(user_id)
|
||||
|
||||
# Verify token
|
||||
user_id = verify_token(access_token)
|
||||
if user_id:
|
||||
print(f"Valid token for user {user_id}")
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.config import settings
|
||||
from app.schemas.auth import TokenPayload, TokenType
|
||||
|
||||
|
||||
def create_access_token(user_id: uuid.UUID) -> str:
|
||||
"""Create a short-lived JWT access token.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID to encode in the token.
|
||||
|
||||
Returns:
|
||||
Encoded JWT access token string.
|
||||
|
||||
Example:
|
||||
token = create_access_token(user.id)
|
||||
# Use token in Authorization header: Bearer {token}
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
expires = now + timedelta(minutes=settings.jwt_expire_minutes)
|
||||
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"exp": expires,
|
||||
"iat": now,
|
||||
"type": TokenType.ACCESS.value,
|
||||
}
|
||||
|
||||
return jwt.encode(
|
||||
payload,
|
||||
settings.secret_key.get_secret_value(),
|
||||
algorithm=settings.jwt_algorithm,
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: uuid.UUID) -> tuple[str, str]:
|
||||
"""Create a longer-lived JWT refresh token.
|
||||
|
||||
Refresh tokens include a JTI (JWT ID) for tracking and revocation.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID to encode in the token.
|
||||
|
||||
Returns:
|
||||
Tuple of (encoded JWT refresh token, JTI string).
|
||||
The JTI should be stored in Redis for revocation tracking.
|
||||
|
||||
Example:
|
||||
token, jti = create_refresh_token(user.id)
|
||||
await token_store.store_refresh_token(user.id, jti, expires_at)
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
expires = now + timedelta(days=settings.jwt_refresh_expire_days)
|
||||
jti = str(uuid.uuid4())
|
||||
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"exp": expires,
|
||||
"iat": now,
|
||||
"type": TokenType.REFRESH.value,
|
||||
"jti": jti,
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.secret_key.get_secret_value(),
|
||||
algorithm=settings.jwt_algorithm,
|
||||
)
|
||||
|
||||
return token, jti
|
||||
|
||||
|
||||
def decode_token(token: str) -> TokenPayload | None:
|
||||
"""Decode and validate a JWT token.
|
||||
|
||||
Args:
|
||||
token: The JWT token string to decode.
|
||||
|
||||
Returns:
|
||||
TokenPayload if valid, None if invalid or expired.
|
||||
|
||||
Example:
|
||||
payload = decode_token(token)
|
||||
if payload:
|
||||
user_id = UUID(payload.sub)
|
||||
"""
|
||||
try:
|
||||
payload_dict = jwt.decode(
|
||||
token,
|
||||
settings.secret_key.get_secret_value(),
|
||||
algorithms=[settings.jwt_algorithm],
|
||||
)
|
||||
return TokenPayload(**payload_dict)
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
def verify_access_token(token: str) -> uuid.UUID | None:
|
||||
"""Verify an access token and extract the user ID.
|
||||
|
||||
Args:
|
||||
token: The JWT access token to verify.
|
||||
|
||||
Returns:
|
||||
User UUID if valid access token, None otherwise.
|
||||
|
||||
Example:
|
||||
user_id = verify_access_token(token)
|
||||
if user_id:
|
||||
user = await user_service.get_user_by_id(db, user_id)
|
||||
"""
|
||||
payload = decode_token(token)
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
if payload.type != TokenType.ACCESS:
|
||||
return None
|
||||
|
||||
try:
|
||||
return uuid.UUID(payload.sub)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def verify_refresh_token(token: str) -> tuple[uuid.UUID, str] | None:
|
||||
"""Verify a refresh token and extract user ID and JTI.
|
||||
|
||||
The JTI should be checked against the token store to ensure
|
||||
the token hasn't been revoked.
|
||||
|
||||
Args:
|
||||
token: The JWT refresh token to verify.
|
||||
|
||||
Returns:
|
||||
Tuple of (user UUID, JTI) if valid refresh token, None otherwise.
|
||||
|
||||
Example:
|
||||
result = verify_refresh_token(token)
|
||||
if result:
|
||||
user_id, jti = result
|
||||
if await token_store.is_token_valid(user_id, jti):
|
||||
# Issue new access token
|
||||
"""
|
||||
payload = decode_token(token)
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
if payload.type != TokenType.REFRESH:
|
||||
return None
|
||||
|
||||
if payload.jti is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
user_id = uuid.UUID(payload.sub)
|
||||
return user_id, payload.jti
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def get_token_expiration_seconds() -> int:
|
||||
"""Get the access token expiration time in seconds.
|
||||
|
||||
Useful for the expires_in field in token responses.
|
||||
|
||||
Returns:
|
||||
Access token lifetime in seconds.
|
||||
"""
|
||||
return settings.jwt_expire_minutes * 60
|
||||
|
||||
|
||||
def get_refresh_token_expiration() -> datetime:
|
||||
"""Get the expiration datetime for a new refresh token.
|
||||
|
||||
Useful for setting TTL when storing in Redis.
|
||||
|
||||
Returns:
|
||||
Datetime when a refresh token created now would expire.
|
||||
"""
|
||||
return datetime.now(UTC) + timedelta(days=settings.jwt_refresh_expire_days)
|
||||
25
backend/app/services/oauth/__init__.py
Normal file
25
backend/app/services/oauth/__init__.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""OAuth provider services for Mantimon TCG.
|
||||
|
||||
This package contains OAuth integration services for supported providers.
|
||||
|
||||
Providers:
|
||||
- Google: OAuth 2.0 with Google accounts
|
||||
- Discord: OAuth 2.0 with Discord accounts
|
||||
|
||||
Example:
|
||||
from app.services.oauth import google_oauth, discord_oauth
|
||||
|
||||
# Get authorization URL
|
||||
auth_url = google_oauth.get_authorization_url(redirect_uri, state)
|
||||
|
||||
# Exchange code for user info
|
||||
user_info = await google_oauth.get_user_info(code, redirect_uri)
|
||||
"""
|
||||
|
||||
from app.services.oauth.discord import discord_oauth
|
||||
from app.services.oauth.google import google_oauth
|
||||
|
||||
__all__ = [
|
||||
"google_oauth",
|
||||
"discord_oauth",
|
||||
]
|
||||
242
backend/app/services/oauth/discord.py
Normal file
242
backend/app/services/oauth/discord.py
Normal file
@ -0,0 +1,242 @@
|
||||
"""Discord OAuth service for Mantimon TCG.
|
||||
|
||||
This module handles Discord OAuth 2.0 authentication flow:
|
||||
1. Generate authorization URL for user redirect
|
||||
2. Exchange authorization code for tokens
|
||||
3. Fetch user information from Discord
|
||||
|
||||
Discord OAuth Endpoints:
|
||||
- Authorization: https://discord.com/api/oauth2/authorize
|
||||
- Token: https://discord.com/api/oauth2/token
|
||||
- User Info: https://discord.com/api/users/@me
|
||||
|
||||
Example:
|
||||
from app.services.oauth.discord import discord_oauth
|
||||
|
||||
# Step 1: Redirect user to Discord
|
||||
auth_url = discord_oauth.get_authorization_url(
|
||||
redirect_uri="https://play.mantimon.com/api/auth/discord/callback",
|
||||
state="random-csrf-token"
|
||||
)
|
||||
|
||||
# Step 2: Handle callback and get user info
|
||||
user_info = await discord_oauth.get_user_info(code, redirect_uri)
|
||||
"""
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
from app.schemas.user import OAuthUserInfo
|
||||
|
||||
|
||||
class DiscordOAuthError(Exception):
|
||||
"""Exception raised for Discord OAuth errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DiscordOAuth:
|
||||
"""Discord OAuth 2.0 service.
|
||||
|
||||
Handles the OAuth flow for authenticating users with Discord accounts.
|
||||
"""
|
||||
|
||||
AUTHORIZATION_URL = "https://discord.com/api/oauth2/authorize"
|
||||
TOKEN_URL = "https://discord.com/api/oauth2/token"
|
||||
USER_INFO_URL = "https://discord.com/api/users/@me"
|
||||
CDN_URL = "https://cdn.discordapp.com"
|
||||
|
||||
# Scopes we request from Discord
|
||||
SCOPES = ["identify", "email"]
|
||||
|
||||
def get_authorization_url(self, redirect_uri: str, state: str) -> str:
|
||||
"""Generate the Discord OAuth authorization URL.
|
||||
|
||||
Args:
|
||||
redirect_uri: Where Discord should redirect after authorization.
|
||||
state: Random string for CSRF protection.
|
||||
|
||||
Returns:
|
||||
Full authorization URL to redirect user to.
|
||||
|
||||
Raises:
|
||||
DiscordOAuthError: If Discord OAuth is not configured.
|
||||
|
||||
Example:
|
||||
url = discord_oauth.get_authorization_url(
|
||||
redirect_uri="https://play.mantimon.com/api/auth/discord/callback",
|
||||
state="abc123"
|
||||
)
|
||||
# Redirect user to url
|
||||
"""
|
||||
if not settings.discord_client_id:
|
||||
raise DiscordOAuthError("Discord OAuth is not configured")
|
||||
|
||||
params = {
|
||||
"client_id": settings.discord_client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(self.SCOPES),
|
||||
"state": state,
|
||||
"prompt": "consent", # Always show consent screen
|
||||
}
|
||||
|
||||
return f"{self.AUTHORIZATION_URL}?{urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
) -> dict:
|
||||
"""Exchange authorization code for access tokens.
|
||||
|
||||
Args:
|
||||
code: Authorization code from Discord callback.
|
||||
redirect_uri: Same redirect_uri used in authorization request.
|
||||
|
||||
Returns:
|
||||
Token response containing access_token, refresh_token, etc.
|
||||
|
||||
Raises:
|
||||
DiscordOAuthError: If token exchange fails.
|
||||
"""
|
||||
if not settings.discord_client_id or not settings.discord_client_secret:
|
||||
raise DiscordOAuthError("Discord OAuth is not configured")
|
||||
|
||||
data = {
|
||||
"client_id": settings.discord_client_id,
|
||||
"client_secret": settings.discord_client_secret.get_secret_value(),
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.TOKEN_URL,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_data = response.json() if response.content else {}
|
||||
error_msg = error_data.get("error_description", response.text)
|
||||
raise DiscordOAuthError(f"Token exchange failed: {error_msg}")
|
||||
|
||||
return response.json()
|
||||
|
||||
async def fetch_user_info(self, access_token: str) -> dict:
|
||||
"""Fetch user information from Discord.
|
||||
|
||||
Args:
|
||||
access_token: Valid Discord access token.
|
||||
|
||||
Returns:
|
||||
User info dict with id, username, email, avatar, etc.
|
||||
|
||||
Raises:
|
||||
DiscordOAuthError: If fetching user info fails.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
self.USER_INFO_URL,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise DiscordOAuthError(f"Failed to fetch user info: {response.text}")
|
||||
|
||||
return response.json()
|
||||
|
||||
def _build_avatar_url(self, user_id: str, avatar_hash: str | None) -> str | None:
|
||||
"""Build Discord avatar URL from user ID and avatar hash.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID.
|
||||
avatar_hash: Avatar hash from Discord API (can be None).
|
||||
|
||||
Returns:
|
||||
Full CDN URL for avatar, or None if no avatar.
|
||||
|
||||
Note:
|
||||
Discord avatar format: https://cdn.discordapp.com/avatars/{user_id}/{avatar_hash}.png
|
||||
If avatar_hash starts with 'a_', it's animated (gif).
|
||||
"""
|
||||
if not avatar_hash:
|
||||
return None
|
||||
|
||||
# Animated avatars start with 'a_'
|
||||
extension = "gif" if avatar_hash.startswith("a_") else "png"
|
||||
return f"{self.CDN_URL}/avatars/{user_id}/{avatar_hash}.{extension}"
|
||||
|
||||
async def get_user_info(
|
||||
self,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
) -> OAuthUserInfo:
|
||||
"""Complete OAuth flow: exchange code and fetch user info.
|
||||
|
||||
This is the main method to call after receiving the OAuth callback.
|
||||
It exchanges the authorization code for tokens, then fetches user info.
|
||||
|
||||
Args:
|
||||
code: Authorization code from Discord callback.
|
||||
redirect_uri: Same redirect_uri used in authorization request.
|
||||
|
||||
Returns:
|
||||
Normalized OAuthUserInfo ready for user creation/lookup.
|
||||
|
||||
Raises:
|
||||
DiscordOAuthError: If any step of the OAuth flow fails.
|
||||
|
||||
Example:
|
||||
# In your callback handler:
|
||||
user_info = await discord_oauth.get_user_info(code, redirect_uri)
|
||||
user, created = await user_service.get_or_create_from_oauth(db, user_info)
|
||||
"""
|
||||
# Exchange code for tokens
|
||||
tokens = await self.exchange_code_for_tokens(code, redirect_uri)
|
||||
access_token = tokens.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
raise DiscordOAuthError("No access token in response")
|
||||
|
||||
# Fetch user info
|
||||
user_data = await self.fetch_user_info(access_token)
|
||||
|
||||
# Discord requires email scope, but email can still be None if not verified
|
||||
email = user_data.get("email")
|
||||
if not email:
|
||||
raise DiscordOAuthError("Discord account does not have a verified email")
|
||||
|
||||
# Build display name: prefer global_name, then username
|
||||
display_name = user_data.get("global_name") or user_data["username"]
|
||||
|
||||
# Build avatar URL
|
||||
avatar_url = self._build_avatar_url(
|
||||
user_data["id"],
|
||||
user_data.get("avatar"),
|
||||
)
|
||||
|
||||
# Normalize to our schema
|
||||
return OAuthUserInfo(
|
||||
provider="discord",
|
||||
oauth_id=user_data["id"],
|
||||
email=email,
|
||||
name=display_name,
|
||||
avatar_url=avatar_url,
|
||||
)
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if Discord OAuth is properly configured.
|
||||
|
||||
Returns:
|
||||
True if client ID and secret are set.
|
||||
"""
|
||||
return bool(settings.discord_client_id and settings.discord_client_secret)
|
||||
|
||||
|
||||
# Global instance
|
||||
discord_oauth = DiscordOAuth()
|
||||
207
backend/app/services/oauth/google.py
Normal file
207
backend/app/services/oauth/google.py
Normal file
@ -0,0 +1,207 @@
|
||||
"""Google OAuth service for Mantimon TCG.
|
||||
|
||||
This module handles Google OAuth 2.0 authentication flow:
|
||||
1. Generate authorization URL for user redirect
|
||||
2. Exchange authorization code for tokens
|
||||
3. Fetch user information from Google
|
||||
|
||||
Google OAuth Endpoints:
|
||||
- Authorization: https://accounts.google.com/o/oauth2/v2/auth
|
||||
- Token: https://oauth2.googleapis.com/token
|
||||
- User Info: https://www.googleapis.com/oauth2/v2/userinfo
|
||||
|
||||
Example:
|
||||
from app.services.oauth.google import google_oauth
|
||||
|
||||
# Step 1: Redirect user to Google
|
||||
auth_url = google_oauth.get_authorization_url(
|
||||
redirect_uri="https://play.mantimon.com/api/auth/google/callback",
|
||||
state="random-csrf-token"
|
||||
)
|
||||
|
||||
# Step 2: Handle callback and get user info
|
||||
user_info = await google_oauth.get_user_info(code, redirect_uri)
|
||||
"""
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
from app.schemas.user import OAuthUserInfo
|
||||
|
||||
|
||||
class GoogleOAuthError(Exception):
|
||||
"""Exception raised for Google OAuth errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GoogleOAuth:
|
||||
"""Google OAuth 2.0 service.
|
||||
|
||||
Handles the OAuth flow for authenticating users with Google accounts.
|
||||
"""
|
||||
|
||||
AUTHORIZATION_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
USER_INFO_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
# Scopes we request from Google
|
||||
SCOPES = ["openid", "email", "profile"]
|
||||
|
||||
def get_authorization_url(self, redirect_uri: str, state: str) -> str:
|
||||
"""Generate the Google OAuth authorization URL.
|
||||
|
||||
Args:
|
||||
redirect_uri: Where Google should redirect after authorization.
|
||||
state: Random string for CSRF protection.
|
||||
|
||||
Returns:
|
||||
Full authorization URL to redirect user to.
|
||||
|
||||
Raises:
|
||||
GoogleOAuthError: If Google OAuth is not configured.
|
||||
|
||||
Example:
|
||||
url = google_oauth.get_authorization_url(
|
||||
redirect_uri="https://play.mantimon.com/api/auth/google/callback",
|
||||
state="abc123"
|
||||
)
|
||||
# Redirect user to url
|
||||
"""
|
||||
if not settings.google_client_id:
|
||||
raise GoogleOAuthError("Google OAuth is not configured")
|
||||
|
||||
params = {
|
||||
"client_id": settings.google_client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(self.SCOPES),
|
||||
"state": state,
|
||||
"access_type": "offline", # Get refresh token
|
||||
"prompt": "select_account", # Always show account picker
|
||||
}
|
||||
|
||||
return f"{self.AUTHORIZATION_URL}?{urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
) -> dict:
|
||||
"""Exchange authorization code for access tokens.
|
||||
|
||||
Args:
|
||||
code: Authorization code from Google callback.
|
||||
redirect_uri: Same redirect_uri used in authorization request.
|
||||
|
||||
Returns:
|
||||
Token response containing access_token, id_token, etc.
|
||||
|
||||
Raises:
|
||||
GoogleOAuthError: If token exchange fails.
|
||||
"""
|
||||
if not settings.google_client_id or not settings.google_client_secret:
|
||||
raise GoogleOAuthError("Google OAuth is not configured")
|
||||
|
||||
data = {
|
||||
"client_id": settings.google_client_id,
|
||||
"client_secret": settings.google_client_secret.get_secret_value(),
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.TOKEN_URL,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_data = response.json() if response.content else {}
|
||||
error_msg = error_data.get("error_description", response.text)
|
||||
raise GoogleOAuthError(f"Token exchange failed: {error_msg}")
|
||||
|
||||
return response.json()
|
||||
|
||||
async def fetch_user_info(self, access_token: str) -> dict:
|
||||
"""Fetch user information from Google.
|
||||
|
||||
Args:
|
||||
access_token: Valid Google access token.
|
||||
|
||||
Returns:
|
||||
User info dict with id, email, name, picture, etc.
|
||||
|
||||
Raises:
|
||||
GoogleOAuthError: If fetching user info fails.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
self.USER_INFO_URL,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise GoogleOAuthError(f"Failed to fetch user info: {response.text}")
|
||||
|
||||
return response.json()
|
||||
|
||||
async def get_user_info(
|
||||
self,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
) -> OAuthUserInfo:
|
||||
"""Complete OAuth flow: exchange code and fetch user info.
|
||||
|
||||
This is the main method to call after receiving the OAuth callback.
|
||||
It exchanges the authorization code for tokens, then fetches user info.
|
||||
|
||||
Args:
|
||||
code: Authorization code from Google callback.
|
||||
redirect_uri: Same redirect_uri used in authorization request.
|
||||
|
||||
Returns:
|
||||
Normalized OAuthUserInfo ready for user creation/lookup.
|
||||
|
||||
Raises:
|
||||
GoogleOAuthError: If any step of the OAuth flow fails.
|
||||
|
||||
Example:
|
||||
# In your callback handler:
|
||||
user_info = await google_oauth.get_user_info(code, redirect_uri)
|
||||
user, created = await user_service.get_or_create_from_oauth(db, user_info)
|
||||
"""
|
||||
# Exchange code for tokens
|
||||
tokens = await self.exchange_code_for_tokens(code, redirect_uri)
|
||||
access_token = tokens.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
raise GoogleOAuthError("No access token in response")
|
||||
|
||||
# Fetch user info
|
||||
user_data = await self.fetch_user_info(access_token)
|
||||
|
||||
# Normalize to our schema
|
||||
return OAuthUserInfo(
|
||||
provider="google",
|
||||
oauth_id=user_data["id"],
|
||||
email=user_data["email"],
|
||||
name=user_data.get("name", user_data["email"].split("@")[0]),
|
||||
avatar_url=user_data.get("picture"),
|
||||
)
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if Google OAuth is properly configured.
|
||||
|
||||
Returns:
|
||||
True if client ID and secret are set.
|
||||
"""
|
||||
return bool(settings.google_client_id and settings.google_client_secret)
|
||||
|
||||
|
||||
# Global instance
|
||||
google_oauth = GoogleOAuth()
|
||||
195
backend/app/services/token_store.py
Normal file
195
backend/app/services/token_store.py
Normal file
@ -0,0 +1,195 @@
|
||||
"""Refresh token storage for Mantimon TCG.
|
||||
|
||||
This module provides Redis-based storage for refresh token tracking
|
||||
and revocation. Each refresh token's JTI is stored in Redis with
|
||||
a TTL matching the token's expiration.
|
||||
|
||||
Key Pattern:
|
||||
refresh_token:{user_id}:{jti} -> "1" (exists = valid)
|
||||
|
||||
Revocation:
|
||||
- Single token: Delete the specific key
|
||||
- All user tokens: Delete all keys matching refresh_token:{user_id}:*
|
||||
|
||||
Example:
|
||||
from app.services.token_store import token_store
|
||||
|
||||
# Store a new refresh token
|
||||
await token_store.store_refresh_token(user_id, jti, expires_at)
|
||||
|
||||
# Check if token is valid (not revoked)
|
||||
if await token_store.is_token_valid(user_id, jti):
|
||||
# Issue new access token
|
||||
|
||||
# Revoke on logout
|
||||
await token_store.revoke_token(user_id, jti)
|
||||
|
||||
# Logout from all devices
|
||||
await token_store.revoke_all_user_tokens(user_id)
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from app.db.redis import get_redis
|
||||
|
||||
|
||||
class TokenStore:
|
||||
"""Redis-based refresh token storage for revocation support.
|
||||
|
||||
Tracks valid refresh tokens by storing their JTIs in Redis.
|
||||
Tokens can be revoked individually or all at once per user.
|
||||
"""
|
||||
|
||||
KEY_PREFIX = "refresh_token"
|
||||
|
||||
def _make_key(self, user_id: UUID, jti: str) -> str:
|
||||
"""Create Redis key for a refresh token.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
jti: The token's unique identifier.
|
||||
|
||||
Returns:
|
||||
Redis key string.
|
||||
"""
|
||||
return f"{self.KEY_PREFIX}:{user_id}:{jti}"
|
||||
|
||||
def _make_user_pattern(self, user_id: UUID) -> str:
|
||||
"""Create Redis key pattern for all user's tokens.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
Redis key pattern for SCAN/KEYS.
|
||||
"""
|
||||
return f"{self.KEY_PREFIX}:{user_id}:*"
|
||||
|
||||
async def store_refresh_token(
|
||||
self,
|
||||
user_id: UUID,
|
||||
jti: str,
|
||||
expires_at: datetime,
|
||||
) -> None:
|
||||
"""Store a refresh token's JTI in Redis.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
jti: The token's unique identifier (from JWT).
|
||||
expires_at: When the token expires (for TTL calculation).
|
||||
|
||||
Example:
|
||||
expires_at = datetime.now(UTC) + timedelta(days=7)
|
||||
await token_store.store_refresh_token(user_id, jti, expires_at)
|
||||
"""
|
||||
key = self._make_key(user_id, jti)
|
||||
|
||||
# Calculate TTL in seconds
|
||||
now = datetime.now(UTC)
|
||||
ttl_seconds = int((expires_at - now).total_seconds())
|
||||
|
||||
if ttl_seconds <= 0:
|
||||
# Token already expired, don't store
|
||||
return
|
||||
|
||||
async with get_redis() as redis:
|
||||
await redis.setex(key, ttl_seconds, "1")
|
||||
|
||||
async def is_token_valid(self, user_id: UUID, jti: str) -> bool:
|
||||
"""Check if a refresh token is valid (not revoked).
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
jti: The token's unique identifier.
|
||||
|
||||
Returns:
|
||||
True if token exists in store (valid), False if revoked or expired.
|
||||
|
||||
Example:
|
||||
if await token_store.is_token_valid(user_id, jti):
|
||||
# Token is valid, issue new access token
|
||||
else:
|
||||
# Token was revoked, require re-authentication
|
||||
"""
|
||||
key = self._make_key(user_id, jti)
|
||||
|
||||
async with get_redis() as redis:
|
||||
result = await redis.exists(key)
|
||||
return result > 0
|
||||
|
||||
async def revoke_token(self, user_id: UUID, jti: str) -> bool:
|
||||
"""Revoke a specific refresh token.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
jti: The token's unique identifier.
|
||||
|
||||
Returns:
|
||||
True if token was revoked, False if it didn't exist.
|
||||
|
||||
Example:
|
||||
# On logout
|
||||
await token_store.revoke_token(user_id, jti)
|
||||
"""
|
||||
key = self._make_key(user_id, jti)
|
||||
|
||||
async with get_redis() as redis:
|
||||
result = await redis.delete(key)
|
||||
return result > 0
|
||||
|
||||
async def revoke_all_user_tokens(self, user_id: UUID) -> int:
|
||||
"""Revoke all refresh tokens for a user.
|
||||
|
||||
Useful for "logout from all devices" or security incidents.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
Number of tokens revoked.
|
||||
|
||||
Example:
|
||||
# Logout from all devices
|
||||
count = await token_store.revoke_all_user_tokens(user_id)
|
||||
print(f"Revoked {count} sessions")
|
||||
"""
|
||||
pattern = self._make_user_pattern(user_id)
|
||||
|
||||
async with get_redis() as redis:
|
||||
# Use SCAN to find all matching keys (safer than KEYS for large datasets)
|
||||
keys_to_delete = []
|
||||
async for key in redis.scan_iter(match=pattern):
|
||||
keys_to_delete.append(key)
|
||||
|
||||
if not keys_to_delete:
|
||||
return 0
|
||||
|
||||
# Delete all found keys
|
||||
result = await redis.delete(*keys_to_delete)
|
||||
return result
|
||||
|
||||
async def get_active_session_count(self, user_id: UUID) -> int:
|
||||
"""Get the number of active sessions (valid refresh tokens) for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
Number of active sessions.
|
||||
|
||||
Example:
|
||||
count = await token_store.get_active_session_count(user_id)
|
||||
print(f"User has {count} active sessions")
|
||||
"""
|
||||
pattern = self._make_user_pattern(user_id)
|
||||
|
||||
async with get_redis() as redis:
|
||||
count = 0
|
||||
async for _ in redis.scan_iter(match=pattern):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
# Global token store instance
|
||||
token_store = TokenStore()
|
||||
309
backend/app/services/user_service.py
Normal file
309
backend/app/services/user_service.py
Normal file
@ -0,0 +1,309 @@
|
||||
"""User service for Mantimon TCG.
|
||||
|
||||
This module provides async CRUD operations for user accounts,
|
||||
including OAuth-based user creation and premium status management.
|
||||
|
||||
All database operations use async SQLAlchemy sessions.
|
||||
|
||||
Example:
|
||||
from app.services.user_service import user_service
|
||||
|
||||
# Get user by ID
|
||||
user = await user_service.get_by_id(db, user_id)
|
||||
|
||||
# Create from OAuth
|
||||
user = await user_service.create_from_oauth(db, oauth_info)
|
||||
|
||||
# Update premium status
|
||||
user = await user_service.update_premium(db, user_id, premium_until)
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.user import User
|
||||
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
|
||||
|
||||
|
||||
class UserService:
|
||||
"""Service for user account operations.
|
||||
|
||||
Provides async methods for user CRUD, OAuth-based creation,
|
||||
and premium subscription management.
|
||||
"""
|
||||
|
||||
async def get_by_id(self, db: AsyncSession, user_id: UUID) -> User | None:
|
||||
"""Get a user by their ID.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise.
|
||||
|
||||
Example:
|
||||
user = await user_service.get_by_id(db, user_id)
|
||||
if user:
|
||||
print(f"Found user: {user.display_name}")
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
|
||||
"""Get a user by their email address.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
email: The user's email address.
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise.
|
||||
|
||||
Example:
|
||||
user = await user_service.get_by_email(db, "player@example.com")
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_oauth(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
provider: str,
|
||||
oauth_id: str,
|
||||
) -> User | None:
|
||||
"""Get a user by their OAuth provider and ID.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
provider: OAuth provider name (google, discord).
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise.
|
||||
|
||||
Example:
|
||||
user = await user_service.get_by_oauth(db, "google", "123456789")
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.oauth_provider == provider,
|
||||
User.oauth_id == oauth_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, db: AsyncSession, user_data: UserCreate) -> User:
|
||||
"""Create a new user.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user_data: User creation data.
|
||||
|
||||
Returns:
|
||||
The created User instance.
|
||||
|
||||
Example:
|
||||
user_data = UserCreate(
|
||||
email="player@example.com",
|
||||
display_name="Player1",
|
||||
oauth_provider="google",
|
||||
oauth_id="123456789"
|
||||
)
|
||||
user = await user_service.create(db, user_data)
|
||||
"""
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
display_name=user_data.display_name,
|
||||
avatar_url=user_data.avatar_url,
|
||||
oauth_provider=user_data.oauth_provider,
|
||||
oauth_id=user_data.oauth_id,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
async def create_from_oauth(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
oauth_info: OAuthUserInfo,
|
||||
) -> User:
|
||||
"""Create a new user from OAuth provider info.
|
||||
|
||||
Convenience method that converts OAuthUserInfo to UserCreate.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
oauth_info: Normalized OAuth user information.
|
||||
|
||||
Returns:
|
||||
The created User instance.
|
||||
|
||||
Example:
|
||||
oauth_info = OAuthUserInfo(
|
||||
provider="google",
|
||||
oauth_id="123456789",
|
||||
email="player@example.com",
|
||||
name="Player One",
|
||||
avatar_url="https://..."
|
||||
)
|
||||
user = await user_service.create_from_oauth(db, oauth_info)
|
||||
"""
|
||||
user_data = oauth_info.to_user_create()
|
||||
return await self.create(db, user_data)
|
||||
|
||||
async def get_or_create_from_oauth(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
oauth_info: OAuthUserInfo,
|
||||
) -> tuple[User, bool]:
|
||||
"""Get existing user or create new one from OAuth info.
|
||||
|
||||
First checks for existing user by OAuth provider+ID, then by email
|
||||
(for account linking), and finally creates a new user if not found.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
oauth_info: Normalized OAuth user information.
|
||||
|
||||
Returns:
|
||||
Tuple of (User, created) where created is True if new user.
|
||||
|
||||
Example:
|
||||
user, created = await user_service.get_or_create_from_oauth(db, oauth_info)
|
||||
if created:
|
||||
print("Welcome, new user!")
|
||||
else:
|
||||
print("Welcome back!")
|
||||
"""
|
||||
# First, check by OAuth provider + ID (exact match)
|
||||
user = await self.get_by_oauth(db, oauth_info.provider, oauth_info.oauth_id)
|
||||
if user:
|
||||
return user, False
|
||||
|
||||
# Check by email for potential account linking
|
||||
# If user exists with same email but different OAuth, update their OAuth
|
||||
user = await self.get_by_email(db, oauth_info.email)
|
||||
if user:
|
||||
# Update OAuth credentials for existing user
|
||||
# This links the new OAuth provider to the existing account
|
||||
user.oauth_provider = oauth_info.provider
|
||||
user.oauth_id = oauth_info.oauth_id
|
||||
# Optionally update avatar if not set
|
||||
if not user.avatar_url and oauth_info.avatar_url:
|
||||
user.avatar_url = oauth_info.avatar_url
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user, False
|
||||
|
||||
# Create new user
|
||||
user = await self.create_from_oauth(db, oauth_info)
|
||||
return user, True
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user: User,
|
||||
update_data: UserUpdate,
|
||||
) -> User:
|
||||
"""Update user profile fields.
|
||||
|
||||
Only updates fields that are provided (not None).
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user: The user to update.
|
||||
update_data: Fields to update.
|
||||
|
||||
Returns:
|
||||
The updated User instance.
|
||||
|
||||
Example:
|
||||
update_data = UserUpdate(display_name="New Name")
|
||||
user = await user_service.update(db, user, update_data)
|
||||
"""
|
||||
if update_data.display_name is not None:
|
||||
user.display_name = update_data.display_name
|
||||
if update_data.avatar_url is not None:
|
||||
user.avatar_url = update_data.avatar_url
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
async def update_last_login(self, db: AsyncSession, user: User) -> User:
|
||||
"""Update the user's last login timestamp.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user: The user to update.
|
||||
|
||||
Returns:
|
||||
The updated User instance.
|
||||
|
||||
Example:
|
||||
user = await user_service.update_last_login(db, user)
|
||||
"""
|
||||
user.last_login = datetime.now(UTC)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
async def update_premium(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user: User,
|
||||
premium_until: datetime | None,
|
||||
) -> User:
|
||||
"""Update user's premium subscription status.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user: The user to update.
|
||||
premium_until: When premium expires, or None to remove premium.
|
||||
|
||||
Returns:
|
||||
The updated User instance.
|
||||
|
||||
Example:
|
||||
# Grant 30 days of premium
|
||||
expires = datetime.now(UTC) + timedelta(days=30)
|
||||
user = await user_service.update_premium(db, user, expires)
|
||||
|
||||
# Remove premium
|
||||
user = await user_service.update_premium(db, user, None)
|
||||
"""
|
||||
if premium_until is not None:
|
||||
user.is_premium = True
|
||||
user.premium_until = premium_until
|
||||
else:
|
||||
user.is_premium = False
|
||||
user.premium_until = None
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
async def delete(self, db: AsyncSession, user: User) -> None:
|
||||
"""Delete a user account.
|
||||
|
||||
This will cascade delete all related data (decks, collection, etc.)
|
||||
based on the model relationships.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user: The user to delete.
|
||||
|
||||
Example:
|
||||
await user_service.delete(db, user)
|
||||
"""
|
||||
await db.delete(user)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Global service instance
|
||||
user_service = UserService()
|
||||
@ -44,6 +44,18 @@ services:
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
|
||||
adminer:
|
||||
image: adminer:latest
|
||||
restart: unless-stopped
|
||||
container_name: mantimon-adminer
|
||||
ports:
|
||||
- "8090:8080"
|
||||
environment:
|
||||
- ADMINER_DEFAULT_SERVER=mantimon-postgres
|
||||
- TZ=America/Chicago
|
||||
depends_on:
|
||||
- postgres
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
name: mantimon_postgres_data
|
||||
|
||||
@ -8,7 +8,9 @@ dependencies = [
|
||||
"alembic>=1.18.1",
|
||||
"asyncpg>=0.31.0",
|
||||
"bcrypt>=5.0.0",
|
||||
"email-validator>=2.3.0",
|
||||
"fastapi>=0.128.0",
|
||||
"httpx>=0.28.1",
|
||||
"passlib>=1.7.4",
|
||||
"psycopg2-binary>=2.9.11",
|
||||
"pydantic>=2.12.5",
|
||||
@ -24,12 +26,14 @@ dependencies = [
|
||||
dev = [
|
||||
"beautifulsoup4>=4.12.0",
|
||||
"black>=26.1.0",
|
||||
"fakeredis>=2.33.0",
|
||||
"httpx>=0.28.1",
|
||||
"mypy>=1.19.1",
|
||||
"pytest>=9.0.2",
|
||||
"pytest-asyncio>=1.3.0",
|
||||
"pytest-cov>=7.0.0",
|
||||
"requests>=2.31.0",
|
||||
"respx>=0.22.0",
|
||||
"ruff>=0.14.14",
|
||||
"testcontainers[postgres,redis]>=4.0.0",
|
||||
]
|
||||
|
||||
1
backend/tests/api/__init__.py
Normal file
1
backend/tests/api/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""API endpoint tests."""
|
||||
133
backend/tests/api/conftest.py
Normal file
133
backend/tests/api/conftest.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""Test fixtures for API endpoint tests.
|
||||
|
||||
Provides fixtures for testing FastAPI endpoints with mocked dependencies.
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import fakeredis.aioredis
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api import deps as api_deps
|
||||
from app.api.auth import router as auth_router
|
||||
from app.api.users import router as users_router
|
||||
from app.db.models import User
|
||||
from app.services.jwt_service import create_access_token, create_refresh_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_redis():
|
||||
"""Provide a fake Redis instance for testing."""
|
||||
return fakeredis.aioredis.FakeRedis(decode_responses=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_redis(fake_redis):
|
||||
"""Mock the get_redis context manager to use fake Redis."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _mock_get_redis():
|
||||
yield fake_redis
|
||||
|
||||
return _mock_get_redis
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user():
|
||||
"""Create a test user object.
|
||||
|
||||
Returns a User model instance that can be used in tests.
|
||||
The user is not persisted to database.
|
||||
"""
|
||||
user = User(
|
||||
email="test@example.com",
|
||||
display_name="Test User",
|
||||
avatar_url="https://example.com/avatar.jpg",
|
||||
oauth_provider="google",
|
||||
oauth_id="google-123",
|
||||
is_premium=False,
|
||||
premium_until=None,
|
||||
)
|
||||
# Manually set the ID since we're not using database
|
||||
user.id = str(uuid4())
|
||||
user.created_at = datetime.now(UTC)
|
||||
user.updated_at = datetime.now(UTC)
|
||||
user.last_login = None
|
||||
user.linked_accounts = []
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def premium_user(test_user):
|
||||
"""Create a premium test user."""
|
||||
from datetime import timedelta
|
||||
|
||||
test_user.is_premium = True
|
||||
test_user.premium_until = datetime.now(UTC) + timedelta(days=30)
|
||||
return test_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def access_token(test_user):
|
||||
"""Create a valid access token for the test user."""
|
||||
from uuid import UUID
|
||||
|
||||
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
|
||||
return create_access_token(user_id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def refresh_token_data(test_user):
|
||||
"""Create a valid refresh token and JTI for the test user."""
|
||||
from uuid import UUID
|
||||
|
||||
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
|
||||
token, jti = create_refresh_token(user_id)
|
||||
return {"token": token, "jti": jti, "user_id": user_id}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Create a mock database session."""
|
||||
return MagicMock(spec=AsyncSession)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mock_get_redis, mock_db_session):
|
||||
"""Create a test FastAPI app with mocked Redis and DB.
|
||||
|
||||
This creates a minimal app with just the auth and users routers,
|
||||
with Redis and database mocked.
|
||||
"""
|
||||
# Create test app (no lifespan since we're mocking everything)
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(auth_router, prefix="/api")
|
||||
test_app.include_router(users_router, prefix="/api")
|
||||
|
||||
# Override get_db dependency to return mock session
|
||||
async def override_get_db():
|
||||
yield mock_db_session
|
||||
|
||||
test_app.dependency_overrides[api_deps.get_db] = override_get_db
|
||||
|
||||
# Patch get_redis globally for this app
|
||||
with (
|
||||
patch("app.api.auth.get_redis", mock_get_redis),
|
||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||
):
|
||||
yield test_app
|
||||
|
||||
# Clean up overrides
|
||||
test_app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a test client for the app."""
|
||||
return TestClient(app)
|
||||
260
backend/tests/api/test_auth.py
Normal file
260
backend/tests/api/test_auth.py
Normal file
@ -0,0 +1,260 @@
|
||||
"""Tests for auth API endpoints.
|
||||
|
||||
Tests the authentication endpoints including OAuth redirects,
|
||||
token refresh, and logout.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestGoogleAuthRedirect:
|
||||
"""Tests for GET /api/auth/google endpoint."""
|
||||
|
||||
def test_returns_501_when_not_configured(self, client: TestClient):
|
||||
"""Test that endpoint returns 501 when Google OAuth is not configured.
|
||||
|
||||
Without client credentials, OAuth flow cannot proceed.
|
||||
"""
|
||||
with patch("app.api.auth.google_oauth") as mock_oauth:
|
||||
mock_oauth.is_configured.return_value = False
|
||||
|
||||
response = client.get(
|
||||
"/api/auth/google",
|
||||
params={"redirect_uri": "http://localhost/callback"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_501_NOT_IMPLEMENTED
|
||||
assert "not configured" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestDiscordAuthRedirect:
|
||||
"""Tests for GET /api/auth/discord endpoint."""
|
||||
|
||||
def test_returns_501_when_not_configured(self, client: TestClient):
|
||||
"""Test that endpoint returns 501 when Discord OAuth is not configured."""
|
||||
with patch("app.api.auth.discord_oauth") as mock_oauth:
|
||||
mock_oauth.is_configured.return_value = False
|
||||
|
||||
response = client.get(
|
||||
"/api/auth/discord",
|
||||
params={"redirect_uri": "http://localhost/callback"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_501_NOT_IMPLEMENTED
|
||||
assert "not configured" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestRefreshTokens:
|
||||
"""Tests for POST /api/auth/refresh endpoint."""
|
||||
|
||||
def test_returns_new_access_token(
|
||||
self, client: TestClient, test_user, refresh_token_data, mock_get_redis
|
||||
):
|
||||
"""Test that refresh endpoint returns new access token for valid refresh token.
|
||||
|
||||
A valid, non-revoked refresh token should yield a new access token.
|
||||
"""
|
||||
# Store the refresh token in fake Redis
|
||||
import asyncio
|
||||
|
||||
async def setup_token():
|
||||
async with mock_get_redis() as redis:
|
||||
key = f"refresh_token:{refresh_token_data['user_id']}:{refresh_token_data['jti']}"
|
||||
await redis.setex(key, 86400, "1")
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(setup_token())
|
||||
|
||||
# Mock user service to return our test user
|
||||
with patch("app.api.auth.user_service") as mock_user_service:
|
||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
|
||||
with (
|
||||
patch("app.api.auth.get_redis", mock_get_redis),
|
||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||
):
|
||||
response = client.post(
|
||||
"/api/auth/refresh",
|
||||
json={"refresh_token": refresh_token_data["token"]},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert data["refresh_token"] == refresh_token_data["token"]
|
||||
assert data["token_type"] == "bearer"
|
||||
assert "expires_in" in data
|
||||
|
||||
def test_returns_401_for_invalid_token(self, client: TestClient):
|
||||
"""Test that refresh endpoint returns 401 for invalid refresh token."""
|
||||
response = client.post(
|
||||
"/api/auth/refresh",
|
||||
json={"refresh_token": "invalid.token.here"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_returns_401_for_revoked_token(
|
||||
self, client: TestClient, refresh_token_data, mock_get_redis
|
||||
):
|
||||
"""Test that refresh endpoint returns 401 for revoked token.
|
||||
|
||||
A refresh token not in Redis (revoked/expired) should be rejected.
|
||||
"""
|
||||
# Don't store the token in Redis - simulating revocation
|
||||
|
||||
with (
|
||||
patch("app.api.auth.get_redis", mock_get_redis),
|
||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||
):
|
||||
response = client.post(
|
||||
"/api/auth/refresh",
|
||||
json={"refresh_token": refresh_token_data["token"]},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "revoked" in response.json()["detail"]
|
||||
|
||||
def test_returns_401_for_deleted_user(
|
||||
self, client: TestClient, refresh_token_data, mock_get_redis
|
||||
):
|
||||
"""Test that refresh endpoint returns 401 if user no longer exists."""
|
||||
# Store the token
|
||||
import asyncio
|
||||
|
||||
async def setup_token():
|
||||
async with mock_get_redis() as redis:
|
||||
key = f"refresh_token:{refresh_token_data['user_id']}:{refresh_token_data['jti']}"
|
||||
await redis.setex(key, 86400, "1")
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(setup_token())
|
||||
|
||||
# Mock user service to return None (user deleted)
|
||||
with patch("app.api.auth.user_service") as mock_user_service:
|
||||
mock_user_service.get_by_id = AsyncMock(return_value=None)
|
||||
|
||||
with (
|
||||
patch("app.api.auth.get_redis", mock_get_redis),
|
||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||
):
|
||||
response = client.post(
|
||||
"/api/auth/refresh",
|
||||
json={"refresh_token": refresh_token_data["token"]},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "User not found" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestLogout:
|
||||
"""Tests for POST /api/auth/logout endpoint."""
|
||||
|
||||
def test_revokes_token(self, client: TestClient, refresh_token_data, mock_get_redis):
|
||||
"""Test that logout revokes the refresh token.
|
||||
|
||||
After logout, the token should no longer be in Redis.
|
||||
"""
|
||||
# Store the token first
|
||||
import asyncio
|
||||
|
||||
async def setup_and_check():
|
||||
async with mock_get_redis() as redis:
|
||||
key = f"refresh_token:{refresh_token_data['user_id']}:{refresh_token_data['jti']}"
|
||||
await redis.setex(key, 86400, "1")
|
||||
return key
|
||||
|
||||
key = asyncio.get_event_loop().run_until_complete(setup_and_check())
|
||||
|
||||
# Logout
|
||||
with (
|
||||
patch("app.api.auth.get_redis", mock_get_redis),
|
||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||
):
|
||||
response = client.post(
|
||||
"/api/auth/logout",
|
||||
json={"refresh_token": refresh_token_data["token"]},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
# Verify token is gone
|
||||
async def verify_deleted():
|
||||
async with mock_get_redis() as redis:
|
||||
return await redis.exists(key)
|
||||
|
||||
exists = asyncio.get_event_loop().run_until_complete(verify_deleted())
|
||||
assert exists == 0
|
||||
|
||||
def test_succeeds_for_invalid_token(self, client: TestClient):
|
||||
"""Test that logout succeeds even for invalid tokens.
|
||||
|
||||
Invalid tokens are effectively "already logged out", so no error.
|
||||
"""
|
||||
response = client.post(
|
||||
"/api/auth/logout",
|
||||
json={"refresh_token": "invalid.token.here"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
|
||||
class TestLogoutAll:
|
||||
"""Tests for POST /api/auth/logout-all endpoint."""
|
||||
|
||||
def test_requires_authentication(self, client: TestClient):
|
||||
"""Test that logout-all requires a valid access token.
|
||||
|
||||
Without authentication, endpoint should return 401.
|
||||
"""
|
||||
response = client.post("/api/auth/logout-all")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_revokes_all_tokens(self, client: TestClient, test_user, access_token, mock_get_redis):
|
||||
"""Test that logout-all revokes all refresh tokens for user.
|
||||
|
||||
Should delete all tokens matching the user's ID pattern.
|
||||
"""
|
||||
|
||||
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
|
||||
|
||||
# Store multiple tokens
|
||||
import asyncio
|
||||
|
||||
async def setup_tokens():
|
||||
async with mock_get_redis() as redis:
|
||||
await redis.setex(f"refresh_token:{user_id}:jti-1", 86400, "1")
|
||||
await redis.setex(f"refresh_token:{user_id}:jti-2", 86400, "1")
|
||||
await redis.setex(f"refresh_token:{user_id}:jti-3", 86400, "1")
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
||||
|
||||
# Mock dependencies
|
||||
with patch("app.api.deps.user_service") as mock_user_service:
|
||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
|
||||
with (
|
||||
patch("app.api.auth.get_redis", mock_get_redis),
|
||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||
):
|
||||
response = client.post(
|
||||
"/api/auth/logout-all",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
# Verify all tokens are gone
|
||||
async def count_remaining():
|
||||
async with mock_get_redis() as redis:
|
||||
count = 0
|
||||
async for _ in redis.scan_iter(match=f"refresh_token:{user_id}:*"):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
remaining = asyncio.get_event_loop().run_until_complete(count_remaining())
|
||||
assert remaining == 0
|
||||
172
backend/tests/api/test_users.py
Normal file
172
backend/tests/api/test_users.py
Normal file
@ -0,0 +1,172 @@
|
||||
"""Tests for users API endpoints.
|
||||
|
||||
Tests the user profile management endpoints.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
"""Tests for GET /api/users/me endpoint."""
|
||||
|
||||
def test_returns_user_profile(self, client: TestClient, test_user, access_token):
|
||||
"""Test that endpoint returns user profile for authenticated user.
|
||||
|
||||
Should return the user's profile information.
|
||||
"""
|
||||
with patch("app.api.deps.user_service") as mock_user_service:
|
||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
|
||||
response = client.get(
|
||||
"/api/users/me",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["email"] == test_user.email
|
||||
assert data["display_name"] == test_user.display_name
|
||||
assert data["avatar_url"] == test_user.avatar_url
|
||||
assert data["is_premium"] == test_user.is_premium
|
||||
|
||||
def test_requires_authentication(self, client: TestClient):
|
||||
"""Test that endpoint returns 401 without authentication."""
|
||||
response = client.get("/api/users/me")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_returns_401_for_invalid_token(self, client: TestClient):
|
||||
"""Test that endpoint returns 401 for invalid access token."""
|
||||
response = client.get(
|
||||
"/api/users/me",
|
||||
headers={"Authorization": "Bearer invalid.token.here"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class TestUpdateCurrentUser:
|
||||
"""Tests for PATCH /api/users/me endpoint."""
|
||||
|
||||
def test_updates_display_name(self, client: TestClient, test_user, access_token):
|
||||
"""Test that endpoint updates display_name when provided."""
|
||||
updated_user = test_user
|
||||
updated_user.display_name = "New Name"
|
||||
|
||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
|
||||
with patch("app.api.users.user_service") as mock_user_service:
|
||||
mock_user_service.update = AsyncMock(return_value=updated_user)
|
||||
|
||||
response = client.patch(
|
||||
"/api/users/me",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json={"display_name": "New Name"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["display_name"] == "New Name"
|
||||
|
||||
def test_updates_avatar_url(self, client: TestClient, test_user, access_token):
|
||||
"""Test that endpoint updates avatar_url when provided."""
|
||||
updated_user = test_user
|
||||
updated_user.avatar_url = "https://new-avatar.com/img.jpg"
|
||||
|
||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
|
||||
with patch("app.api.users.user_service") as mock_user_service:
|
||||
mock_user_service.update = AsyncMock(return_value=updated_user)
|
||||
|
||||
response = client.patch(
|
||||
"/api/users/me",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json={"avatar_url": "https://new-avatar.com/img.jpg"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
|
||||
|
||||
def test_requires_authentication(self, client: TestClient):
|
||||
"""Test that endpoint returns 401 without authentication."""
|
||||
response = client.patch(
|
||||
"/api/users/me",
|
||||
json={"display_name": "New Name"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class TestGetLinkedAccounts:
|
||||
"""Tests for GET /api/users/me/linked-accounts endpoint."""
|
||||
|
||||
def test_returns_linked_accounts(self, client: TestClient, test_user, access_token):
|
||||
"""Test that endpoint returns list of linked OAuth accounts.
|
||||
|
||||
Should include the primary provider and any linked accounts.
|
||||
"""
|
||||
with patch("app.api.deps.user_service") as mock_user_service:
|
||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
|
||||
response = client.get(
|
||||
"/api/users/me/linked-accounts",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1 # At least primary account
|
||||
assert data[0]["provider"] == test_user.oauth_provider
|
||||
|
||||
def test_requires_authentication(self, client: TestClient):
|
||||
"""Test that endpoint returns 401 without authentication."""
|
||||
response = client.get("/api/users/me/linked-accounts")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class TestGetActiveSessions:
|
||||
"""Tests for GET /api/users/me/sessions endpoint."""
|
||||
|
||||
def test_returns_session_count(
|
||||
self, client: TestClient, test_user, access_token, mock_get_redis
|
||||
):
|
||||
"""Test that endpoint returns count of active sessions.
|
||||
|
||||
Should return the number of valid refresh tokens.
|
||||
"""
|
||||
|
||||
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
|
||||
|
||||
# Store some tokens
|
||||
import asyncio
|
||||
|
||||
async def setup_tokens():
|
||||
async with mock_get_redis() as redis:
|
||||
await redis.setex(f"refresh_token:{user_id}:jti-1", 86400, "1")
|
||||
await redis.setex(f"refresh_token:{user_id}:jti-2", 86400, "1")
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
||||
|
||||
with patch("app.api.deps.user_service") as mock_user_service:
|
||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
|
||||
with patch("app.services.token_store.get_redis", mock_get_redis):
|
||||
response = client.get(
|
||||
"/api/users/me/sessions",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "active_sessions" in data
|
||||
assert data["active_sessions"] == 2
|
||||
|
||||
def test_requires_authentication(self, client: TestClient):
|
||||
"""Test that endpoint returns 401 without authentication."""
|
||||
response = client.get("/api/users/me/sessions")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@ -70,6 +70,7 @@ TABLES_TO_TRUNCATE = [
|
||||
"campaign_progress",
|
||||
"collections",
|
||||
"decks",
|
||||
"oauth_linked_accounts",
|
||||
"users",
|
||||
]
|
||||
|
||||
|
||||
1
backend/tests/services/oauth/__init__.py
Normal file
1
backend/tests/services/oauth/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""OAuth service tests."""
|
||||
321
backend/tests/services/oauth/test_discord.py
Normal file
321
backend/tests/services/oauth/test_discord.py
Normal file
@ -0,0 +1,321 @@
|
||||
"""Tests for Discord OAuth service.
|
||||
|
||||
Tests the Discord OAuth flow with mocked HTTP responses using respx.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
from httpx import Response
|
||||
|
||||
from app.services.oauth.discord import DiscordOAuth, DiscordOAuthError
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class TestGetAuthorizationUrl:
|
||||
"""Tests for get_authorization_url method."""
|
||||
|
||||
def test_raises_when_not_configured(self):
|
||||
"""Test that get_authorization_url raises when Discord OAuth is not configured.
|
||||
|
||||
Without client ID, the method should raise DiscordOAuthError.
|
||||
"""
|
||||
oauth = DiscordOAuth()
|
||||
|
||||
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||
mock_settings.discord_client_id = None
|
||||
|
||||
with pytest.raises(DiscordOAuthError, match="not configured"):
|
||||
oauth.get_authorization_url("http://localhost/callback", "state123")
|
||||
|
||||
def test_returns_valid_url_when_configured(self):
|
||||
"""Test that get_authorization_url returns properly formatted URL.
|
||||
|
||||
The URL should include client_id, redirect_uri, state, and scopes.
|
||||
"""
|
||||
oauth = DiscordOAuth()
|
||||
|
||||
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||
mock_settings.discord_client_id = "test-client-id"
|
||||
mock_settings.discord_client_secret = "test-secret"
|
||||
|
||||
url = oauth.get_authorization_url("http://localhost/callback", "state123")
|
||||
|
||||
assert "discord.com/api/oauth2/authorize" in url
|
||||
assert "client_id=test-client-id" in url
|
||||
assert "redirect_uri=http" in url
|
||||
assert "state=state123" in url
|
||||
assert "scope=" in url
|
||||
assert "response_type=code" in url
|
||||
|
||||
|
||||
class TestExchangeCodeForTokens:
|
||||
"""Tests for exchange_code_for_tokens method."""
|
||||
|
||||
@respx.mock
|
||||
async def test_returns_tokens_on_success(self):
|
||||
"""Test that exchange_code_for_tokens returns tokens on success.
|
||||
|
||||
Mocks Discord's token endpoint to return valid tokens.
|
||||
"""
|
||||
oauth = DiscordOAuth()
|
||||
|
||||
respx.post("https://discord.com/api/oauth2/token").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"access_token": "test-access-token",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"expires_in": 604800,
|
||||
"token_type": "Bearer",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||
mock_settings.discord_client_id = "test-client-id"
|
||||
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
|
||||
|
||||
tokens = await oauth.exchange_code_for_tokens("auth-code", "http://localhost/callback")
|
||||
|
||||
assert tokens["access_token"] == "test-access-token"
|
||||
|
||||
@respx.mock
|
||||
async def test_raises_on_error_response(self):
|
||||
"""Test that exchange_code_for_tokens raises on error from Discord.
|
||||
|
||||
If Discord returns an error, DiscordOAuthError should be raised.
|
||||
"""
|
||||
oauth = DiscordOAuth()
|
||||
|
||||
respx.post("https://discord.com/api/oauth2/token").mock(
|
||||
return_value=Response(
|
||||
400,
|
||||
json={
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Invalid code",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||
mock_settings.discord_client_id = "test-client-id"
|
||||
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
|
||||
|
||||
with pytest.raises(DiscordOAuthError, match="Token exchange failed"):
|
||||
await oauth.exchange_code_for_tokens("invalid-code", "http://localhost/callback")
|
||||
|
||||
|
||||
class TestFetchUserInfo:
|
||||
"""Tests for fetch_user_info method."""
|
||||
|
||||
@respx.mock
|
||||
async def test_returns_user_info_on_success(self):
|
||||
"""Test that fetch_user_info returns user data from Discord.
|
||||
|
||||
Mocks Discord's users/@me endpoint.
|
||||
"""
|
||||
oauth = DiscordOAuth()
|
||||
|
||||
respx.get("https://discord.com/api/users/@me").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"id": "discord-user-123",
|
||||
"username": "testuser",
|
||||
"global_name": "Test User",
|
||||
"email": "user@discord.com",
|
||||
"avatar": "abc123",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
user_info = await oauth.fetch_user_info("test-access-token")
|
||||
|
||||
assert user_info["id"] == "discord-user-123"
|
||||
assert user_info["email"] == "user@discord.com"
|
||||
|
||||
@respx.mock
|
||||
async def test_raises_on_error_response(self):
|
||||
"""Test that fetch_user_info raises on error from Discord."""
|
||||
oauth = DiscordOAuth()
|
||||
|
||||
respx.get("https://discord.com/api/users/@me").mock(
|
||||
return_value=Response(401, json={"message": "401: Unauthorized"})
|
||||
)
|
||||
|
||||
with pytest.raises(DiscordOAuthError, match="Failed to fetch user info"):
|
||||
await oauth.fetch_user_info("invalid-token")
|
||||
|
||||
|
||||
class TestBuildAvatarUrl:
|
||||
"""Tests for _build_avatar_url method."""
|
||||
|
||||
def test_returns_none_for_no_avatar(self):
|
||||
"""Test that _build_avatar_url returns None when avatar is None."""
|
||||
oauth = DiscordOAuth()
|
||||
result = oauth._build_avatar_url("123456", None)
|
||||
assert result is None
|
||||
|
||||
def test_builds_png_url_for_static_avatar(self):
|
||||
"""Test that _build_avatar_url builds PNG URL for static avatars."""
|
||||
oauth = DiscordOAuth()
|
||||
result = oauth._build_avatar_url("123456", "abcdef123")
|
||||
|
||||
assert result == "https://cdn.discordapp.com/avatars/123456/abcdef123.png"
|
||||
|
||||
def test_builds_gif_url_for_animated_avatar(self):
|
||||
"""Test that _build_avatar_url builds GIF URL for animated avatars.
|
||||
|
||||
Animated avatars start with 'a_' prefix.
|
||||
"""
|
||||
oauth = DiscordOAuth()
|
||||
result = oauth._build_avatar_url("123456", "a_animated123")
|
||||
|
||||
assert result == "https://cdn.discordapp.com/avatars/123456/a_animated123.gif"
|
||||
|
||||
|
||||
class TestGetUserInfo:
|
||||
"""Tests for get_user_info method (full flow)."""
|
||||
|
||||
@respx.mock
|
||||
async def test_returns_oauth_user_info_on_success(self):
|
||||
"""Test that get_user_info completes full OAuth flow.
|
||||
|
||||
This tests the combined token exchange + user info fetch.
|
||||
"""
|
||||
oauth = DiscordOAuth()
|
||||
|
||||
# Mock token exchange
|
||||
respx.post("https://discord.com/api/oauth2/token").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"access_token": "test-access-token",
|
||||
"refresh_token": "test-refresh",
|
||||
"expires_in": 604800,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Mock user info
|
||||
respx.get("https://discord.com/api/users/@me").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"id": "discord-user-456",
|
||||
"username": "fullflowuser",
|
||||
"global_name": "Full Flow User",
|
||||
"email": "fullflow@discord.com",
|
||||
"avatar": "avatar123",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||
mock_settings.discord_client_id = "test-client-id"
|
||||
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
|
||||
|
||||
result = await oauth.get_user_info("auth-code", "http://localhost/callback")
|
||||
|
||||
assert result.provider == "discord"
|
||||
assert result.oauth_id == "discord-user-456"
|
||||
assert result.email == "fullflow@discord.com"
|
||||
assert result.name == "Full Flow User"
|
||||
assert "avatar123.png" in result.avatar_url
|
||||
|
||||
@respx.mock
|
||||
async def test_uses_username_when_no_global_name(self):
|
||||
"""Test that get_user_info falls back to username for display name.
|
||||
|
||||
Discord users may not have global_name set.
|
||||
"""
|
||||
oauth = DiscordOAuth()
|
||||
|
||||
respx.post("https://discord.com/api/oauth2/token").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={"access_token": "test-access-token"},
|
||||
)
|
||||
)
|
||||
|
||||
respx.get("https://discord.com/api/users/@me").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"id": "discord-user-789",
|
||||
"username": "legacyuser",
|
||||
"global_name": None, # No global name
|
||||
"email": "legacy@discord.com",
|
||||
"avatar": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||
mock_settings.discord_client_id = "test-client-id"
|
||||
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
|
||||
|
||||
result = await oauth.get_user_info("auth-code", "http://localhost/callback")
|
||||
|
||||
assert result.name == "legacyuser"
|
||||
assert result.avatar_url is None
|
||||
|
||||
@respx.mock
|
||||
async def test_raises_when_no_email(self):
|
||||
"""Test that get_user_info raises when Discord user has no email.
|
||||
|
||||
Email is required for account creation.
|
||||
"""
|
||||
oauth = DiscordOAuth()
|
||||
|
||||
respx.post("https://discord.com/api/oauth2/token").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={"access_token": "test-access-token"},
|
||||
)
|
||||
)
|
||||
|
||||
respx.get("https://discord.com/api/users/@me").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"id": "discord-user-noemail",
|
||||
"username": "noemailuser",
|
||||
"email": None, # No verified email
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||
mock_settings.discord_client_id = "test-client-id"
|
||||
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
|
||||
|
||||
with pytest.raises(DiscordOAuthError, match="verified email"):
|
||||
await oauth.get_user_info("auth-code", "http://localhost/callback")
|
||||
|
||||
|
||||
class TestIsConfigured:
|
||||
"""Tests for is_configured method."""
|
||||
|
||||
def test_returns_false_when_not_configured(self):
|
||||
"""Test that is_configured returns False without credentials."""
|
||||
oauth = DiscordOAuth()
|
||||
|
||||
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||
mock_settings.discord_client_id = None
|
||||
mock_settings.discord_client_secret = None
|
||||
|
||||
assert oauth.is_configured() is False
|
||||
|
||||
def test_returns_true_when_configured(self):
|
||||
"""Test that is_configured returns True with credentials."""
|
||||
oauth = DiscordOAuth()
|
||||
|
||||
with patch("app.services.oauth.discord.settings") as mock_settings:
|
||||
mock_settings.discord_client_id = "client-id"
|
||||
mock_settings.discord_client_secret = "client-secret"
|
||||
|
||||
assert oauth.is_configured() is True
|
||||
241
backend/tests/services/oauth/test_google.py
Normal file
241
backend/tests/services/oauth/test_google.py
Normal file
@ -0,0 +1,241 @@
|
||||
"""Tests for Google OAuth service.
|
||||
|
||||
Tests the Google OAuth flow with mocked HTTP responses using respx.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
from httpx import Response
|
||||
|
||||
from app.services.oauth.google import GoogleOAuth, GoogleOAuthError
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class TestGetAuthorizationUrl:
|
||||
"""Tests for get_authorization_url method."""
|
||||
|
||||
def test_raises_when_not_configured(self):
|
||||
"""Test that get_authorization_url raises when Google OAuth is not configured.
|
||||
|
||||
Without client ID, the method should raise GoogleOAuthError.
|
||||
"""
|
||||
oauth = GoogleOAuth()
|
||||
|
||||
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||
mock_settings.google_client_id = None
|
||||
|
||||
with pytest.raises(GoogleOAuthError, match="not configured"):
|
||||
oauth.get_authorization_url("http://localhost/callback", "state123")
|
||||
|
||||
def test_returns_valid_url_when_configured(self):
|
||||
"""Test that get_authorization_url returns properly formatted URL.
|
||||
|
||||
The URL should include client_id, redirect_uri, state, and scopes.
|
||||
"""
|
||||
oauth = GoogleOAuth()
|
||||
|
||||
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||
mock_settings.google_client_id = "test-client-id"
|
||||
mock_settings.google_client_secret = "test-secret"
|
||||
|
||||
url = oauth.get_authorization_url("http://localhost/callback", "state123")
|
||||
|
||||
assert "accounts.google.com/o/oauth2/v2/auth" in url
|
||||
assert "client_id=test-client-id" in url
|
||||
assert "redirect_uri=http" in url
|
||||
assert "state=state123" in url
|
||||
assert "scope=" in url
|
||||
assert "response_type=code" in url
|
||||
|
||||
|
||||
class TestExchangeCodeForTokens:
|
||||
"""Tests for exchange_code_for_tokens method."""
|
||||
|
||||
@respx.mock
|
||||
async def test_returns_tokens_on_success(self):
|
||||
"""Test that exchange_code_for_tokens returns tokens on success.
|
||||
|
||||
Mocks Google's token endpoint to return valid tokens.
|
||||
"""
|
||||
oauth = GoogleOAuth()
|
||||
|
||||
respx.post("https://oauth2.googleapis.com/token").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"access_token": "test-access-token",
|
||||
"id_token": "test-id-token",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||
mock_settings.google_client_id = "test-client-id"
|
||||
mock_settings.google_client_secret.get_secret_value.return_value = "test-secret"
|
||||
|
||||
tokens = await oauth.exchange_code_for_tokens("auth-code", "http://localhost/callback")
|
||||
|
||||
assert tokens["access_token"] == "test-access-token"
|
||||
|
||||
@respx.mock
|
||||
async def test_raises_on_error_response(self):
|
||||
"""Test that exchange_code_for_tokens raises on error from Google.
|
||||
|
||||
If Google returns an error, GoogleOAuthError should be raised.
|
||||
"""
|
||||
oauth = GoogleOAuth()
|
||||
|
||||
respx.post("https://oauth2.googleapis.com/token").mock(
|
||||
return_value=Response(
|
||||
400,
|
||||
json={
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Code has expired",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||
mock_settings.google_client_id = "test-client-id"
|
||||
mock_settings.google_client_secret.get_secret_value.return_value = "test-secret"
|
||||
|
||||
with pytest.raises(GoogleOAuthError, match="Token exchange failed"):
|
||||
await oauth.exchange_code_for_tokens("expired-code", "http://localhost/callback")
|
||||
|
||||
|
||||
class TestFetchUserInfo:
|
||||
"""Tests for fetch_user_info method."""
|
||||
|
||||
@respx.mock
|
||||
async def test_returns_user_info_on_success(self):
|
||||
"""Test that fetch_user_info returns user data from Google.
|
||||
|
||||
Mocks Google's userinfo endpoint.
|
||||
"""
|
||||
oauth = GoogleOAuth()
|
||||
|
||||
respx.get("https://www.googleapis.com/oauth2/v2/userinfo").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"id": "google-user-123",
|
||||
"email": "user@gmail.com",
|
||||
"name": "Test User",
|
||||
"picture": "https://google.com/avatar.jpg",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
user_info = await oauth.fetch_user_info("test-access-token")
|
||||
|
||||
assert user_info["id"] == "google-user-123"
|
||||
assert user_info["email"] == "user@gmail.com"
|
||||
|
||||
@respx.mock
|
||||
async def test_raises_on_error_response(self):
|
||||
"""Test that fetch_user_info raises on error from Google."""
|
||||
oauth = GoogleOAuth()
|
||||
|
||||
respx.get("https://www.googleapis.com/oauth2/v2/userinfo").mock(
|
||||
return_value=Response(401, json={"error": "Invalid token"})
|
||||
)
|
||||
|
||||
with pytest.raises(GoogleOAuthError, match="Failed to fetch user info"):
|
||||
await oauth.fetch_user_info("invalid-token")
|
||||
|
||||
|
||||
class TestGetUserInfo:
|
||||
"""Tests for get_user_info method (full flow)."""
|
||||
|
||||
@respx.mock
|
||||
async def test_returns_oauth_user_info_on_success(self):
|
||||
"""Test that get_user_info completes full OAuth flow.
|
||||
|
||||
This tests the combined token exchange + user info fetch.
|
||||
"""
|
||||
oauth = GoogleOAuth()
|
||||
|
||||
# Mock token exchange
|
||||
respx.post("https://oauth2.googleapis.com/token").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"access_token": "test-access-token",
|
||||
"id_token": "test-id-token",
|
||||
"expires_in": 3600,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Mock user info
|
||||
respx.get("https://www.googleapis.com/oauth2/v2/userinfo").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"id": "google-user-456",
|
||||
"email": "fullflow@gmail.com",
|
||||
"name": "Full Flow User",
|
||||
"picture": "https://google.com/fullflow.jpg",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||
mock_settings.google_client_id = "test-client-id"
|
||||
mock_settings.google_client_secret.get_secret_value.return_value = "test-secret"
|
||||
|
||||
result = await oauth.get_user_info("auth-code", "http://localhost/callback")
|
||||
|
||||
assert result.provider == "google"
|
||||
assert result.oauth_id == "google-user-456"
|
||||
assert result.email == "fullflow@gmail.com"
|
||||
assert result.name == "Full Flow User"
|
||||
assert result.avatar_url == "https://google.com/fullflow.jpg"
|
||||
|
||||
@respx.mock
|
||||
async def test_raises_when_no_access_token(self):
|
||||
"""Test that get_user_info raises when token response lacks access_token."""
|
||||
oauth = GoogleOAuth()
|
||||
|
||||
respx.post("https://oauth2.googleapis.com/token").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={"id_token": "only-id-token"}, # No access_token
|
||||
)
|
||||
)
|
||||
|
||||
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||
mock_settings.google_client_id = "test-client-id"
|
||||
mock_settings.google_client_secret.get_secret_value.return_value = "test-secret"
|
||||
|
||||
with pytest.raises(GoogleOAuthError, match="No access token"):
|
||||
await oauth.get_user_info("auth-code", "http://localhost/callback")
|
||||
|
||||
|
||||
class TestIsConfigured:
|
||||
"""Tests for is_configured method."""
|
||||
|
||||
def test_returns_false_when_not_configured(self):
|
||||
"""Test that is_configured returns False without credentials."""
|
||||
oauth = GoogleOAuth()
|
||||
|
||||
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||
mock_settings.google_client_id = None
|
||||
mock_settings.google_client_secret = None
|
||||
|
||||
assert oauth.is_configured() is False
|
||||
|
||||
def test_returns_true_when_configured(self):
|
||||
"""Test that is_configured returns True with credentials."""
|
||||
oauth = GoogleOAuth()
|
||||
|
||||
with patch("app.services.oauth.google.settings") as mock_settings:
|
||||
mock_settings.google_client_id = "client-id"
|
||||
mock_settings.google_client_secret = "client-secret"
|
||||
|
||||
assert oauth.is_configured() is True
|
||||
370
backend/tests/services/test_jwt_service.py
Normal file
370
backend/tests/services/test_jwt_service.py
Normal file
@ -0,0 +1,370 @@
|
||||
"""Tests for JWT service.
|
||||
|
||||
Tests the JWT token creation and verification functions used for
|
||||
authentication throughout the application.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from jose import jwt
|
||||
|
||||
from app.config import settings
|
||||
from app.schemas.auth import TokenType
|
||||
from app.services.jwt_service import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
get_refresh_token_expiration,
|
||||
get_token_expiration_seconds,
|
||||
verify_access_token,
|
||||
verify_refresh_token,
|
||||
)
|
||||
|
||||
|
||||
class TestCreateAccessToken:
|
||||
"""Tests for create_access_token function."""
|
||||
|
||||
def test_creates_valid_jwt(self):
|
||||
"""Test that create_access_token returns a valid JWT string.
|
||||
|
||||
The returned token should be decodable and contain the expected
|
||||
claims including subject, expiration, and token type.
|
||||
"""
|
||||
user_id = uuid.uuid4()
|
||||
token = create_access_token(user_id)
|
||||
|
||||
# Should be a valid JWT (three dot-separated parts)
|
||||
assert isinstance(token, str)
|
||||
assert token.count(".") == 2
|
||||
|
||||
# Should be decodable
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.secret_key.get_secret_value(),
|
||||
algorithms=[settings.jwt_algorithm],
|
||||
)
|
||||
assert payload["sub"] == str(user_id)
|
||||
assert payload["type"] == TokenType.ACCESS.value
|
||||
|
||||
def test_sets_correct_expiration(self):
|
||||
"""Test that access token expiration matches configured setting.
|
||||
|
||||
The token should expire approximately jwt_expire_minutes from now.
|
||||
JWT timestamps have second precision, so we allow 1 second tolerance.
|
||||
"""
|
||||
user_id = uuid.uuid4()
|
||||
before = datetime.now(UTC)
|
||||
token = create_access_token(user_id)
|
||||
after = datetime.now(UTC)
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.secret_key.get_secret_value(),
|
||||
algorithms=[settings.jwt_algorithm],
|
||||
)
|
||||
|
||||
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
||||
expected_min = (
|
||||
before + timedelta(minutes=settings.jwt_expire_minutes) - timedelta(seconds=1)
|
||||
)
|
||||
expected_max = after + timedelta(minutes=settings.jwt_expire_minutes) + timedelta(seconds=1)
|
||||
|
||||
assert expected_min <= exp <= expected_max
|
||||
|
||||
def test_includes_issued_at(self):
|
||||
"""Test that access token includes iat (issued at) claim.
|
||||
|
||||
The iat claim should be set to approximately the current time.
|
||||
JWT timestamps have second precision, so we allow 1 second tolerance.
|
||||
"""
|
||||
user_id = uuid.uuid4()
|
||||
before = datetime.now(UTC)
|
||||
token = create_access_token(user_id)
|
||||
after = datetime.now(UTC)
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.secret_key.get_secret_value(),
|
||||
algorithms=[settings.jwt_algorithm],
|
||||
)
|
||||
|
||||
iat = datetime.fromtimestamp(payload["iat"], tz=UTC)
|
||||
assert before - timedelta(seconds=1) <= iat <= after + timedelta(seconds=1)
|
||||
|
||||
|
||||
class TestCreateRefreshToken:
|
||||
"""Tests for create_refresh_token function."""
|
||||
|
||||
def test_creates_valid_jwt_with_jti(self):
|
||||
"""Test that create_refresh_token returns a valid JWT and JTI.
|
||||
|
||||
The function should return a tuple of (token, jti) where the
|
||||
token contains the JTI for revocation tracking.
|
||||
"""
|
||||
user_id = uuid.uuid4()
|
||||
token, jti = create_refresh_token(user_id)
|
||||
|
||||
# Should return token and jti
|
||||
assert isinstance(token, str)
|
||||
assert isinstance(jti, str)
|
||||
assert token.count(".") == 2
|
||||
|
||||
# JTI should be a valid UUID
|
||||
uuid.UUID(jti) # Will raise if invalid
|
||||
|
||||
# Token should contain the JTI
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.secret_key.get_secret_value(),
|
||||
algorithms=[settings.jwt_algorithm],
|
||||
)
|
||||
assert payload["jti"] == jti
|
||||
assert payload["type"] == TokenType.REFRESH.value
|
||||
|
||||
def test_sets_correct_expiration(self):
|
||||
"""Test that refresh token expiration matches configured setting.
|
||||
|
||||
The token should expire approximately jwt_refresh_expire_days from now.
|
||||
JWT timestamps have second precision, so we allow 1 second tolerance.
|
||||
"""
|
||||
user_id = uuid.uuid4()
|
||||
before = datetime.now(UTC)
|
||||
token, _ = create_refresh_token(user_id)
|
||||
after = datetime.now(UTC)
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.secret_key.get_secret_value(),
|
||||
algorithms=[settings.jwt_algorithm],
|
||||
)
|
||||
|
||||
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
||||
expected_min = (
|
||||
before + timedelta(days=settings.jwt_refresh_expire_days) - timedelta(seconds=1)
|
||||
)
|
||||
expected_max = (
|
||||
after + timedelta(days=settings.jwt_refresh_expire_days) + timedelta(seconds=1)
|
||||
)
|
||||
|
||||
assert expected_min <= exp <= expected_max
|
||||
|
||||
def test_generates_unique_jti(self):
|
||||
"""Test that each refresh token gets a unique JTI.
|
||||
|
||||
Multiple calls should generate different JTIs to ensure
|
||||
each token can be individually revoked.
|
||||
"""
|
||||
user_id = uuid.uuid4()
|
||||
_, jti1 = create_refresh_token(user_id)
|
||||
_, jti2 = create_refresh_token(user_id)
|
||||
|
||||
assert jti1 != jti2
|
||||
|
||||
|
||||
class TestDecodeToken:
|
||||
"""Tests for decode_token function."""
|
||||
|
||||
def test_decodes_valid_token(self):
|
||||
"""Test that decode_token returns TokenPayload for valid tokens.
|
||||
|
||||
A valid token should be decoded into a TokenPayload with
|
||||
all expected fields populated.
|
||||
"""
|
||||
user_id = uuid.uuid4()
|
||||
token = create_access_token(user_id)
|
||||
|
||||
payload = decode_token(token)
|
||||
|
||||
assert payload is not None
|
||||
assert payload.sub == str(user_id)
|
||||
assert payload.type == TokenType.ACCESS
|
||||
assert payload.exp is not None
|
||||
assert payload.iat is not None
|
||||
|
||||
def test_returns_none_for_invalid_token(self):
|
||||
"""Test that decode_token returns None for malformed tokens.
|
||||
|
||||
Invalid JWT strings should not raise exceptions but return None.
|
||||
"""
|
||||
result = decode_token("invalid.token.here")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_wrong_signature(self):
|
||||
"""Test that decode_token returns None for tokens with wrong signature.
|
||||
|
||||
Tokens signed with a different key should be rejected.
|
||||
"""
|
||||
user_id = uuid.uuid4()
|
||||
# Create token with different secret
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
"iat": datetime.now(UTC),
|
||||
"type": "access",
|
||||
}
|
||||
token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
|
||||
result = decode_token(token)
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_expired_token(self):
|
||||
"""Test that decode_token returns None for expired tokens.
|
||||
|
||||
Tokens past their expiration should be rejected.
|
||||
"""
|
||||
user_id = uuid.uuid4()
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"exp": datetime.now(UTC) - timedelta(hours=1), # Already expired
|
||||
"iat": datetime.now(UTC) - timedelta(hours=2),
|
||||
"type": "access",
|
||||
}
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.secret_key.get_secret_value(),
|
||||
algorithm=settings.jwt_algorithm,
|
||||
)
|
||||
|
||||
result = decode_token(token)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestVerifyAccessToken:
|
||||
"""Tests for verify_access_token function."""
|
||||
|
||||
def test_returns_user_id_for_valid_access_token(self):
|
||||
"""Test that verify_access_token returns user ID for valid tokens.
|
||||
|
||||
A valid access token should return the UUID of the user.
|
||||
"""
|
||||
user_id = uuid.uuid4()
|
||||
token = create_access_token(user_id)
|
||||
|
||||
result = verify_access_token(token)
|
||||
|
||||
assert result == user_id
|
||||
|
||||
def test_returns_none_for_refresh_token(self):
|
||||
"""Test that verify_access_token rejects refresh tokens.
|
||||
|
||||
Even valid refresh tokens should be rejected when verifying
|
||||
as access tokens to prevent token type confusion.
|
||||
"""
|
||||
user_id = uuid.uuid4()
|
||||
token, _ = create_refresh_token(user_id)
|
||||
|
||||
result = verify_access_token(token)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_invalid_token(self):
|
||||
"""Test that verify_access_token returns None for invalid tokens."""
|
||||
result = verify_access_token("invalid.token.here")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_invalid_uuid_subject(self):
|
||||
"""Test that verify_access_token returns None for non-UUID subject.
|
||||
|
||||
If the subject claim is not a valid UUID, the token should be rejected.
|
||||
"""
|
||||
payload = {
|
||||
"sub": "not-a-uuid",
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
"iat": datetime.now(UTC),
|
||||
"type": "access",
|
||||
}
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.secret_key.get_secret_value(),
|
||||
algorithm=settings.jwt_algorithm,
|
||||
)
|
||||
|
||||
result = verify_access_token(token)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestVerifyRefreshToken:
|
||||
"""Tests for verify_refresh_token function."""
|
||||
|
||||
def test_returns_user_id_and_jti_for_valid_refresh_token(self):
|
||||
"""Test that verify_refresh_token returns user ID and JTI.
|
||||
|
||||
A valid refresh token should return both values needed for
|
||||
revocation checking.
|
||||
"""
|
||||
user_id = uuid.uuid4()
|
||||
token, jti = create_refresh_token(user_id)
|
||||
|
||||
result = verify_refresh_token(token)
|
||||
|
||||
assert result is not None
|
||||
result_user_id, result_jti = result
|
||||
assert result_user_id == user_id
|
||||
assert result_jti == jti
|
||||
|
||||
def test_returns_none_for_access_token(self):
|
||||
"""Test that verify_refresh_token rejects access tokens.
|
||||
|
||||
Even valid access tokens should be rejected when verifying
|
||||
as refresh tokens.
|
||||
"""
|
||||
user_id = uuid.uuid4()
|
||||
token = create_access_token(user_id)
|
||||
|
||||
result = verify_refresh_token(token)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_token_without_jti(self):
|
||||
"""Test that verify_refresh_token rejects tokens missing JTI.
|
||||
|
||||
Refresh tokens must have a JTI for revocation tracking.
|
||||
"""
|
||||
payload = {
|
||||
"sub": str(uuid.uuid4()),
|
||||
"exp": datetime.now(UTC) + timedelta(days=7),
|
||||
"iat": datetime.now(UTC),
|
||||
"type": "refresh",
|
||||
# No jti
|
||||
}
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.secret_key.get_secret_value(),
|
||||
algorithm=settings.jwt_algorithm,
|
||||
)
|
||||
|
||||
result = verify_refresh_token(token)
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_invalid_token(self):
|
||||
"""Test that verify_refresh_token returns None for invalid tokens."""
|
||||
result = verify_refresh_token("invalid.token.here")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestHelperFunctions:
|
||||
"""Tests for helper functions."""
|
||||
|
||||
def test_get_token_expiration_seconds(self):
|
||||
"""Test that get_token_expiration_seconds returns correct value.
|
||||
|
||||
Should return jwt_expire_minutes converted to seconds.
|
||||
"""
|
||||
result = get_token_expiration_seconds()
|
||||
assert result == settings.jwt_expire_minutes * 60
|
||||
|
||||
def test_get_refresh_token_expiration(self):
|
||||
"""Test that get_refresh_token_expiration returns future datetime.
|
||||
|
||||
Should return a datetime approximately jwt_refresh_expire_days
|
||||
in the future.
|
||||
"""
|
||||
before = datetime.now(UTC)
|
||||
result = get_refresh_token_expiration()
|
||||
after = datetime.now(UTC)
|
||||
|
||||
expected_min = before + timedelta(days=settings.jwt_refresh_expire_days)
|
||||
expected_max = after + timedelta(days=settings.jwt_refresh_expire_days)
|
||||
|
||||
assert expected_min <= result <= expected_max
|
||||
407
backend/tests/services/test_user_service.py
Normal file
407
backend/tests/services/test_user_service.py
Normal file
@ -0,0 +1,407 @@
|
||||
"""Tests for UserService.
|
||||
|
||||
Tests the user service CRUD operations and OAuth-based user creation.
|
||||
Uses real Postgres via the db_session fixture from conftest.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db.models import User
|
||||
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
|
||||
from app.services.user_service import user_service
|
||||
|
||||
# Import db_session fixture from db conftest
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class TestGetById:
|
||||
"""Tests for get_by_id method."""
|
||||
|
||||
async def test_returns_user_when_found(self, db_session):
|
||||
"""Test that get_by_id returns user when it exists.
|
||||
|
||||
Creates a user and verifies it can be retrieved by ID.
|
||||
"""
|
||||
# Create user directly
|
||||
user = User(
|
||||
email="test@example.com",
|
||||
display_name="Test User",
|
||||
oauth_provider="google",
|
||||
oauth_id="123456",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
# Retrieve by ID
|
||||
from uuid import UUID
|
||||
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
result = await user_service.get_by_id(db_session, user_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.email == "test@example.com"
|
||||
|
||||
async def test_returns_none_when_not_found(self, db_session):
|
||||
"""Test that get_by_id returns None for nonexistent users."""
|
||||
from uuid import uuid4
|
||||
|
||||
result = await user_service.get_by_id(db_session, uuid4())
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetByEmail:
|
||||
"""Tests for get_by_email method."""
|
||||
|
||||
async def test_returns_user_when_found(self, db_session):
|
||||
"""Test that get_by_email returns user when it exists."""
|
||||
user = User(
|
||||
email="findme@example.com",
|
||||
display_name="Find Me",
|
||||
oauth_provider="discord",
|
||||
oauth_id="discord123",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
result = await user_service.get_by_email(db_session, "findme@example.com")
|
||||
|
||||
assert result is not None
|
||||
assert result.display_name == "Find Me"
|
||||
|
||||
async def test_returns_none_when_not_found(self, db_session):
|
||||
"""Test that get_by_email returns None for nonexistent emails."""
|
||||
result = await user_service.get_by_email(db_session, "nobody@example.com")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetByOAuth:
|
||||
"""Tests for get_by_oauth method."""
|
||||
|
||||
async def test_returns_user_when_found(self, db_session):
|
||||
"""Test that get_by_oauth returns user for matching provider+id."""
|
||||
user = User(
|
||||
email="oauth@example.com",
|
||||
display_name="OAuth User",
|
||||
oauth_provider="google",
|
||||
oauth_id="google-unique-id",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
result = await user_service.get_by_oauth(db_session, "google", "google-unique-id")
|
||||
|
||||
assert result is not None
|
||||
assert result.email == "oauth@example.com"
|
||||
|
||||
async def test_returns_none_for_wrong_provider(self, db_session):
|
||||
"""Test that get_by_oauth returns None if provider doesn't match."""
|
||||
user = User(
|
||||
email="oauth2@example.com",
|
||||
display_name="OAuth User 2",
|
||||
oauth_provider="google",
|
||||
oauth_id="google-id-2",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
# Same ID, different provider
|
||||
result = await user_service.get_by_oauth(db_session, "discord", "google-id-2")
|
||||
assert result is None
|
||||
|
||||
async def test_returns_none_when_not_found(self, db_session):
|
||||
"""Test that get_by_oauth returns None for nonexistent OAuth."""
|
||||
result = await user_service.get_by_oauth(db_session, "google", "nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCreate:
|
||||
"""Tests for create method."""
|
||||
|
||||
async def test_creates_user_with_all_fields(self, db_session):
|
||||
"""Test that create properly persists all user fields."""
|
||||
user_data = UserCreate(
|
||||
email="new@example.com",
|
||||
display_name="New User",
|
||||
avatar_url="https://example.com/avatar.jpg",
|
||||
oauth_provider="discord",
|
||||
oauth_id="discord-new-id",
|
||||
)
|
||||
|
||||
result = await user_service.create(db_session, user_data)
|
||||
|
||||
assert result.id is not None
|
||||
assert result.email == "new@example.com"
|
||||
assert result.display_name == "New User"
|
||||
assert result.avatar_url == "https://example.com/avatar.jpg"
|
||||
assert result.oauth_provider == "discord"
|
||||
assert result.oauth_id == "discord-new-id"
|
||||
assert result.is_premium is False
|
||||
assert result.premium_until is None
|
||||
|
||||
async def test_creates_user_without_avatar(self, db_session):
|
||||
"""Test that create works without optional avatar_url."""
|
||||
user_data = UserCreate(
|
||||
email="noavatar@example.com",
|
||||
display_name="No Avatar",
|
||||
oauth_provider="google",
|
||||
oauth_id="google-no-avatar",
|
||||
)
|
||||
|
||||
result = await user_service.create(db_session, user_data)
|
||||
|
||||
assert result.avatar_url is None
|
||||
|
||||
|
||||
class TestCreateFromOAuth:
|
||||
"""Tests for create_from_oauth method."""
|
||||
|
||||
async def test_creates_user_from_oauth_info(self, db_session):
|
||||
"""Test that create_from_oauth converts OAuthUserInfo to User."""
|
||||
oauth_info = OAuthUserInfo(
|
||||
provider="google",
|
||||
oauth_id="google-oauth-123",
|
||||
email="oauthcreate@example.com",
|
||||
name="OAuth Created User",
|
||||
avatar_url="https://google.com/avatar.jpg",
|
||||
)
|
||||
|
||||
result = await user_service.create_from_oauth(db_session, oauth_info)
|
||||
|
||||
assert result.email == "oauthcreate@example.com"
|
||||
assert result.display_name == "OAuth Created User"
|
||||
assert result.oauth_provider == "google"
|
||||
assert result.oauth_id == "google-oauth-123"
|
||||
|
||||
|
||||
class TestGetOrCreateFromOAuth:
|
||||
"""Tests for get_or_create_from_oauth method."""
|
||||
|
||||
async def test_returns_existing_user_by_oauth(self, db_session):
|
||||
"""Test that existing user is returned when OAuth matches.
|
||||
|
||||
Verifies the method returns (user, False) for existing users.
|
||||
"""
|
||||
# Create existing user
|
||||
existing = User(
|
||||
email="existing@example.com",
|
||||
display_name="Existing",
|
||||
oauth_provider="google",
|
||||
oauth_id="existing-oauth-id",
|
||||
)
|
||||
db_session.add(existing)
|
||||
await db_session.commit()
|
||||
|
||||
# Try to get or create with same OAuth
|
||||
oauth_info = OAuthUserInfo(
|
||||
provider="google",
|
||||
oauth_id="existing-oauth-id",
|
||||
email="existing@example.com",
|
||||
name="Existing",
|
||||
)
|
||||
|
||||
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
||||
|
||||
assert created is False
|
||||
assert result.id == existing.id
|
||||
|
||||
async def test_links_existing_user_by_email(self, db_session):
|
||||
"""Test that OAuth is linked when email matches existing user.
|
||||
|
||||
If a user exists with the same email but different OAuth,
|
||||
the new OAuth should be linked to the existing account.
|
||||
"""
|
||||
# Create user with Google
|
||||
existing = User(
|
||||
email="link@example.com",
|
||||
display_name="Link Me",
|
||||
oauth_provider="google",
|
||||
oauth_id="google-link-id",
|
||||
)
|
||||
db_session.add(existing)
|
||||
await db_session.commit()
|
||||
|
||||
# Login with Discord (same email)
|
||||
oauth_info = OAuthUserInfo(
|
||||
provider="discord",
|
||||
oauth_id="discord-link-id",
|
||||
email="link@example.com",
|
||||
name="Link Me",
|
||||
avatar_url="https://discord.com/avatar.jpg",
|
||||
)
|
||||
|
||||
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
||||
|
||||
assert created is False
|
||||
assert result.id == existing.id
|
||||
# OAuth should be updated to Discord
|
||||
assert result.oauth_provider == "discord"
|
||||
assert result.oauth_id == "discord-link-id"
|
||||
|
||||
async def test_creates_new_user_when_not_found(self, db_session):
|
||||
"""Test that new user is created when no match exists.
|
||||
|
||||
Verifies the method returns (user, True) for new users.
|
||||
"""
|
||||
oauth_info = OAuthUserInfo(
|
||||
provider="discord",
|
||||
oauth_id="brand-new-id",
|
||||
email="brandnew@example.com",
|
||||
name="Brand New",
|
||||
)
|
||||
|
||||
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
||||
|
||||
assert created is True
|
||||
assert result.email == "brandnew@example.com"
|
||||
|
||||
|
||||
class TestUpdate:
|
||||
"""Tests for update method."""
|
||||
|
||||
async def test_updates_display_name(self, db_session):
|
||||
"""Test that update changes display_name when provided."""
|
||||
user = User(
|
||||
email="update@example.com",
|
||||
display_name="Old Name",
|
||||
oauth_provider="google",
|
||||
oauth_id="update-id",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
update_data = UserUpdate(display_name="New Name")
|
||||
result = await user_service.update(db_session, user, update_data)
|
||||
|
||||
assert result.display_name == "New Name"
|
||||
|
||||
async def test_updates_avatar_url(self, db_session):
|
||||
"""Test that update changes avatar_url when provided."""
|
||||
user = User(
|
||||
email="avatar@example.com",
|
||||
display_name="Avatar User",
|
||||
oauth_provider="google",
|
||||
oauth_id="avatar-id",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
update_data = UserUpdate(avatar_url="https://new-avatar.com/img.jpg")
|
||||
result = await user_service.update(db_session, user, update_data)
|
||||
|
||||
assert result.avatar_url == "https://new-avatar.com/img.jpg"
|
||||
|
||||
async def test_ignores_none_values(self, db_session):
|
||||
"""Test that update doesn't change fields set to None.
|
||||
|
||||
Only explicitly provided fields should be updated.
|
||||
"""
|
||||
user = User(
|
||||
email="keep@example.com",
|
||||
display_name="Keep Me",
|
||||
avatar_url="https://keep.com/avatar.jpg",
|
||||
oauth_provider="google",
|
||||
oauth_id="keep-id",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
# Update only display_name, leave avatar alone
|
||||
update_data = UserUpdate(display_name="Changed")
|
||||
result = await user_service.update(db_session, user, update_data)
|
||||
|
||||
assert result.display_name == "Changed"
|
||||
assert result.avatar_url == "https://keep.com/avatar.jpg"
|
||||
|
||||
|
||||
class TestUpdateLastLogin:
|
||||
"""Tests for update_last_login method."""
|
||||
|
||||
async def test_updates_last_login_timestamp(self, db_session):
|
||||
"""Test that update_last_login sets current timestamp."""
|
||||
user = User(
|
||||
email="login@example.com",
|
||||
display_name="Login User",
|
||||
oauth_provider="google",
|
||||
oauth_id="login-id",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
assert user.last_login is None
|
||||
|
||||
before = datetime.now(UTC)
|
||||
result = await user_service.update_last_login(db_session, user)
|
||||
after = datetime.now(UTC)
|
||||
|
||||
assert result.last_login is not None
|
||||
# Allow 1 second tolerance
|
||||
assert before - timedelta(seconds=1) <= result.last_login <= after + timedelta(seconds=1)
|
||||
|
||||
|
||||
class TestUpdatePremium:
|
||||
"""Tests for update_premium method."""
|
||||
|
||||
async def test_grants_premium(self, db_session):
|
||||
"""Test that update_premium sets premium status and expiration."""
|
||||
user = User(
|
||||
email="premium@example.com",
|
||||
display_name="Premium User",
|
||||
oauth_provider="google",
|
||||
oauth_id="premium-id",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
assert user.is_premium is False
|
||||
|
||||
expires = datetime.now(UTC) + timedelta(days=30)
|
||||
result = await user_service.update_premium(db_session, user, expires)
|
||||
|
||||
assert result.is_premium is True
|
||||
assert result.premium_until == expires
|
||||
|
||||
async def test_removes_premium(self, db_session):
|
||||
"""Test that update_premium with None removes premium status."""
|
||||
user = User(
|
||||
email="unpremium@example.com",
|
||||
display_name="Unpremium User",
|
||||
oauth_provider="google",
|
||||
oauth_id="unpremium-id",
|
||||
is_premium=True,
|
||||
premium_until=datetime.now(UTC) + timedelta(days=30),
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
result = await user_service.update_premium(db_session, user, None)
|
||||
|
||||
assert result.is_premium is False
|
||||
assert result.premium_until is None
|
||||
|
||||
|
||||
class TestDelete:
|
||||
"""Tests for delete method."""
|
||||
|
||||
async def test_deletes_user(self, db_session):
|
||||
"""Test that delete removes user from database."""
|
||||
user = User(
|
||||
email="delete@example.com",
|
||||
display_name="Delete Me",
|
||||
oauth_provider="google",
|
||||
oauth_id="delete-id",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
user_id = user.id
|
||||
await user_service.delete(db_session, user)
|
||||
|
||||
# Verify user is gone
|
||||
from uuid import UUID
|
||||
|
||||
result = await user_service.get_by_id(
|
||||
db_session, UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
)
|
||||
assert result is None
|
||||
64
backend/uv.lock
generated
64
backend/uv.lock
generated
@ -368,6 +368,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/48/d9f421cb8da5afaa1a64570d9989e00fb7955e6acddc5a12979f7666ef60/coverage-7.13.1-py3-none-any.whl", hash = "sha256:2016745cb3ba554469d02819d78958b571792bb68e31302610e898f80dd3a573", size = 210722, upload-time = "2025-12-28T15:42:54.901Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dnspython"
|
||||
version = "2.8.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8c/8b/57666417c0f90f08bcafa776861060426765fdb422eb10212086fb811d26/dnspython-2.8.0.tar.gz", hash = "sha256:181d3c6996452cb1189c4046c61599b84a5a86e099562ffde77d26984ff26d0f", size = 368251, upload-time = "2025-09-07T18:58:00.022Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docker"
|
||||
version = "7.1.0"
|
||||
@ -394,6 +403,32 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "email-validator"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dnspython" },
|
||||
{ name = "idna" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f5/22/900cb125c76b7aaa450ce02fd727f452243f2e91a61af068b40adba60ea9/email_validator-2.3.0.tar.gz", hash = "sha256:9fc05c37f2f6cf439ff414f8fc46d917929974a82244c20eb10231ba60c54426", size = 51238, upload-time = "2025-08-26T13:09:06.831Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fakeredis"
|
||||
version = "2.33.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "redis" },
|
||||
{ name = "sortedcontainers" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5f/f9/57464119936414d60697fcbd32f38909bb5688b616ae13de6e98384433e0/fakeredis-2.33.0.tar.gz", hash = "sha256:d7bc9a69d21df108a6451bbffee23b3eba432c21a654afc7ff2d295428ec5770", size = 175187, upload-time = "2025-12-16T19:45:52.269Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/78/a850fed8aeef96d4a99043c90b818b2ed5419cd5b24a4049fd7cfb9f1471/fakeredis-2.33.0-py3-none-any.whl", hash = "sha256:de535f3f9ccde1c56672ab2fdd6a8efbc4f2619fc2f1acc87b8737177d71c965", size = 119605, upload-time = "2025-12-16T19:45:51.08Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
@ -579,7 +614,9 @@ dependencies = [
|
||||
{ name = "alembic" },
|
||||
{ name = "asyncpg" },
|
||||
{ name = "bcrypt" },
|
||||
{ name = "email-validator" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "httpx" },
|
||||
{ name = "passlib" },
|
||||
{ name = "psycopg2-binary" },
|
||||
{ name = "pydantic" },
|
||||
@ -595,12 +632,14 @@ dependencies = [
|
||||
dev = [
|
||||
{ name = "beautifulsoup4" },
|
||||
{ name = "black" },
|
||||
{ name = "fakeredis" },
|
||||
{ name = "httpx" },
|
||||
{ name = "mypy" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "requests" },
|
||||
{ name = "respx" },
|
||||
{ name = "ruff" },
|
||||
{ name = "testcontainers", extra = ["redis"] },
|
||||
]
|
||||
@ -610,7 +649,9 @@ requires-dist = [
|
||||
{ name = "alembic", specifier = ">=1.18.1" },
|
||||
{ name = "asyncpg", specifier = ">=0.31.0" },
|
||||
{ name = "bcrypt", specifier = ">=5.0.0" },
|
||||
{ name = "email-validator", specifier = ">=2.3.0" },
|
||||
{ name = "fastapi", specifier = ">=0.128.0" },
|
||||
{ name = "httpx", specifier = ">=0.28.1" },
|
||||
{ name = "passlib", specifier = ">=1.7.4" },
|
||||
{ name = "psycopg2-binary", specifier = ">=2.9.11" },
|
||||
{ name = "pydantic", specifier = ">=2.12.5" },
|
||||
@ -626,12 +667,14 @@ requires-dist = [
|
||||
dev = [
|
||||
{ name = "beautifulsoup4", specifier = ">=4.12.0" },
|
||||
{ name = "black", specifier = ">=26.1.0" },
|
||||
{ name = "fakeredis", specifier = ">=2.33.0" },
|
||||
{ name = "httpx", specifier = ">=0.28.1" },
|
||||
{ name = "mypy", specifier = ">=1.19.1" },
|
||||
{ name = "pytest", specifier = ">=9.0.2" },
|
||||
{ name = "pytest-asyncio", specifier = ">=1.3.0" },
|
||||
{ name = "pytest-cov", specifier = ">=7.0.0" },
|
||||
{ name = "requests", specifier = ">=2.31.0" },
|
||||
{ name = "respx", specifier = ">=0.22.0" },
|
||||
{ name = "ruff", specifier = ">=0.14.14" },
|
||||
{ name = "testcontainers", extras = ["postgres", "redis"], specifier = ">=4.0.0" },
|
||||
]
|
||||
@ -1105,6 +1148,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "respx"
|
||||
version = "0.22.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f4/7c/96bd0bc759cf009675ad1ee1f96535edcb11e9666b985717eb8c87192a95/respx-0.22.0.tar.gz", hash = "sha256:3c8924caa2a50bd71aefc07aa812f2466ff489f1848c96e954a5362d17095d91", size = 28439, upload-time = "2024-12-19T22:33:59.374Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/67/afbb0978d5399bc9ea200f1d4489a23c9a1dad4eee6376242b8182389c79/respx-0.22.0-py2.py3-none-any.whl", hash = "sha256:631128d4c9aba15e56903fb5f66fb1eff412ce28dd387ca3a81339e52dbd3ad0", size = 25127, upload-time = "2024-12-19T22:33:57.837Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rsa"
|
||||
version = "4.9.1"
|
||||
@ -1164,6 +1219,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sortedcontainers"
|
||||
version = "2.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload-time = "2021-05-16T22:03:42.897Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "soupsieve"
|
||||
version = "2.8.3"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user