Merge branch 'backend-phase2' - Complete Phase 2 Authentication

This commit is contained in:
Cal Corum 2026-01-27 22:08:42 -06:00
commit 4cdb544162
34 changed files with 5576 additions and 3 deletions

View File

@ -0,0 +1,14 @@
"""API routers and dependencies for Mantimon TCG.
This package contains FastAPI routers and common dependencies
for the REST API.
Routers:
- auth: OAuth login, token refresh, logout
- users: User profile management
Dependencies:
- get_current_user: Extract and validate user from JWT
- get_current_active_user: Ensure user exists
- get_current_premium_user: Require premium subscription
"""

651
backend/app/api/auth.py Normal file
View File

@ -0,0 +1,651 @@
"""Authentication API router for Mantimon TCG.
This module provides endpoints for OAuth authentication:
- OAuth login redirects (Google, Discord)
- OAuth callbacks (token exchange)
- Token refresh
- Logout
OAuth Flow:
1. Client calls GET /auth/{provider} to get redirect URL
2. Client redirects user to OAuth provider
3. Provider redirects back to GET /auth/{provider}/callback
4. Server exchanges code for tokens, creates/fetches user
5. Server returns JWT access + refresh tokens
Example:
# Start OAuth flow
GET /api/auth/google?redirect_uri=https://play.mantimon.com/login/callback
# After OAuth callback, refresh tokens
POST /api/auth/refresh
{"refresh_token": "..."}
# Logout
POST /api/auth/logout
{"refresh_token": "..."}
"""
import secrets
from fastapi import APIRouter, HTTPException, Query, status
from fastapi.responses import RedirectResponse
from app.api.deps import CurrentUser, DbSession
from app.config import settings
from app.db.redis import get_redis
from app.schemas.auth import RefreshTokenRequest, TokenResponse
from app.services.jwt_service import (
create_access_token,
create_refresh_token,
get_refresh_token_expiration,
get_token_expiration_seconds,
verify_refresh_token,
)
from app.services.oauth.discord import DiscordOAuthError, discord_oauth
from app.services.oauth.google import GoogleOAuthError, google_oauth
from app.services.token_store import token_store
from app.services.user_service import AccountLinkingError, user_service
router = APIRouter(prefix="/auth", tags=["auth"])
# OAuth state TTL (5 minutes)
OAUTH_STATE_TTL = 300
async def _store_oauth_state(state: str, provider: str, redirect_uri: str) -> None:
"""Store OAuth state in Redis for CSRF validation."""
async with get_redis() as redis:
key = f"oauth_state:{state}"
value = f"{provider}:{redirect_uri}"
await redis.setex(key, OAUTH_STATE_TTL, value)
async def _validate_oauth_state(state: str, provider: str) -> str | None:
"""Validate and consume OAuth state, returning redirect_uri if valid."""
async with get_redis() as redis:
key = f"oauth_state:{state}"
value = await redis.get(key)
if not value:
return None
# Delete state (one-time use)
await redis.delete(key)
# Parse and validate
stored_provider, redirect_uri = value.split(":", 1)
if stored_provider != provider:
return None
return redirect_uri
async def _create_tokens_for_user(user_id) -> TokenResponse:
"""Create access and refresh tokens for a user."""
from uuid import UUID
if isinstance(user_id, str):
user_id = UUID(user_id)
access_token = create_access_token(user_id)
refresh_token, jti = create_refresh_token(user_id)
# Store refresh token in Redis for revocation tracking
expires_at = get_refresh_token_expiration()
await token_store.store_refresh_token(user_id, jti, expires_at)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=get_token_expiration_seconds(),
)
# =============================================================================
# Google OAuth
# =============================================================================
@router.get("/google")
async def google_auth_redirect(
redirect_uri: str = Query(..., description="URI to redirect to after OAuth completes"),
) -> RedirectResponse:
"""Start Google OAuth flow.
Redirects the user to Google's OAuth consent screen.
Args:
redirect_uri: Where to redirect after successful authentication.
This is YOUR app's callback, not the OAuth callback.
Returns:
Redirect to Google OAuth authorization URL.
Raises:
HTTPException: 501 if Google OAuth is not configured.
"""
if not google_oauth.is_configured():
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Google OAuth is not configured",
)
# Generate state for CSRF protection
state = secrets.token_urlsafe(32)
await _store_oauth_state(state, "google", redirect_uri)
# Build OAuth callback URL (our server endpoint)
# The redirect_uri param here is where Google sends the code
# Must be an absolute URL for OAuth providers
oauth_callback = f"{settings.base_url}/api/auth/google/callback"
# Get authorization URL
auth_url = google_oauth.get_authorization_url(
redirect_uri=oauth_callback,
state=state,
)
return RedirectResponse(url=auth_url, status_code=status.HTTP_302_FOUND)
@router.get("/google/callback")
async def google_auth_callback(
db: DbSession,
code: str = Query(..., description="Authorization code from Google"),
state: str = Query(..., description="State parameter for CSRF validation"),
) -> TokenResponse:
"""Handle Google OAuth callback.
Exchanges the authorization code for tokens and creates/fetches the user.
Args:
code: Authorization code from Google.
state: State parameter for CSRF validation.
Returns:
JWT access and refresh tokens.
Raises:
HTTPException: 400 if state is invalid or OAuth fails.
"""
# Validate state
redirect_uri = await _validate_oauth_state(state, "google")
if redirect_uri is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired state parameter",
)
try:
# Exchange code for user info
oauth_callback = f"{settings.base_url}/api/auth/google/callback"
user_info = await google_oauth.get_user_info(code, oauth_callback)
# Get or create user
user, created = await user_service.get_or_create_from_oauth(db, user_info)
# Update last login
await user_service.update_last_login(db, user)
# Create tokens
return await _create_tokens_for_user(user.id)
except GoogleOAuthError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Google OAuth failed: {e}",
) from None
# =============================================================================
# Discord OAuth
# =============================================================================
@router.get("/discord")
async def discord_auth_redirect(
redirect_uri: str = Query(..., description="URI to redirect to after OAuth completes"),
) -> RedirectResponse:
"""Start Discord OAuth flow.
Redirects the user to Discord's OAuth consent screen.
Args:
redirect_uri: Where to redirect after successful authentication.
Returns:
Redirect to Discord OAuth authorization URL.
Raises:
HTTPException: 501 if Discord OAuth is not configured.
"""
if not discord_oauth.is_configured():
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Discord OAuth is not configured",
)
# Generate state for CSRF protection
state = secrets.token_urlsafe(32)
await _store_oauth_state(state, "discord", redirect_uri)
# Build OAuth callback URL (must be absolute for OAuth providers)
oauth_callback = f"{settings.base_url}/api/auth/discord/callback"
# Get authorization URL
auth_url = discord_oauth.get_authorization_url(
redirect_uri=oauth_callback,
state=state,
)
return RedirectResponse(url=auth_url, status_code=status.HTTP_302_FOUND)
@router.get("/discord/callback")
async def discord_auth_callback(
db: DbSession,
code: str = Query(..., description="Authorization code from Discord"),
state: str = Query(..., description="State parameter for CSRF validation"),
) -> TokenResponse:
"""Handle Discord OAuth callback.
Exchanges the authorization code for tokens and creates/fetches the user.
Args:
code: Authorization code from Discord.
state: State parameter for CSRF validation.
Returns:
JWT access and refresh tokens.
Raises:
HTTPException: 400 if state is invalid or OAuth fails.
"""
# Validate state
redirect_uri = await _validate_oauth_state(state, "discord")
if redirect_uri is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired state parameter",
)
try:
# Exchange code for user info
oauth_callback = f"{settings.base_url}/api/auth/discord/callback"
user_info = await discord_oauth.get_user_info(code, oauth_callback)
# Get or create user
user, created = await user_service.get_or_create_from_oauth(db, user_info)
# Update last login
await user_service.update_last_login(db, user)
# Create tokens
return await _create_tokens_for_user(user.id)
except DiscordOAuthError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Discord OAuth failed: {e}",
) from None
# =============================================================================
# Token Management
# =============================================================================
@router.post("/refresh", response_model=TokenResponse)
async def refresh_tokens(
db: DbSession,
request: RefreshTokenRequest,
) -> TokenResponse:
"""Refresh access token using refresh token.
Validates the refresh token and issues a new access token.
The refresh token itself is NOT rotated (same token can be used
until it expires or is revoked).
Args:
request: Contains the refresh token.
Returns:
New JWT access token (same refresh token).
Raises:
HTTPException: 401 if refresh token is invalid or revoked.
"""
# Verify refresh token
result = verify_refresh_token(request.refresh_token)
if result is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
user_id, jti = result
# Check if token is revoked
is_valid = await token_store.is_token_valid(user_id, jti)
if not is_valid:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has been revoked",
)
# Verify user still exists
user = await user_service.get_by_id(db, user_id)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
)
# Create new access token (keep same refresh token)
access_token = create_access_token(user_id)
return TokenResponse(
access_token=access_token,
refresh_token=request.refresh_token, # Return same refresh token
expires_in=get_token_expiration_seconds(),
)
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
async def logout(
request: RefreshTokenRequest,
) -> None:
"""Logout by revoking the refresh token.
The access token will continue to work until it expires,
but the refresh token cannot be used to get new tokens.
Args:
request: Contains the refresh token to revoke.
"""
# Verify refresh token to get user_id and jti
result = verify_refresh_token(request.refresh_token)
if result is None:
# Token is invalid, but that's fine - user is effectively logged out
return
user_id, jti = result
# Revoke the token
await token_store.revoke_token(user_id, jti)
@router.post("/logout-all", status_code=status.HTTP_204_NO_CONTENT)
async def logout_all(
user: CurrentUser,
) -> None:
"""Logout from all devices by revoking all refresh tokens.
Requires authentication (uses current access token).
All refresh tokens for the user will be revoked.
Args:
user: Current authenticated user.
"""
from uuid import UUID
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
await token_store.revoke_all_user_tokens(user_id)
# =============================================================================
# Account Linking
# =============================================================================
async def _store_link_state(state: str, provider: str, user_id: str, redirect_uri: str) -> None:
"""Store OAuth state for account linking (includes user_id)."""
async with get_redis() as redis:
key = f"oauth_link_state:{state}"
value = f"{provider}:{user_id}:{redirect_uri}"
await redis.setex(key, OAUTH_STATE_TTL, value)
async def _validate_link_state(state: str, provider: str) -> tuple[str, str] | None:
"""Validate and consume link state, returning (user_id, redirect_uri) if valid."""
async with get_redis() as redis:
key = f"oauth_link_state:{state}"
value = await redis.get(key)
if not value:
return None
# Delete state (one-time use)
await redis.delete(key)
# Parse and validate
parts = value.split(":", 2)
if len(parts) != 3:
return None
stored_provider, user_id, redirect_uri = parts
if stored_provider != provider:
return None
return user_id, redirect_uri
@router.get("/link/google")
async def google_link_redirect(
user: CurrentUser,
redirect_uri: str = Query(..., description="URI to redirect to after linking"),
) -> RedirectResponse:
"""Start Google OAuth flow for account linking.
Requires authentication. Links Google account to the current user.
Args:
redirect_uri: Where to redirect after linking completes.
Returns:
Redirect to Google OAuth authorization URL.
Raises:
HTTPException: 501 if Google OAuth is not configured.
HTTPException: 400 if Google is already the primary provider.
"""
if not google_oauth.is_configured():
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Google OAuth is not configured",
)
# Check if Google is already their primary provider
if user.oauth_provider == "google":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Google is already your primary login provider",
)
# Generate state for CSRF protection
state = secrets.token_urlsafe(32)
await _store_link_state(state, "google", str(user.id), redirect_uri)
# Build OAuth callback URL for linking
oauth_callback = f"{settings.base_url}/api/auth/link/google/callback"
# Get authorization URL
auth_url = google_oauth.get_authorization_url(
redirect_uri=oauth_callback,
state=state,
)
return RedirectResponse(url=auth_url, status_code=status.HTTP_302_FOUND)
@router.get("/link/google/callback")
async def google_link_callback(
db: DbSession,
code: str = Query(..., description="Authorization code from Google"),
state: str = Query(..., description="State parameter for CSRF validation"),
) -> RedirectResponse:
"""Handle Google OAuth callback for account linking.
Exchanges the authorization code and links the Google account to the user.
Args:
code: Authorization code from Google.
state: State parameter for CSRF validation.
Returns:
Redirect to the original redirect_uri with success/error query params.
"""
# Validate state
result = await _validate_link_state(state, "google")
if result is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired state parameter",
)
user_id_str, redirect_uri = result
try:
# Exchange code for user info
oauth_callback = f"{settings.base_url}/api/auth/link/google/callback"
oauth_info = await google_oauth.get_user_info(code, oauth_callback)
# Get the user
from uuid import UUID
user_id = UUID(user_id_str)
user = await user_service.get_by_id(db, user_id)
if user is None:
return RedirectResponse(
url=f"{redirect_uri}?error=user_not_found",
status_code=status.HTTP_302_FOUND,
)
# Link the account
await user_service.link_oauth_account(db, user, oauth_info)
return RedirectResponse(
url=f"{redirect_uri}?linked=google",
status_code=status.HTTP_302_FOUND,
)
except GoogleOAuthError as e:
return RedirectResponse(
url=f"{redirect_uri}?error=oauth_failed&message={e}",
status_code=status.HTTP_302_FOUND,
)
except AccountLinkingError as e:
return RedirectResponse(
url=f"{redirect_uri}?error=linking_failed&message={e}",
status_code=status.HTTP_302_FOUND,
)
@router.get("/link/discord")
async def discord_link_redirect(
user: CurrentUser,
redirect_uri: str = Query(..., description="URI to redirect to after linking"),
) -> RedirectResponse:
"""Start Discord OAuth flow for account linking.
Requires authentication. Links Discord account to the current user.
Args:
redirect_uri: Where to redirect after linking completes.
Returns:
Redirect to Discord OAuth authorization URL.
Raises:
HTTPException: 501 if Discord OAuth is not configured.
HTTPException: 400 if Discord is already the primary provider.
"""
if not discord_oauth.is_configured():
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Discord OAuth is not configured",
)
# Check if Discord is already their primary provider
if user.oauth_provider == "discord":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Discord is already your primary login provider",
)
# Generate state for CSRF protection
state = secrets.token_urlsafe(32)
await _store_link_state(state, "discord", str(user.id), redirect_uri)
# Build OAuth callback URL for linking
oauth_callback = f"{settings.base_url}/api/auth/link/discord/callback"
# Get authorization URL
auth_url = discord_oauth.get_authorization_url(
redirect_uri=oauth_callback,
state=state,
)
return RedirectResponse(url=auth_url, status_code=status.HTTP_302_FOUND)
@router.get("/link/discord/callback")
async def discord_link_callback(
db: DbSession,
code: str = Query(..., description="Authorization code from Discord"),
state: str = Query(..., description="State parameter for CSRF validation"),
) -> RedirectResponse:
"""Handle Discord OAuth callback for account linking.
Exchanges the authorization code and links the Discord account to the user.
Args:
code: Authorization code from Discord.
state: State parameter for CSRF validation.
Returns:
Redirect to the original redirect_uri with success/error query params.
"""
# Validate state
result = await _validate_link_state(state, "discord")
if result is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired state parameter",
)
user_id_str, redirect_uri = result
try:
# Exchange code for user info
oauth_callback = f"{settings.base_url}/api/auth/link/discord/callback"
oauth_info = await discord_oauth.get_user_info(code, oauth_callback)
# Get the user
from uuid import UUID
user_id = UUID(user_id_str)
user = await user_service.get_by_id(db, user_id)
if user is None:
return RedirectResponse(
url=f"{redirect_uri}?error=user_not_found",
status_code=status.HTTP_302_FOUND,
)
# Link the account
await user_service.link_oauth_account(db, user, oauth_info)
return RedirectResponse(
url=f"{redirect_uri}?linked=discord",
status_code=status.HTTP_302_FOUND,
)
except DiscordOAuthError as e:
return RedirectResponse(
url=f"{redirect_uri}?error=oauth_failed&message={e}",
status_code=status.HTTP_302_FOUND,
)
except AccountLinkingError as e:
return RedirectResponse(
url=f"{redirect_uri}?error=linking_failed&message={e}",
status_code=status.HTTP_302_FOUND,
)

172
backend/app/api/deps.py Normal file
View File

@ -0,0 +1,172 @@
"""FastAPI dependencies for Mantimon TCG API.
This module provides dependency injection functions for authentication
and database access in API endpoints.
Usage:
from app.api.deps import get_current_user, get_db
@router.get("/me")
async def get_me(
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
return user
Dependencies:
- get_db: Async database session
- get_current_user: Authenticated user from JWT (required)
- get_optional_user: Authenticated user or None
- get_current_premium_user: User with active premium
"""
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_session_dependency
from app.db.models import User
from app.services.jwt_service import verify_access_token
from app.services.user_service import user_service
# OAuth2 scheme for extracting Bearer token from Authorization header
# tokenUrl is for OpenAPI docs - points to where tokens are obtained
oauth2_scheme = OAuth2PasswordBearer(
tokenUrl="/api/auth/token", # For OpenAPI docs
auto_error=True, # Raise 401 if no token
)
oauth2_scheme_optional = OAuth2PasswordBearer(
tokenUrl="/api/auth/token",
auto_error=False, # Return None if no token
)
async def get_db() -> AsyncSession:
"""Get async database session.
Yields:
AsyncSession for database operations.
Example:
@router.get("/items")
async def get_items(db: AsyncSession = Depends(get_db)):
...
"""
async with get_session_dependency() as session:
yield session
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> User:
"""Get the current authenticated user from JWT token.
Validates the access token and fetches the user from database.
Args:
token: JWT access token from Authorization header.
db: Database session.
Returns:
The authenticated User.
Raises:
HTTPException: 401 if token is invalid or user not found.
Example:
@router.get("/me")
async def get_me(user: User = Depends(get_current_user)):
return {"email": user.email}
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# Verify token and extract user ID
user_id = verify_access_token(token)
if user_id is None:
raise credentials_exception
# Fetch user from database
user = await user_service.get_by_id(db, user_id)
if user is None:
raise credentials_exception
return user
async def get_optional_user(
token: Annotated[str | None, Depends(oauth2_scheme_optional)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> User | None:
"""Get the current user if authenticated, or None.
Useful for endpoints that work both with and without authentication,
but may provide additional features for authenticated users.
Args:
token: JWT access token or None.
db: Database session.
Returns:
The authenticated User, or None if not authenticated.
Example:
@router.get("/cards")
async def get_cards(user: User | None = Depends(get_optional_user)):
if user:
# Show user's collection
else:
# Show public cards
"""
if token is None:
return None
user_id = verify_access_token(token)
if user_id is None:
return None
return await user_service.get_by_id(db, user_id)
async def get_current_premium_user(
user: Annotated[User, Depends(get_current_user)],
) -> User:
"""Get the current user and verify they have premium.
Args:
user: The authenticated user.
Returns:
The authenticated User with active premium.
Raises:
HTTPException: 403 if user doesn't have active premium.
Example:
@router.post("/decks")
async def create_unlimited_decks(
user: User = Depends(get_current_premium_user)
):
# Only premium users can have unlimited decks
...
"""
if not user.has_active_premium:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Premium subscription required",
)
return user
# Type aliases for cleaner endpoint signatures
CurrentUser = Annotated[User, Depends(get_current_user)]
OptionalUser = Annotated[User | None, Depends(get_optional_user)]
PremiumUser = Annotated[User, Depends(get_current_premium_user)]
DbSession = Annotated[AsyncSession, Depends(get_db)]

168
backend/app/api/users.py Normal file
View File

@ -0,0 +1,168 @@
"""User API router for Mantimon TCG.
This module provides endpoints for user profile management:
- Get current user profile
- Update profile (display name, avatar)
- List linked OAuth accounts
- Session management
All endpoints require authentication.
Example:
# Get current user
GET /api/users/me
Authorization: Bearer <access_token>
# Update profile
PATCH /api/users/me
{"display_name": "NewName"}
"""
from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel, Field
from app.api.deps import CurrentUser, DbSession
from app.schemas.user import UserResponse, UserUpdate
from app.services.token_store import token_store
from app.services.user_service import AccountLinkingError, user_service
router = APIRouter(prefix="/users", tags=["users"])
class LinkedAccountResponse(BaseModel):
"""Response for a linked OAuth account."""
provider: str = Field(..., description="OAuth provider name")
email: str | None = Field(None, description="Email from this provider")
linked_at: str = Field(..., description="When account was linked (ISO format)")
class SessionsResponse(BaseModel):
"""Response for active sessions count."""
active_sessions: int = Field(..., description="Number of active sessions")
@router.get("/me", response_model=UserResponse)
async def get_current_user_profile(
user: CurrentUser,
) -> UserResponse:
"""Get the current user's profile.
Returns:
User profile information.
"""
return UserResponse.model_validate(user)
@router.patch("/me", response_model=UserResponse)
async def update_current_user_profile(
user: CurrentUser,
db: DbSession,
update_data: UserUpdate,
) -> UserResponse:
"""Update the current user's profile.
Only provided fields are updated.
Args:
update_data: Fields to update (display_name, avatar_url).
Returns:
Updated user profile.
"""
updated_user = await user_service.update(db, user, update_data)
return UserResponse.model_validate(updated_user)
@router.get("/me/linked-accounts", response_model=list[LinkedAccountResponse])
async def get_linked_accounts(
user: CurrentUser,
) -> list[LinkedAccountResponse]:
"""Get all OAuth accounts linked to the current user.
Returns the primary OAuth provider plus any additional linked accounts.
Returns:
List of linked OAuth accounts.
"""
accounts = []
# Add primary OAuth account
accounts.append(
LinkedAccountResponse(
provider=user.oauth_provider,
email=user.email,
linked_at=user.created_at.isoformat(),
)
)
# Add additional linked accounts
for linked in user.linked_accounts:
accounts.append(
LinkedAccountResponse(
provider=linked.provider,
email=linked.email,
linked_at=linked.linked_at.isoformat(),
)
)
return accounts
@router.get("/me/sessions", response_model=SessionsResponse)
async def get_active_sessions(
user: CurrentUser,
) -> SessionsResponse:
"""Get the number of active sessions for the current user.
Each session corresponds to a valid refresh token.
Returns:
Number of active sessions.
"""
from uuid import UUID
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
count = await token_store.get_active_session_count(user_id)
return SessionsResponse(active_sessions=count)
@router.delete("/me/link/{provider}", status_code=status.HTTP_204_NO_CONTENT)
async def unlink_oauth_account(
user: CurrentUser,
db: DbSession,
provider: str,
) -> None:
"""Unlink an OAuth provider from the current user's account.
Cannot unlink the primary OAuth provider (the one used to create the account).
Args:
provider: OAuth provider name to unlink ('google' or 'discord').
Raises:
HTTPException: 400 if trying to unlink primary provider.
HTTPException: 404 if provider is not linked.
"""
provider = provider.lower()
if provider not in ("google", "discord"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unknown provider: {provider}",
)
try:
unlinked = await user_service.unlink_oauth_account(db, user, provider)
if not unlinked:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"{provider.title()} is not linked to your account",
)
except AccountLinkingError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from None

View File

@ -154,6 +154,12 @@ class Settings(BaseSettings):
description="Allowed CORS origins",
)
# Base URL (for OAuth callbacks and external links)
base_url: str = Field(
default="http://localhost:8000",
description="Base URL of the API server (for OAuth callbacks)",
)
# Game Settings
turn_timeout_seconds: int = Field(
default=120,

View File

@ -0,0 +1,69 @@
"""add_oauth_linked_accounts
Revision ID: 5ce887128ab1
Revises: ab8a0039fe55
Create Date: 2026-01-27 16:42:12.335987
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "5ce887128ab1"
down_revision: str | Sequence[str] | None = "ab8a0039fe55"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"oauth_linked_accounts",
sa.Column("user_id", sa.UUID(as_uuid=False), nullable=False),
sa.Column("provider", sa.String(length=20), nullable=False),
sa.Column("oauth_id", sa.String(length=255), nullable=False),
sa.Column("email", sa.String(length=255), nullable=True),
sa.Column("display_name", sa.String(length=100), nullable=True),
sa.Column("avatar_url", sa.String(length=500), nullable=True),
sa.Column(
"linked_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False
),
sa.Column("id", sa.UUID(as_uuid=False), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_linked_accounts_provider_oauth_id",
"oauth_linked_accounts",
["provider", "oauth_id"],
unique=True,
)
op.create_index(
op.f("ix_oauth_linked_accounts_user_id"), "oauth_linked_accounts", ["user_id"], unique=False
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_oauth_linked_accounts_user_id"), table_name="oauth_linked_accounts")
op.drop_index("ix_oauth_linked_accounts_provider_oauth_id", table_name="oauth_linked_accounts")
op.drop_table("oauth_linked_accounts")
# ### end Alembic commands ###

View File

@ -25,11 +25,14 @@ from app.db.models.campaign import CampaignProgress
from app.db.models.collection import CardSource, Collection
from app.db.models.deck import Deck
from app.db.models.game import ActiveGame, EndReason, GameHistory, GameType
from app.db.models.oauth_account import OAuthLinkedAccount
from app.db.models.user import User
__all__ = [
# User
"User",
# OAuth
"OAuthLinkedAccount",
# Collection
"Collection",
"CardSource",

View File

@ -0,0 +1,122 @@
"""OAuth linked account model for Mantimon TCG.
This module defines the OAuthLinkedAccount model for supporting multiple
OAuth providers per user (account linking).
A user can have multiple linked accounts (e.g., both Google and Discord),
allowing them to log in with either provider.
Example:
# User links Discord to their existing Google account
linked_account = OAuthLinkedAccount(
user_id=user.id,
provider="discord",
oauth_id="123456789",
email="player@example.com"
)
"""
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import DateTime, ForeignKey, Index, String, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.base import Base
if TYPE_CHECKING:
from app.db.models.user import User
class OAuthLinkedAccount(Base):
"""Linked OAuth account for multi-provider authentication.
Allows users to link multiple OAuth providers to a single account,
enabling login via any linked provider.
The User model still has oauth_provider/oauth_id for the "primary"
provider (the one used to create the account). This table tracks
additional linked providers.
Attributes:
id: Unique identifier (UUID).
user_id: Foreign key to the user who owns this linked account.
provider: OAuth provider name ('google', 'discord').
oauth_id: Unique ID from the OAuth provider.
email: Email address from this OAuth provider (may differ from user's primary email).
display_name: Display name from this OAuth provider.
avatar_url: Avatar URL from this OAuth provider.
linked_at: When this account was linked.
Relationships:
user: The User who owns this linked account.
"""
__tablename__ = "oauth_linked_accounts"
# Foreign key to user
user_id: Mapped[str] = mapped_column(
UUID(as_uuid=False),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
doc="User who owns this linked account",
)
# OAuth provider info
provider: Mapped[str] = mapped_column(
String(20),
nullable=False,
doc="OAuth provider name (google, discord)",
)
oauth_id: Mapped[str] = mapped_column(
String(255),
nullable=False,
doc="Unique ID from OAuth provider",
)
# Additional info from provider
email: Mapped[str | None] = mapped_column(
String(255),
nullable=True,
doc="Email from this OAuth provider",
)
display_name: Mapped[str | None] = mapped_column(
String(100),
nullable=True,
doc="Display name from this OAuth provider",
)
avatar_url: Mapped[str | None] = mapped_column(
String(500),
nullable=True,
doc="Avatar URL from this OAuth provider",
)
# When linked
linked_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
doc="When this account was linked",
)
# Relationship back to user
user: Mapped["User"] = relationship(
"User",
back_populates="linked_accounts",
)
# Indexes and constraints
__table_args__ = (
# Each OAuth provider+ID can only be linked to one user
Index(
"ix_oauth_linked_accounts_provider_oauth_id",
"provider",
"oauth_id",
unique=True,
),
)
def __repr__(self) -> str:
return f"<OAuthLinkedAccount(user_id={self.user_id!r}, provider={self.provider!r})>"

View File

@ -24,6 +24,7 @@ if TYPE_CHECKING:
from app.db.models.campaign import CampaignProgress
from app.db.models.collection import Collection
from app.db.models.deck import Deck
from app.db.models.oauth_account import OAuthLinkedAccount
class User(Base):
@ -45,6 +46,7 @@ class User(Base):
decks: User's deck configurations.
collection: User's card collection.
campaign_progress: User's campaign state.
linked_accounts: Additional linked OAuth providers.
"""
__tablename__ = "users"
@ -120,6 +122,12 @@ class User(Base):
uselist=False,
lazy="selectin",
)
linked_accounts: Mapped[list["OAuthLinkedAccount"]] = relationship(
"OAuthLinkedAccount",
back_populates="user",
cascade="all, delete-orphan",
lazy="selectin",
)
# Indexes
__table_args__ = (Index("ix_users_oauth", "oauth_provider", "oauth_id", unique=True),)

View File

@ -18,6 +18,8 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.auth import router as auth_router
from app.api.users import router as users_router
from app.config import settings
from app.db import close_db, init_db
from app.db.redis import close_redis, init_redis
@ -159,9 +161,11 @@ async def readiness_check() -> dict[str, str | int]:
# === API Routers ===
# TODO: Add routers in Phase 2
# from app.api import auth, games, cards, decks, campaign
# app.include_router(auth.router, prefix="/api/auth", tags=["auth"])
app.include_router(auth_router, prefix="/api")
app.include_router(users_router, prefix="/api")
# TODO: Add remaining routers in future phases
# from app.api import cards, decks, games, campaign
# app.include_router(cards.router, prefix="/api/cards", tags=["cards"])
# app.include_router(decks.router, prefix="/api/decks", tags=["decks"])
# app.include_router(games.router, prefix="/api/games", tags=["games"])

View File

@ -0,0 +1,32 @@
"""Pydantic schemas for Mantimon TCG API.
This package contains request/response models for all API endpoints.
"""
from app.schemas.auth import (
OAuthState,
RefreshTokenRequest,
TokenPayload,
TokenResponse,
TokenType,
)
from app.schemas.user import (
OAuthUserInfo,
UserCreate,
UserResponse,
UserUpdate,
)
__all__ = [
# Auth schemas
"TokenType",
"TokenPayload",
"TokenResponse",
"RefreshTokenRequest",
"OAuthState",
# User schemas
"UserResponse",
"UserCreate",
"UserUpdate",
"OAuthUserInfo",
]

View File

@ -0,0 +1,87 @@
"""Authentication schemas for Mantimon TCG.
This module defines Pydantic models for JWT tokens and authentication
responses used throughout the auth system.
Example:
token_payload = TokenPayload(
sub="550e8400-e29b-41d4-a716-446655440000",
exp=datetime.now(UTC) + timedelta(minutes=30),
iat=datetime.now(UTC),
type="access"
)
"""
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field
class TokenType(str, Enum):
"""Type of JWT token."""
ACCESS = "access"
REFRESH = "refresh"
class TokenPayload(BaseModel):
"""JWT token payload structure.
Attributes:
sub: Subject - the user ID as a string (UUID format).
exp: Expiration timestamp.
iat: Issued-at timestamp.
type: Token type (access or refresh).
jti: JWT ID - unique identifier for refresh token tracking.
"""
sub: str = Field(..., description="User ID (UUID as string)")
exp: datetime = Field(..., description="Token expiration timestamp")
iat: datetime = Field(..., description="Token issued-at timestamp")
type: TokenType = Field(..., description="Token type (access/refresh)")
jti: str | None = Field(default=None, description="JWT ID for refresh token tracking")
class TokenResponse(BaseModel):
"""Response containing JWT tokens after successful authentication.
Attributes:
access_token: Short-lived JWT for API authentication.
refresh_token: Longer-lived JWT for obtaining new access tokens.
token_type: Always "bearer" for OAuth 2.0 compatibility.
expires_in: Access token lifetime in seconds.
"""
access_token: str = Field(..., description="JWT access token")
refresh_token: str = Field(..., description="JWT refresh token")
token_type: str = Field(default="bearer", description="Token type (always bearer)")
expires_in: int = Field(..., description="Access token lifetime in seconds")
class RefreshTokenRequest(BaseModel):
"""Request to refresh an access token.
Attributes:
refresh_token: The refresh token to exchange for a new access token.
"""
refresh_token: str = Field(..., description="Refresh token to exchange")
class OAuthState(BaseModel):
"""OAuth state parameter for CSRF protection.
Stored in Redis with short TTL during OAuth flow.
Attributes:
state: Random string for CSRF protection.
redirect_uri: Where to redirect after OAuth callback.
provider: OAuth provider name.
created_at: When the state was created.
"""
state: str = Field(..., description="Random CSRF protection string")
redirect_uri: str = Field(..., description="Post-auth redirect URI")
provider: str = Field(..., description="OAuth provider (google, discord)")
created_at: datetime = Field(..., description="State creation timestamp")

114
backend/app/schemas/user.py Normal file
View File

@ -0,0 +1,114 @@
"""User schemas for Mantimon TCG.
This module defines Pydantic models for user-related API requests
and responses.
Example:
user_response = UserResponse(
id="550e8400-e29b-41d4-a716-446655440000",
email="player@example.com",
display_name="Player1",
is_premium=False,
created_at=datetime.now(UTC)
)
"""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, EmailStr, Field
class UserResponse(BaseModel):
"""Public user information returned by API endpoints.
Attributes:
id: User's unique identifier.
email: User's email address.
display_name: User's public display name.
avatar_url: URL to user's avatar image.
is_premium: Whether user has active premium subscription.
premium_until: When premium subscription expires (if premium).
created_at: When the account was created.
"""
id: UUID = Field(..., description="User ID")
email: EmailStr = Field(..., description="User's email address")
display_name: str = Field(..., description="Public display name")
avatar_url: str | None = Field(default=None, description="Avatar image URL")
is_premium: bool = Field(default=False, description="Premium subscription status")
premium_until: datetime | None = Field(default=None, description="Premium expiration date")
created_at: datetime = Field(..., description="Account creation date")
model_config = {"from_attributes": True}
class UserCreate(BaseModel):
"""Internal schema for creating a user from OAuth data.
Not exposed via API - used internally by auth service.
Attributes:
email: User's email from OAuth provider.
display_name: User's name from OAuth provider.
avatar_url: Avatar URL from OAuth provider.
oauth_provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider.
"""
email: EmailStr = Field(..., description="Email from OAuth provider")
display_name: str = Field(..., max_length=50, description="Display name")
avatar_url: str | None = Field(default=None, description="Avatar URL")
oauth_provider: str = Field(..., description="OAuth provider (google, discord)")
oauth_id: str = Field(..., description="Unique ID from OAuth provider")
class UserUpdate(BaseModel):
"""Schema for updating user profile.
All fields are optional - only provided fields are updated.
Attributes:
display_name: New display name.
avatar_url: New avatar URL.
"""
display_name: str | None = Field(
default=None, min_length=1, max_length=50, description="New display name"
)
avatar_url: str | None = Field(default=None, description="New avatar URL")
class OAuthUserInfo(BaseModel):
"""Normalized user information from OAuth providers.
This provides a consistent structure regardless of whether
the user authenticated via Google or Discord.
Attributes:
provider: OAuth provider name.
oauth_id: Unique ID from the provider.
email: User's email address.
name: User's display name from provider.
avatar_url: Avatar URL from provider.
"""
provider: str = Field(..., description="OAuth provider (google, discord)")
oauth_id: str = Field(..., description="Unique ID from provider")
email: EmailStr = Field(..., description="User's email")
name: str = Field(..., description="User's name from provider")
avatar_url: str | None = Field(default=None, description="Avatar URL from provider")
def to_user_create(self) -> UserCreate:
"""Convert OAuth info to UserCreate schema.
Returns:
UserCreate instance ready for user creation.
"""
return UserCreate(
email=self.email,
display_name=self.name[:50], # Enforce max length
avatar_url=self.avatar_url,
oauth_provider=self.provider,
oauth_id=self.oauth_id,
)

View File

@ -0,0 +1,206 @@
"""JWT token service for Mantimon TCG.
This module provides functions for creating and verifying JWT tokens
used in the authentication system.
Token Types:
- Access tokens: Short-lived (30 min default), used for API authentication
- Refresh tokens: Longer-lived (7 days default), used to obtain new access tokens
Example:
from app.services.jwt_service import create_access_token, verify_token
# Create tokens
access_token = create_access_token(user_id)
refresh_token, jti = create_refresh_token(user_id)
# Verify token
user_id = verify_token(access_token)
if user_id:
print(f"Valid token for user {user_id}")
"""
import uuid
from datetime import UTC, datetime, timedelta
from jose import JWTError, jwt
from app.config import settings
from app.schemas.auth import TokenPayload, TokenType
def create_access_token(user_id: uuid.UUID) -> str:
"""Create a short-lived JWT access token.
Args:
user_id: The user's UUID to encode in the token.
Returns:
Encoded JWT access token string.
Example:
token = create_access_token(user.id)
# Use token in Authorization header: Bearer {token}
"""
now = datetime.now(UTC)
expires = now + timedelta(minutes=settings.jwt_expire_minutes)
payload = {
"sub": str(user_id),
"exp": expires,
"iat": now,
"type": TokenType.ACCESS.value,
}
return jwt.encode(
payload,
settings.secret_key.get_secret_value(),
algorithm=settings.jwt_algorithm,
)
def create_refresh_token(user_id: uuid.UUID) -> tuple[str, str]:
"""Create a longer-lived JWT refresh token.
Refresh tokens include a JTI (JWT ID) for tracking and revocation.
Args:
user_id: The user's UUID to encode in the token.
Returns:
Tuple of (encoded JWT refresh token, JTI string).
The JTI should be stored in Redis for revocation tracking.
Example:
token, jti = create_refresh_token(user.id)
await token_store.store_refresh_token(user.id, jti, expires_at)
"""
now = datetime.now(UTC)
expires = now + timedelta(days=settings.jwt_refresh_expire_days)
jti = str(uuid.uuid4())
payload = {
"sub": str(user_id),
"exp": expires,
"iat": now,
"type": TokenType.REFRESH.value,
"jti": jti,
}
token = jwt.encode(
payload,
settings.secret_key.get_secret_value(),
algorithm=settings.jwt_algorithm,
)
return token, jti
def decode_token(token: str) -> TokenPayload | None:
"""Decode and validate a JWT token.
Args:
token: The JWT token string to decode.
Returns:
TokenPayload if valid, None if invalid or expired.
Example:
payload = decode_token(token)
if payload:
user_id = UUID(payload.sub)
"""
try:
payload_dict = jwt.decode(
token,
settings.secret_key.get_secret_value(),
algorithms=[settings.jwt_algorithm],
)
return TokenPayload(**payload_dict)
except JWTError:
return None
def verify_access_token(token: str) -> uuid.UUID | None:
"""Verify an access token and extract the user ID.
Args:
token: The JWT access token to verify.
Returns:
User UUID if valid access token, None otherwise.
Example:
user_id = verify_access_token(token)
if user_id:
user = await user_service.get_user_by_id(db, user_id)
"""
payload = decode_token(token)
if payload is None:
return None
if payload.type != TokenType.ACCESS:
return None
try:
return uuid.UUID(payload.sub)
except ValueError:
return None
def verify_refresh_token(token: str) -> tuple[uuid.UUID, str] | None:
"""Verify a refresh token and extract user ID and JTI.
The JTI should be checked against the token store to ensure
the token hasn't been revoked.
Args:
token: The JWT refresh token to verify.
Returns:
Tuple of (user UUID, JTI) if valid refresh token, None otherwise.
Example:
result = verify_refresh_token(token)
if result:
user_id, jti = result
if await token_store.is_token_valid(user_id, jti):
# Issue new access token
"""
payload = decode_token(token)
if payload is None:
return None
if payload.type != TokenType.REFRESH:
return None
if payload.jti is None:
return None
try:
user_id = uuid.UUID(payload.sub)
return user_id, payload.jti
except ValueError:
return None
def get_token_expiration_seconds() -> int:
"""Get the access token expiration time in seconds.
Useful for the expires_in field in token responses.
Returns:
Access token lifetime in seconds.
"""
return settings.jwt_expire_minutes * 60
def get_refresh_token_expiration() -> datetime:
"""Get the expiration datetime for a new refresh token.
Useful for setting TTL when storing in Redis.
Returns:
Datetime when a refresh token created now would expire.
"""
return datetime.now(UTC) + timedelta(days=settings.jwt_refresh_expire_days)

View File

@ -0,0 +1,25 @@
"""OAuth provider services for Mantimon TCG.
This package contains OAuth integration services for supported providers.
Providers:
- Google: OAuth 2.0 with Google accounts
- Discord: OAuth 2.0 with Discord accounts
Example:
from app.services.oauth import google_oauth, discord_oauth
# Get authorization URL
auth_url = google_oauth.get_authorization_url(redirect_uri, state)
# Exchange code for user info
user_info = await google_oauth.get_user_info(code, redirect_uri)
"""
from app.services.oauth.discord import discord_oauth
from app.services.oauth.google import google_oauth
__all__ = [
"google_oauth",
"discord_oauth",
]

View File

@ -0,0 +1,242 @@
"""Discord OAuth service for Mantimon TCG.
This module handles Discord OAuth 2.0 authentication flow:
1. Generate authorization URL for user redirect
2. Exchange authorization code for tokens
3. Fetch user information from Discord
Discord OAuth Endpoints:
- Authorization: https://discord.com/api/oauth2/authorize
- Token: https://discord.com/api/oauth2/token
- User Info: https://discord.com/api/users/@me
Example:
from app.services.oauth.discord import discord_oauth
# Step 1: Redirect user to Discord
auth_url = discord_oauth.get_authorization_url(
redirect_uri="https://play.mantimon.com/api/auth/discord/callback",
state="random-csrf-token"
)
# Step 2: Handle callback and get user info
user_info = await discord_oauth.get_user_info(code, redirect_uri)
"""
from urllib.parse import urlencode
import httpx
from app.config import settings
from app.schemas.user import OAuthUserInfo
class DiscordOAuthError(Exception):
"""Exception raised for Discord OAuth errors."""
pass
class DiscordOAuth:
"""Discord OAuth 2.0 service.
Handles the OAuth flow for authenticating users with Discord accounts.
"""
AUTHORIZATION_URL = "https://discord.com/api/oauth2/authorize"
TOKEN_URL = "https://discord.com/api/oauth2/token"
USER_INFO_URL = "https://discord.com/api/users/@me"
CDN_URL = "https://cdn.discordapp.com"
# Scopes we request from Discord
SCOPES = ["identify", "email"]
def get_authorization_url(self, redirect_uri: str, state: str) -> str:
"""Generate the Discord OAuth authorization URL.
Args:
redirect_uri: Where Discord should redirect after authorization.
state: Random string for CSRF protection.
Returns:
Full authorization URL to redirect user to.
Raises:
DiscordOAuthError: If Discord OAuth is not configured.
Example:
url = discord_oauth.get_authorization_url(
redirect_uri="https://play.mantimon.com/api/auth/discord/callback",
state="abc123"
)
# Redirect user to url
"""
if not settings.discord_client_id:
raise DiscordOAuthError("Discord OAuth is not configured")
params = {
"client_id": settings.discord_client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"scope": " ".join(self.SCOPES),
"state": state,
"prompt": "consent", # Always show consent screen
}
return f"{self.AUTHORIZATION_URL}?{urlencode(params)}"
async def exchange_code_for_tokens(
self,
code: str,
redirect_uri: str,
) -> dict:
"""Exchange authorization code for access tokens.
Args:
code: Authorization code from Discord callback.
redirect_uri: Same redirect_uri used in authorization request.
Returns:
Token response containing access_token, refresh_token, etc.
Raises:
DiscordOAuthError: If token exchange fails.
"""
if not settings.discord_client_id or not settings.discord_client_secret:
raise DiscordOAuthError("Discord OAuth is not configured")
data = {
"client_id": settings.discord_client_id,
"client_secret": settings.discord_client_secret.get_secret_value(),
"code": code,
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
}
async with httpx.AsyncClient() as client:
response = await client.post(
self.TOKEN_URL,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if response.status_code != 200:
error_data = response.json() if response.content else {}
error_msg = error_data.get("error_description", response.text)
raise DiscordOAuthError(f"Token exchange failed: {error_msg}")
return response.json()
async def fetch_user_info(self, access_token: str) -> dict:
"""Fetch user information from Discord.
Args:
access_token: Valid Discord access token.
Returns:
User info dict with id, username, email, avatar, etc.
Raises:
DiscordOAuthError: If fetching user info fails.
"""
async with httpx.AsyncClient() as client:
response = await client.get(
self.USER_INFO_URL,
headers={"Authorization": f"Bearer {access_token}"},
)
if response.status_code != 200:
raise DiscordOAuthError(f"Failed to fetch user info: {response.text}")
return response.json()
def _build_avatar_url(self, user_id: str, avatar_hash: str | None) -> str | None:
"""Build Discord avatar URL from user ID and avatar hash.
Args:
user_id: Discord user ID.
avatar_hash: Avatar hash from Discord API (can be None).
Returns:
Full CDN URL for avatar, or None if no avatar.
Note:
Discord avatar format: https://cdn.discordapp.com/avatars/{user_id}/{avatar_hash}.png
If avatar_hash starts with 'a_', it's animated (gif).
"""
if not avatar_hash:
return None
# Animated avatars start with 'a_'
extension = "gif" if avatar_hash.startswith("a_") else "png"
return f"{self.CDN_URL}/avatars/{user_id}/{avatar_hash}.{extension}"
async def get_user_info(
self,
code: str,
redirect_uri: str,
) -> OAuthUserInfo:
"""Complete OAuth flow: exchange code and fetch user info.
This is the main method to call after receiving the OAuth callback.
It exchanges the authorization code for tokens, then fetches user info.
Args:
code: Authorization code from Discord callback.
redirect_uri: Same redirect_uri used in authorization request.
Returns:
Normalized OAuthUserInfo ready for user creation/lookup.
Raises:
DiscordOAuthError: If any step of the OAuth flow fails.
Example:
# In your callback handler:
user_info = await discord_oauth.get_user_info(code, redirect_uri)
user, created = await user_service.get_or_create_from_oauth(db, user_info)
"""
# Exchange code for tokens
tokens = await self.exchange_code_for_tokens(code, redirect_uri)
access_token = tokens.get("access_token")
if not access_token:
raise DiscordOAuthError("No access token in response")
# Fetch user info
user_data = await self.fetch_user_info(access_token)
# Discord requires email scope, but email can still be None if not verified
email = user_data.get("email")
if not email:
raise DiscordOAuthError("Discord account does not have a verified email")
# Build display name: prefer global_name, then username
display_name = user_data.get("global_name") or user_data["username"]
# Build avatar URL
avatar_url = self._build_avatar_url(
user_data["id"],
user_data.get("avatar"),
)
# Normalize to our schema
return OAuthUserInfo(
provider="discord",
oauth_id=user_data["id"],
email=email,
name=display_name,
avatar_url=avatar_url,
)
def is_configured(self) -> bool:
"""Check if Discord OAuth is properly configured.
Returns:
True if client ID and secret are set.
"""
return bool(settings.discord_client_id and settings.discord_client_secret)
# Global instance
discord_oauth = DiscordOAuth()

View File

@ -0,0 +1,207 @@
"""Google OAuth service for Mantimon TCG.
This module handles Google OAuth 2.0 authentication flow:
1. Generate authorization URL for user redirect
2. Exchange authorization code for tokens
3. Fetch user information from Google
Google OAuth Endpoints:
- Authorization: https://accounts.google.com/o/oauth2/v2/auth
- Token: https://oauth2.googleapis.com/token
- User Info: https://www.googleapis.com/oauth2/v2/userinfo
Example:
from app.services.oauth.google import google_oauth
# Step 1: Redirect user to Google
auth_url = google_oauth.get_authorization_url(
redirect_uri="https://play.mantimon.com/api/auth/google/callback",
state="random-csrf-token"
)
# Step 2: Handle callback and get user info
user_info = await google_oauth.get_user_info(code, redirect_uri)
"""
from urllib.parse import urlencode
import httpx
from app.config import settings
from app.schemas.user import OAuthUserInfo
class GoogleOAuthError(Exception):
"""Exception raised for Google OAuth errors."""
pass
class GoogleOAuth:
"""Google OAuth 2.0 service.
Handles the OAuth flow for authenticating users with Google accounts.
"""
AUTHORIZATION_URL = "https://accounts.google.com/o/oauth2/v2/auth"
TOKEN_URL = "https://oauth2.googleapis.com/token"
USER_INFO_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
# Scopes we request from Google
SCOPES = ["openid", "email", "profile"]
def get_authorization_url(self, redirect_uri: str, state: str) -> str:
"""Generate the Google OAuth authorization URL.
Args:
redirect_uri: Where Google should redirect after authorization.
state: Random string for CSRF protection.
Returns:
Full authorization URL to redirect user to.
Raises:
GoogleOAuthError: If Google OAuth is not configured.
Example:
url = google_oauth.get_authorization_url(
redirect_uri="https://play.mantimon.com/api/auth/google/callback",
state="abc123"
)
# Redirect user to url
"""
if not settings.google_client_id:
raise GoogleOAuthError("Google OAuth is not configured")
params = {
"client_id": settings.google_client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"scope": " ".join(self.SCOPES),
"state": state,
"access_type": "offline", # Get refresh token
"prompt": "select_account", # Always show account picker
}
return f"{self.AUTHORIZATION_URL}?{urlencode(params)}"
async def exchange_code_for_tokens(
self,
code: str,
redirect_uri: str,
) -> dict:
"""Exchange authorization code for access tokens.
Args:
code: Authorization code from Google callback.
redirect_uri: Same redirect_uri used in authorization request.
Returns:
Token response containing access_token, id_token, etc.
Raises:
GoogleOAuthError: If token exchange fails.
"""
if not settings.google_client_id or not settings.google_client_secret:
raise GoogleOAuthError("Google OAuth is not configured")
data = {
"client_id": settings.google_client_id,
"client_secret": settings.google_client_secret.get_secret_value(),
"code": code,
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
}
async with httpx.AsyncClient() as client:
response = await client.post(
self.TOKEN_URL,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if response.status_code != 200:
error_data = response.json() if response.content else {}
error_msg = error_data.get("error_description", response.text)
raise GoogleOAuthError(f"Token exchange failed: {error_msg}")
return response.json()
async def fetch_user_info(self, access_token: str) -> dict:
"""Fetch user information from Google.
Args:
access_token: Valid Google access token.
Returns:
User info dict with id, email, name, picture, etc.
Raises:
GoogleOAuthError: If fetching user info fails.
"""
async with httpx.AsyncClient() as client:
response = await client.get(
self.USER_INFO_URL,
headers={"Authorization": f"Bearer {access_token}"},
)
if response.status_code != 200:
raise GoogleOAuthError(f"Failed to fetch user info: {response.text}")
return response.json()
async def get_user_info(
self,
code: str,
redirect_uri: str,
) -> OAuthUserInfo:
"""Complete OAuth flow: exchange code and fetch user info.
This is the main method to call after receiving the OAuth callback.
It exchanges the authorization code for tokens, then fetches user info.
Args:
code: Authorization code from Google callback.
redirect_uri: Same redirect_uri used in authorization request.
Returns:
Normalized OAuthUserInfo ready for user creation/lookup.
Raises:
GoogleOAuthError: If any step of the OAuth flow fails.
Example:
# In your callback handler:
user_info = await google_oauth.get_user_info(code, redirect_uri)
user, created = await user_service.get_or_create_from_oauth(db, user_info)
"""
# Exchange code for tokens
tokens = await self.exchange_code_for_tokens(code, redirect_uri)
access_token = tokens.get("access_token")
if not access_token:
raise GoogleOAuthError("No access token in response")
# Fetch user info
user_data = await self.fetch_user_info(access_token)
# Normalize to our schema
return OAuthUserInfo(
provider="google",
oauth_id=user_data["id"],
email=user_data["email"],
name=user_data.get("name", user_data["email"].split("@")[0]),
avatar_url=user_data.get("picture"),
)
def is_configured(self) -> bool:
"""Check if Google OAuth is properly configured.
Returns:
True if client ID and secret are set.
"""
return bool(settings.google_client_id and settings.google_client_secret)
# Global instance
google_oauth = GoogleOAuth()

View File

@ -0,0 +1,195 @@
"""Refresh token storage for Mantimon TCG.
This module provides Redis-based storage for refresh token tracking
and revocation. Each refresh token's JTI is stored in Redis with
a TTL matching the token's expiration.
Key Pattern:
refresh_token:{user_id}:{jti} -> "1" (exists = valid)
Revocation:
- Single token: Delete the specific key
- All user tokens: Delete all keys matching refresh_token:{user_id}:*
Example:
from app.services.token_store import token_store
# Store a new refresh token
await token_store.store_refresh_token(user_id, jti, expires_at)
# Check if token is valid (not revoked)
if await token_store.is_token_valid(user_id, jti):
# Issue new access token
# Revoke on logout
await token_store.revoke_token(user_id, jti)
# Logout from all devices
await token_store.revoke_all_user_tokens(user_id)
"""
from datetime import UTC, datetime
from uuid import UUID
from app.db.redis import get_redis
class TokenStore:
"""Redis-based refresh token storage for revocation support.
Tracks valid refresh tokens by storing their JTIs in Redis.
Tokens can be revoked individually or all at once per user.
"""
KEY_PREFIX = "refresh_token"
def _make_key(self, user_id: UUID, jti: str) -> str:
"""Create Redis key for a refresh token.
Args:
user_id: The user's UUID.
jti: The token's unique identifier.
Returns:
Redis key string.
"""
return f"{self.KEY_PREFIX}:{user_id}:{jti}"
def _make_user_pattern(self, user_id: UUID) -> str:
"""Create Redis key pattern for all user's tokens.
Args:
user_id: The user's UUID.
Returns:
Redis key pattern for SCAN/KEYS.
"""
return f"{self.KEY_PREFIX}:{user_id}:*"
async def store_refresh_token(
self,
user_id: UUID,
jti: str,
expires_at: datetime,
) -> None:
"""Store a refresh token's JTI in Redis.
Args:
user_id: The user's UUID.
jti: The token's unique identifier (from JWT).
expires_at: When the token expires (for TTL calculation).
Example:
expires_at = datetime.now(UTC) + timedelta(days=7)
await token_store.store_refresh_token(user_id, jti, expires_at)
"""
key = self._make_key(user_id, jti)
# Calculate TTL in seconds
now = datetime.now(UTC)
ttl_seconds = int((expires_at - now).total_seconds())
if ttl_seconds <= 0:
# Token already expired, don't store
return
async with get_redis() as redis:
await redis.setex(key, ttl_seconds, "1")
async def is_token_valid(self, user_id: UUID, jti: str) -> bool:
"""Check if a refresh token is valid (not revoked).
Args:
user_id: The user's UUID.
jti: The token's unique identifier.
Returns:
True if token exists in store (valid), False if revoked or expired.
Example:
if await token_store.is_token_valid(user_id, jti):
# Token is valid, issue new access token
else:
# Token was revoked, require re-authentication
"""
key = self._make_key(user_id, jti)
async with get_redis() as redis:
result = await redis.exists(key)
return result > 0
async def revoke_token(self, user_id: UUID, jti: str) -> bool:
"""Revoke a specific refresh token.
Args:
user_id: The user's UUID.
jti: The token's unique identifier.
Returns:
True if token was revoked, False if it didn't exist.
Example:
# On logout
await token_store.revoke_token(user_id, jti)
"""
key = self._make_key(user_id, jti)
async with get_redis() as redis:
result = await redis.delete(key)
return result > 0
async def revoke_all_user_tokens(self, user_id: UUID) -> int:
"""Revoke all refresh tokens for a user.
Useful for "logout from all devices" or security incidents.
Args:
user_id: The user's UUID.
Returns:
Number of tokens revoked.
Example:
# Logout from all devices
count = await token_store.revoke_all_user_tokens(user_id)
print(f"Revoked {count} sessions")
"""
pattern = self._make_user_pattern(user_id)
async with get_redis() as redis:
# Use SCAN to find all matching keys (safer than KEYS for large datasets)
keys_to_delete = []
async for key in redis.scan_iter(match=pattern):
keys_to_delete.append(key)
if not keys_to_delete:
return 0
# Delete all found keys
result = await redis.delete(*keys_to_delete)
return result
async def get_active_session_count(self, user_id: UUID) -> int:
"""Get the number of active sessions (valid refresh tokens) for a user.
Args:
user_id: The user's UUID.
Returns:
Number of active sessions.
Example:
count = await token_store.get_active_session_count(user_id)
print(f"User has {count} active sessions")
"""
pattern = self._make_user_pattern(user_id)
async with get_redis() as redis:
count = 0
async for _ in redis.scan_iter(match=pattern):
count += 1
return count
# Global token store instance
token_store = TokenStore()

View File

@ -0,0 +1,439 @@
"""User service for Mantimon TCG.
This module provides async CRUD operations for user accounts,
including OAuth-based user creation and premium status management.
All database operations use async SQLAlchemy sessions.
Example:
from app.services.user_service import user_service
# Get user by ID
user = await user_service.get_by_id(db, user_id)
# Create from OAuth
user = await user_service.create_from_oauth(db, oauth_info)
# Update premium status
user = await user_service.update_premium(db, user_id, premium_until)
"""
from datetime import UTC, datetime
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.oauth_account import OAuthLinkedAccount
from app.db.models.user import User
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
class AccountLinkingError(Exception):
"""Error during account linking operation."""
pass
class UserService:
"""Service for user account operations.
Provides async methods for user CRUD, OAuth-based creation,
and premium subscription management.
"""
async def get_by_id(self, db: AsyncSession, user_id: UUID) -> User | None:
"""Get a user by their ID.
Args:
db: Async database session.
user_id: The user's UUID.
Returns:
User if found, None otherwise.
Example:
user = await user_service.get_by_id(db, user_id)
if user:
print(f"Found user: {user.display_name}")
"""
result = await db.execute(select(User).where(User.id == user_id))
return result.scalar_one_or_none()
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
"""Get a user by their email address.
Args:
db: Async database session.
email: The user's email address.
Returns:
User if found, None otherwise.
Example:
user = await user_service.get_by_email(db, "player@example.com")
"""
result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none()
async def get_by_oauth(
self,
db: AsyncSession,
provider: str,
oauth_id: str,
) -> User | None:
"""Get a user by their OAuth provider and ID.
Args:
db: Async database session.
provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider.
Returns:
User if found, None otherwise.
Example:
user = await user_service.get_by_oauth(db, "google", "123456789")
"""
result = await db.execute(
select(User).where(
User.oauth_provider == provider,
User.oauth_id == oauth_id,
)
)
return result.scalar_one_or_none()
async def create(self, db: AsyncSession, user_data: UserCreate) -> User:
"""Create a new user.
Args:
db: Async database session.
user_data: User creation data.
Returns:
The created User instance.
Example:
user_data = UserCreate(
email="player@example.com",
display_name="Player1",
oauth_provider="google",
oauth_id="123456789"
)
user = await user_service.create(db, user_data)
"""
user = User(
email=user_data.email,
display_name=user_data.display_name,
avatar_url=user_data.avatar_url,
oauth_provider=user_data.oauth_provider,
oauth_id=user_data.oauth_id,
)
db.add(user)
await db.commit()
await db.refresh(user)
return user
async def create_from_oauth(
self,
db: AsyncSession,
oauth_info: OAuthUserInfo,
) -> User:
"""Create a new user from OAuth provider info.
Convenience method that converts OAuthUserInfo to UserCreate.
Args:
db: Async database session.
oauth_info: Normalized OAuth user information.
Returns:
The created User instance.
Example:
oauth_info = OAuthUserInfo(
provider="google",
oauth_id="123456789",
email="player@example.com",
name="Player One",
avatar_url="https://..."
)
user = await user_service.create_from_oauth(db, oauth_info)
"""
user_data = oauth_info.to_user_create()
return await self.create(db, user_data)
async def get_or_create_from_oauth(
self,
db: AsyncSession,
oauth_info: OAuthUserInfo,
) -> tuple[User, bool]:
"""Get existing user or create new one from OAuth info.
First checks for existing user by OAuth provider+ID, then by email
(for account linking), and finally creates a new user if not found.
Args:
db: Async database session.
oauth_info: Normalized OAuth user information.
Returns:
Tuple of (User, created) where created is True if new user.
Example:
user, created = await user_service.get_or_create_from_oauth(db, oauth_info)
if created:
print("Welcome, new user!")
else:
print("Welcome back!")
"""
# First, check by OAuth provider + ID (exact match)
user = await self.get_by_oauth(db, oauth_info.provider, oauth_info.oauth_id)
if user:
return user, False
# Check by email for potential account linking
# If user exists with same email but different OAuth, update their OAuth
user = await self.get_by_email(db, oauth_info.email)
if user:
# Update OAuth credentials for existing user
# This links the new OAuth provider to the existing account
user.oauth_provider = oauth_info.provider
user.oauth_id = oauth_info.oauth_id
# Optionally update avatar if not set
if not user.avatar_url and oauth_info.avatar_url:
user.avatar_url = oauth_info.avatar_url
await db.commit()
await db.refresh(user)
return user, False
# Create new user
user = await self.create_from_oauth(db, oauth_info)
return user, True
async def update(
self,
db: AsyncSession,
user: User,
update_data: UserUpdate,
) -> User:
"""Update user profile fields.
Only updates fields that are provided (not None).
Args:
db: Async database session.
user: The user to update.
update_data: Fields to update.
Returns:
The updated User instance.
Example:
update_data = UserUpdate(display_name="New Name")
user = await user_service.update(db, user, update_data)
"""
if update_data.display_name is not None:
user.display_name = update_data.display_name
if update_data.avatar_url is not None:
user.avatar_url = update_data.avatar_url
await db.commit()
await db.refresh(user)
return user
async def update_last_login(self, db: AsyncSession, user: User) -> User:
"""Update the user's last login timestamp.
Args:
db: Async database session.
user: The user to update.
Returns:
The updated User instance.
Example:
user = await user_service.update_last_login(db, user)
"""
user.last_login = datetime.now(UTC)
await db.commit()
await db.refresh(user)
return user
async def update_premium(
self,
db: AsyncSession,
user: User,
premium_until: datetime | None,
) -> User:
"""Update user's premium subscription status.
Args:
db: Async database session.
user: The user to update.
premium_until: When premium expires, or None to remove premium.
Returns:
The updated User instance.
Example:
# Grant 30 days of premium
expires = datetime.now(UTC) + timedelta(days=30)
user = await user_service.update_premium(db, user, expires)
# Remove premium
user = await user_service.update_premium(db, user, None)
"""
if premium_until is not None:
user.is_premium = True
user.premium_until = premium_until
else:
user.is_premium = False
user.premium_until = None
await db.commit()
await db.refresh(user)
return user
async def delete(self, db: AsyncSession, user: User) -> None:
"""Delete a user account.
This will cascade delete all related data (decks, collection, etc.)
based on the model relationships.
Args:
db: Async database session.
user: The user to delete.
Example:
await user_service.delete(db, user)
"""
await db.delete(user)
await db.commit()
async def get_linked_account(
self,
db: AsyncSession,
provider: str,
oauth_id: str,
) -> OAuthLinkedAccount | None:
"""Get a linked account by provider and OAuth ID.
Args:
db: Async database session.
provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider.
Returns:
OAuthLinkedAccount if found, None otherwise.
"""
result = await db.execute(
select(OAuthLinkedAccount).where(
OAuthLinkedAccount.provider == provider,
OAuthLinkedAccount.oauth_id == oauth_id,
)
)
return result.scalar_one_or_none()
async def link_oauth_account(
self,
db: AsyncSession,
user: User,
oauth_info: OAuthUserInfo,
) -> OAuthLinkedAccount:
"""Link an additional OAuth provider to a user account.
Args:
db: Async database session.
user: The user to link the account to.
oauth_info: OAuth information from the provider.
Returns:
The created OAuthLinkedAccount.
Raises:
AccountLinkingError: If provider is already linked to this or another user.
Example:
linked = await user_service.link_oauth_account(db, user, discord_info)
"""
# Check if this provider+oauth_id is already linked to any user
existing = await self.get_linked_account(db, oauth_info.provider, oauth_info.oauth_id)
if existing:
if str(existing.user_id) == str(user.id):
raise AccountLinkingError(
f"{oauth_info.provider.title()} account is already linked to your account"
)
raise AccountLinkingError(
f"This {oauth_info.provider.title()} account is already linked to another user"
)
# Check if this is the user's primary OAuth provider
if user.oauth_provider == oauth_info.provider:
raise AccountLinkingError(
f"{oauth_info.provider.title()} is your primary login provider"
)
# Check if user already has this provider linked
for linked in user.linked_accounts:
if linked.provider == oauth_info.provider:
raise AccountLinkingError(
f"You already have a {oauth_info.provider.title()} account linked"
)
# Create the linked account
linked_account = OAuthLinkedAccount(
user_id=str(user.id),
provider=oauth_info.provider,
oauth_id=oauth_info.oauth_id,
email=oauth_info.email,
display_name=oauth_info.name,
avatar_url=oauth_info.avatar_url,
)
db.add(linked_account)
await db.commit()
await db.refresh(linked_account)
return linked_account
async def unlink_oauth_account(
self,
db: AsyncSession,
user: User,
provider: str,
) -> bool:
"""Unlink an OAuth provider from a user account.
Cannot unlink the primary OAuth provider.
Args:
db: Async database session.
user: The user to unlink from.
provider: OAuth provider name to unlink.
Returns:
True if unlinked, False if provider wasn't linked.
Raises:
AccountLinkingError: If trying to unlink the primary provider.
Example:
success = await user_service.unlink_oauth_account(db, user, "discord")
"""
# Cannot unlink primary provider
if user.oauth_provider == provider:
raise AccountLinkingError(
f"Cannot unlink {provider.title()} - it is your primary login provider"
)
# Find and delete the linked account
for linked in user.linked_accounts:
if linked.provider == provider:
await db.delete(linked)
await db.commit()
return True
return False
# Global service instance
user_service = UserService()

View File

@ -43,6 +43,18 @@ services:
timeout: 5s
retries: 5
restart: unless-stopped
adminer:
image: adminer:latest
restart: unless-stopped
container_name: mantimon-adminer
ports:
- "8090:8080"
environment:
- ADMINER_DEFAULT_SERVER=mantimon-postgres
- TZ=America/Chicago
depends_on:
- postgres
volumes:
postgres_data:

View File

@ -0,0 +1,478 @@
{
"meta": {
"version": "1.0.0",
"created": "2026-01-27",
"lastUpdated": "2026-01-27",
"planType": "phase",
"phaseId": "PHASE_2",
"phaseName": "Authentication",
"description": "OAuth login (Google, Discord), JWT session management, user management, premium tier tracking",
"totalEstimatedHours": 24,
"totalTasks": 15,
"completedTasks": 15,
"status": "complete",
"masterPlan": "../PROJECT_PLAN_MASTER.json"
},
"goals": [
"Implement OAuth 2.0 authentication with Google and Discord providers",
"Create JWT-based session management with access/refresh token pattern",
"Build user management service with create, read, update operations",
"Implement FastAPI dependencies for protected endpoints",
"Support account linking (multiple OAuth providers per user)",
"Track premium subscription status with expiration dates"
],
"architectureNotes": {
"tokenStrategy": {
"accessToken": "Short-lived JWT (30 min default), contains user_id",
"refreshToken": "Longer-lived JWT (7 days default), stored in Redis for revocation",
"storage": "Refresh tokens tracked in Redis for logout/revocation support"
},
"oauthFlow": {
"pattern": "Authorization Code Flow (PKCE deferred)",
"callback": "Backend receives code, exchanges for tokens, creates/updates user",
"security": "Never store OAuth provider tokens, only OAuth ID"
},
"accountLinking": {
"strategy": "Email-based matching + explicit linking via OAuth flow",
"flow": "If user exists with same email, add OAuth provider to existing account. Users can also explicitly link additional providers via /auth/link/{provider}"
},
"existingInfrastructure": {
"config": "JWT, OAuth, and base_url settings in app/config.py Settings class",
"dependencies": "python-jose, passlib, bcrypt, httpx already installed",
"userModel": "User model with OAuth fields in app/db/models/user.py"
}
},
"directoryStructure": {
"schemas": "backend/app/schemas/",
"services": "backend/app/services/",
"api": "backend/app/api/",
"tests": "backend/tests/api/, backend/tests/services/"
},
"tasks": [
{
"id": "AUTH-001",
"name": "Create Pydantic schemas for auth",
"description": "Define request/response models for authentication flows",
"category": "critical",
"priority": 1,
"completed": true,
"dependencies": [],
"files": [
{"path": "app/schemas/__init__.py", "status": "complete"},
{"path": "app/schemas/auth.py", "status": "complete"},
{"path": "app/schemas/user.py", "status": "complete"}
],
"details": [
"TokenPayload: sub (user_id), exp, iat, type (access/refresh)",
"TokenResponse: access_token, refresh_token, token_type, expires_in",
"UserResponse: id, email, display_name, avatar_url, is_premium, premium_until",
"UserCreate: internal model for user creation from OAuth",
"OAuthUserInfo: normalized structure for OAuth provider data"
],
"estimatedHours": 1.5
},
{
"id": "AUTH-002",
"name": "Create JWT utilities service",
"description": "Functions for creating and verifying JWT tokens",
"category": "critical",
"priority": 2,
"completed": true,
"dependencies": ["AUTH-001"],
"files": [
{"path": "app/services/jwt_service.py", "status": "complete"}
],
"details": [
"create_access_token(user_id: UUID) -> str - Uses settings.jwt_expire_minutes",
"create_refresh_token(user_id: UUID) -> tuple[str, str] - Returns token and jti",
"verify_access_token(token: str) -> UUID | None - Returns user_id or None",
"verify_refresh_token(token: str) -> tuple[UUID, str] | None - Returns (user_id, jti) or None",
"Uses python-jose with HS256 algorithm",
"All timing uses datetime.now(UTC) per project standards"
],
"estimatedHours": 1.5
},
{
"id": "AUTH-003",
"name": "Create refresh token Redis storage",
"description": "Redis-based storage for refresh token tracking and revocation",
"category": "high",
"priority": 3,
"completed": true,
"dependencies": ["AUTH-002"],
"files": [
{"path": "app/services/token_store.py", "status": "complete"}
],
"details": [
"Key format: refresh_token:{user_id}:{jti} -> '1' with TTL",
"store_refresh_token(user_id, jti, expires_at) - Store with TTL",
"is_token_valid(user_id, jti) -> bool - Check if not revoked",
"revoke_token(user_id, jti) - Delete specific token",
"revoke_all_user_tokens(user_id) - Logout from all devices",
"get_active_session_count(user_id) - Count valid tokens",
"Uses existing Redis connection from app/db/redis.py"
],
"estimatedHours": 1.5
},
{
"id": "AUTH-004",
"name": "Create UserService",
"description": "Service layer for user CRUD operations",
"category": "critical",
"priority": 4,
"completed": true,
"dependencies": ["AUTH-001"],
"files": [
{"path": "app/services/user_service.py", "status": "complete"}
],
"details": [
"get_by_id(db, user_id: UUID) -> User | None",
"get_by_email(db, email: str) -> User | None",
"get_by_oauth(db, provider: str, oauth_id: str) -> User | None",
"create(db, user_data: UserCreate) -> User",
"create_from_oauth(db, oauth_info: OAuthUserInfo) -> User",
"get_or_create_from_oauth(db, oauth_info: OAuthUserInfo) -> tuple[User, bool]",
"update(db, user: User, update_data: UserUpdate) -> User",
"update_last_login(db, user: User) -> User",
"update_premium(db, user: User, premium_until: datetime | None) -> User",
"link_oauth_account(db, user: User, oauth_info: OAuthUserInfo) -> OAuthLinkedAccount",
"unlink_oauth_account(db, user: User, provider: str) -> bool",
"All operations are async using SQLAlchemy async session"
],
"estimatedHours": 2
},
{
"id": "AUTH-005",
"name": "Create OAuthLinkedAccount model",
"description": "Database model for multiple OAuth providers per user (account linking)",
"category": "high",
"priority": 5,
"completed": true,
"dependencies": [],
"files": [
{"path": "app/db/models/oauth_account.py", "status": "complete"},
{"path": "app/db/migrations/versions/5ce887128ab1_add_oauth_linked_accounts.py", "status": "complete"}
],
"details": [
"Fields: id, user_id (FK), provider, oauth_id, email, display_name, avatar_url, linked_at",
"Unique constraint on (provider, oauth_id)",
"User.oauth_provider/oauth_id kept as 'primary' provider",
"Relationship: User.linked_accounts -> list[OAuthLinkedAccount]"
],
"estimatedHours": 2
},
{
"id": "AUTH-006",
"name": "Create Google OAuth service",
"description": "Handle Google OAuth authorization code flow",
"category": "critical",
"priority": 6,
"completed": true,
"dependencies": ["AUTH-004"],
"files": [
{"path": "app/services/oauth/google.py", "status": "complete"},
{"path": "app/services/oauth/__init__.py", "status": "complete"}
],
"details": [
"get_authorization_url(redirect_uri, state) -> str",
"get_user_info(code, redirect_uri) -> OAuthUserInfo",
"is_configured() -> bool",
"Uses httpx for async HTTP requests",
"Google OAuth endpoints: accounts.google.com/o/oauth2/v2/auth, oauth2.googleapis.com/token",
"User info endpoint: www.googleapis.com/oauth2/v2/userinfo"
],
"estimatedHours": 2
},
{
"id": "AUTH-007",
"name": "Create Discord OAuth service",
"description": "Handle Discord OAuth authorization code flow",
"category": "critical",
"priority": 7,
"completed": true,
"dependencies": ["AUTH-004"],
"files": [
{"path": "app/services/oauth/discord.py", "status": "complete"}
],
"details": [
"get_authorization_url(redirect_uri, state) -> str",
"get_user_info(code, redirect_uri) -> OAuthUserInfo",
"is_configured() -> bool",
"Uses httpx for async HTTP requests",
"Discord OAuth endpoints: discord.com/oauth2/authorize, discord.com/api/oauth2/token",
"User info endpoint: discord.com/api/users/@me",
"Avatar URL construction from user ID and avatar hash"
],
"estimatedHours": 2
},
{
"id": "AUTH-008",
"name": "Create FastAPI auth dependencies",
"description": "Dependency injection for protected endpoints",
"category": "critical",
"priority": 8,
"completed": true,
"dependencies": ["AUTH-002", "AUTH-003", "AUTH-004"],
"files": [
{"path": "app/api/__init__.py", "status": "complete"},
{"path": "app/api/deps.py", "status": "complete"}
],
"details": [
"OAuth2PasswordBearer scheme for token extraction",
"get_current_user(token, db) -> User - Validates token, fetches user",
"CurrentUser type alias with Annotated for dependency injection",
"DbSession type alias for database dependency",
"Proper error responses: 401 Unauthorized"
],
"estimatedHours": 1.5
},
{
"id": "AUTH-009",
"name": "Create auth API router",
"description": "REST endpoints for OAuth login, token refresh, logout",
"category": "critical",
"priority": 9,
"completed": true,
"dependencies": ["AUTH-006", "AUTH-007", "AUTH-008"],
"files": [
{"path": "app/api/auth.py", "status": "complete"}
],
"details": [
"GET /auth/google - Redirects to Google OAuth consent screen",
"GET /auth/google/callback - Handles OAuth callback, returns tokens",
"GET /auth/discord - Redirects to Discord OAuth consent screen",
"GET /auth/discord/callback - Handles OAuth callback, returns tokens",
"GET /auth/link/google - Start account linking for Google (requires auth)",
"GET /auth/link/google/callback - Handle account linking callback",
"GET /auth/link/discord - Start account linking for Discord (requires auth)",
"GET /auth/link/discord/callback - Handle account linking callback",
"POST /auth/refresh - Exchange refresh token for new access token",
"POST /auth/logout - Revoke refresh token",
"POST /auth/logout-all - Revoke all refresh tokens (requires auth)",
"State parameter stored in Redis with short TTL for CSRF protection",
"Uses settings.base_url for absolute OAuth callback URLs"
],
"estimatedHours": 3
},
{
"id": "AUTH-010",
"name": "Create user API router",
"description": "REST endpoints for user profile and account management",
"category": "high",
"priority": 10,
"completed": true,
"dependencies": ["AUTH-008", "AUTH-004"],
"files": [
{"path": "app/api/users.py", "status": "complete"}
],
"details": [
"GET /users/me - Get current user profile",
"PATCH /users/me - Update display_name, avatar_url",
"GET /users/me/linked-accounts - List linked OAuth providers",
"DELETE /users/me/link/{provider} - Unlink OAuth provider",
"GET /users/me/sessions - Get active session count",
"All endpoints require authentication via CurrentUser dependency"
],
"estimatedHours": 2
},
{
"id": "AUTH-011",
"name": "Integrate routers in main.py",
"description": "Mount auth and user routers",
"category": "high",
"priority": 11,
"completed": true,
"dependencies": ["AUTH-009", "AUTH-010"],
"files": [
{"path": "app/main.py", "status": "complete"}
],
"details": [
"Include auth router: prefix='/api'",
"Include users router: prefix='/api'"
],
"estimatedHours": 0.5
},
{
"id": "AUTH-012",
"name": "Create JWT service tests",
"description": "Unit tests for token creation and verification",
"category": "high",
"priority": 12,
"completed": true,
"dependencies": ["AUTH-002"],
"files": [
{"path": "tests/services/test_jwt_service.py", "status": "complete"}
],
"details": [
"Test create_access_token returns valid JWT",
"Test create_refresh_token returns valid JWT with jti",
"Test verify_access_token extracts correct user_id",
"Test verify_refresh_token returns user_id and jti",
"Test expired tokens return None",
"Test invalid signatures return None",
"20 tests covering all token operations"
],
"estimatedHours": 1.5
},
{
"id": "AUTH-013",
"name": "Create UserService tests",
"description": "Integration tests for user CRUD operations",
"category": "high",
"priority": 13,
"completed": true,
"dependencies": ["AUTH-004"],
"files": [
{"path": "tests/services/test_user_service.py", "status": "complete"}
],
"details": [
"Test get_by_id returns user or None",
"Test get_by_email returns user or None",
"Test get_by_oauth finds by provider+oauth_id",
"Test create creates user with correct fields",
"Test create_from_oauth creates from OAuthUserInfo",
"Test get_or_create_from_oauth handles all scenarios",
"Test update updates profile fields",
"Test update_last_login updates timestamp",
"Test update_premium manages subscription status",
"Test link_oauth_account links new providers",
"Test unlink_oauth_account removes linked providers",
"29 tests using real Postgres via testcontainers"
],
"estimatedHours": 2
},
{
"id": "AUTH-014",
"name": "Create OAuth service tests",
"description": "Unit tests for OAuth flows with mocked HTTP",
"category": "high",
"priority": 14,
"completed": true,
"dependencies": ["AUTH-006", "AUTH-007"],
"files": [
{"path": "tests/services/oauth/test_google.py", "status": "complete"},
{"path": "tests/services/oauth/test_discord.py", "status": "complete"}
],
"details": [
"Mock httpx responses for token exchange",
"Mock httpx responses for user info",
"Test authorization URL construction",
"Test error handling for invalid codes",
"Test OAuthUserInfo normalization",
"10 tests for Google, 14 tests for Discord",
"Uses respx for httpx mocking"
],
"estimatedHours": 2
},
{
"id": "AUTH-015",
"name": "Create auth API endpoint tests",
"description": "Integration tests for auth endpoints",
"category": "high",
"priority": 15,
"completed": true,
"dependencies": ["AUTH-009", "AUTH-010"],
"files": [
{"path": "tests/api/__init__.py", "status": "complete"},
{"path": "tests/api/conftest.py", "status": "complete"},
{"path": "tests/api/test_auth.py", "status": "complete"},
{"path": "tests/api/test_users.py", "status": "complete"}
],
"details": [
"Test OAuth redirect returns correct URL",
"Test refresh endpoint returns new access token",
"Test logout revokes refresh token",
"Test /users/me returns current user",
"Test /users/me update works",
"Test /users/me/linked-accounts returns accounts",
"Test /users/me/sessions returns count",
"Test DELETE /users/me/link/{provider} unlinks account",
"10 tests for auth, 15 tests for users",
"Uses TestClient with dependency overrides and fakeredis"
],
"estimatedHours": 3
}
],
"testingStrategy": {
"approach": "Unit tests for services, integration tests for API endpoints",
"mocking": "httpx responses mocked with respx for OAuth providers, fakeredis for token store",
"database": "Real Postgres via testcontainers for service tests",
"coverage": "1072 total tests, 98 tests for auth system"
},
"acceptanceCriteria": [
{"criterion": "User can login with Google OAuth and receive JWT tokens", "met": true},
{"criterion": "User can login with Discord OAuth and receive JWT tokens", "met": true},
{"criterion": "Access tokens expire after configured time", "met": true},
{"criterion": "Refresh tokens can be used to get new access tokens", "met": true},
{"criterion": "Logout revokes refresh token (cannot be reused)", "met": true},
{"criterion": "Protected endpoints return 401 without valid token", "met": true},
{"criterion": "User can link multiple OAuth providers to one account", "met": true},
{"criterion": "Premium status is tracked with expiration date", "met": true},
{"criterion": "All tests pass with high coverage", "met": true}
],
"securityConsiderations": [
"OAuth state parameter validated to prevent CSRF attacks",
"OAuth provider tokens never stored (only OAuth ID)",
"JWT secret key loaded from environment, never hardcoded",
"Refresh tokens stored in Redis with TTL for revocation support",
"Access tokens short-lived (30 min) to limit exposure",
"OAuth callbacks use absolute URLs (base_url config setting)",
"HTTPS required in production for all auth endpoints"
],
"deferredItems": [
{
"item": "PKCE for OAuth",
"reason": "Not strictly required for server-side OAuth flow",
"priority": "low"
},
{
"item": "Rate limiting on auth endpoints",
"reason": "Can be added as infrastructure concern later",
"priority": "medium"
},
{
"item": "Refresh token rotation",
"reason": "Current implementation is secure; rotation adds complexity",
"priority": "low"
}
],
"dependencies": {
"existing": [
"python-jose>=3.5.0 (already installed)",
"passlib>=1.7.4 (already installed)",
"bcrypt>=5.0.0 (already installed)"
],
"added": [
"email-validator (for Pydantic EmailStr)",
"fakeredis (dev - Redis mocking in tests)",
"respx (dev - httpx mocking in tests)"
]
},
"phase1Prerequisites": {
"met": [
"JWT configuration in Settings class",
"OAuth configuration in Settings class",
"User model with OAuth fields",
"python-jose dependency installed",
"Database session management",
"Redis connection utilities"
]
},
"completionNotes": {
"totalNewTests": 98,
"totalTestsAfter": 1072,
"commit": "996c43f - Implement Phase 2: Authentication system",
"additionalCommitNeeded": "Fix OAuth absolute URLs and add account linking endpoints"
}
}

View File

@ -8,7 +8,9 @@ dependencies = [
"alembic>=1.18.1",
"asyncpg>=0.31.0",
"bcrypt>=5.0.0",
"email-validator>=2.3.0",
"fastapi>=0.128.0",
"httpx>=0.28.1",
"passlib>=1.7.4",
"psycopg2-binary>=2.9.11",
"pydantic>=2.12.5",
@ -24,12 +26,14 @@ dependencies = [
dev = [
"beautifulsoup4>=4.12.0",
"black>=26.1.0",
"fakeredis>=2.33.0",
"httpx>=0.28.1",
"mypy>=1.19.1",
"pytest>=9.0.2",
"pytest-asyncio>=1.3.0",
"pytest-cov>=7.0.0",
"requests>=2.31.0",
"respx>=0.22.0",
"ruff>=0.14.14",
"testcontainers[postgres,redis]>=4.0.0",
]

View File

@ -0,0 +1 @@
"""API endpoint tests."""

View File

@ -0,0 +1,133 @@
"""Test fixtures for API endpoint tests.
Provides fixtures for testing FastAPI endpoints with mocked dependencies.
"""
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch
from uuid import uuid4
import fakeredis.aioredis
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.api import deps as api_deps
from app.api.auth import router as auth_router
from app.api.users import router as users_router
from app.db.models import User
from app.services.jwt_service import create_access_token, create_refresh_token
@pytest.fixture
def fake_redis():
"""Provide a fake Redis instance for testing."""
return fakeredis.aioredis.FakeRedis(decode_responses=True)
@pytest.fixture
def mock_get_redis(fake_redis):
"""Mock the get_redis context manager to use fake Redis."""
@asynccontextmanager
async def _mock_get_redis():
yield fake_redis
return _mock_get_redis
@pytest.fixture
def test_user():
"""Create a test user object.
Returns a User model instance that can be used in tests.
The user is not persisted to database.
"""
user = User(
email="test@example.com",
display_name="Test User",
avatar_url="https://example.com/avatar.jpg",
oauth_provider="google",
oauth_id="google-123",
is_premium=False,
premium_until=None,
)
# Manually set the ID since we're not using database
user.id = str(uuid4())
user.created_at = datetime.now(UTC)
user.updated_at = datetime.now(UTC)
user.last_login = None
user.linked_accounts = []
return user
@pytest.fixture
def premium_user(test_user):
"""Create a premium test user."""
from datetime import timedelta
test_user.is_premium = True
test_user.premium_until = datetime.now(UTC) + timedelta(days=30)
return test_user
@pytest.fixture
def access_token(test_user):
"""Create a valid access token for the test user."""
from uuid import UUID
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
return create_access_token(user_id)
@pytest.fixture
def refresh_token_data(test_user):
"""Create a valid refresh token and JTI for the test user."""
from uuid import UUID
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
token, jti = create_refresh_token(user_id)
return {"token": token, "jti": jti, "user_id": user_id}
@pytest.fixture
def mock_db_session():
"""Create a mock database session."""
return MagicMock(spec=AsyncSession)
@pytest.fixture
def app(mock_get_redis, mock_db_session):
"""Create a test FastAPI app with mocked Redis and DB.
This creates a minimal app with just the auth and users routers,
with Redis and database mocked.
"""
# Create test app (no lifespan since we're mocking everything)
test_app = FastAPI()
test_app.include_router(auth_router, prefix="/api")
test_app.include_router(users_router, prefix="/api")
# Override get_db dependency to return mock session
async def override_get_db():
yield mock_db_session
test_app.dependency_overrides[api_deps.get_db] = override_get_db
# Patch get_redis globally for this app
with (
patch("app.api.auth.get_redis", mock_get_redis),
patch("app.services.token_store.get_redis", mock_get_redis),
):
yield test_app
# Clean up overrides
test_app.dependency_overrides.clear()
@pytest.fixture
def client(app):
"""Create a test client for the app."""
return TestClient(app)

View File

@ -0,0 +1,260 @@
"""Tests for auth API endpoints.
Tests the authentication endpoints including OAuth redirects,
token refresh, and logout.
"""
from unittest.mock import AsyncMock, patch
from uuid import UUID
from fastapi import status
from fastapi.testclient import TestClient
class TestGoogleAuthRedirect:
"""Tests for GET /api/auth/google endpoint."""
def test_returns_501_when_not_configured(self, client: TestClient):
"""Test that endpoint returns 501 when Google OAuth is not configured.
Without client credentials, OAuth flow cannot proceed.
"""
with patch("app.api.auth.google_oauth") as mock_oauth:
mock_oauth.is_configured.return_value = False
response = client.get(
"/api/auth/google",
params={"redirect_uri": "http://localhost/callback"},
follow_redirects=False,
)
assert response.status_code == status.HTTP_501_NOT_IMPLEMENTED
assert "not configured" in response.json()["detail"]
class TestDiscordAuthRedirect:
"""Tests for GET /api/auth/discord endpoint."""
def test_returns_501_when_not_configured(self, client: TestClient):
"""Test that endpoint returns 501 when Discord OAuth is not configured."""
with patch("app.api.auth.discord_oauth") as mock_oauth:
mock_oauth.is_configured.return_value = False
response = client.get(
"/api/auth/discord",
params={"redirect_uri": "http://localhost/callback"},
follow_redirects=False,
)
assert response.status_code == status.HTTP_501_NOT_IMPLEMENTED
assert "not configured" in response.json()["detail"]
class TestRefreshTokens:
"""Tests for POST /api/auth/refresh endpoint."""
def test_returns_new_access_token(
self, client: TestClient, test_user, refresh_token_data, mock_get_redis
):
"""Test that refresh endpoint returns new access token for valid refresh token.
A valid, non-revoked refresh token should yield a new access token.
"""
# Store the refresh token in fake Redis
import asyncio
async def setup_token():
async with mock_get_redis() as redis:
key = f"refresh_token:{refresh_token_data['user_id']}:{refresh_token_data['jti']}"
await redis.setex(key, 86400, "1")
asyncio.get_event_loop().run_until_complete(setup_token())
# Mock user service to return our test user
with patch("app.api.auth.user_service") as mock_user_service:
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
with (
patch("app.api.auth.get_redis", mock_get_redis),
patch("app.services.token_store.get_redis", mock_get_redis),
):
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token_data["token"]},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "access_token" in data
assert data["refresh_token"] == refresh_token_data["token"]
assert data["token_type"] == "bearer"
assert "expires_in" in data
def test_returns_401_for_invalid_token(self, client: TestClient):
"""Test that refresh endpoint returns 401 for invalid refresh token."""
response = client.post(
"/api/auth/refresh",
json={"refresh_token": "invalid.token.here"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_returns_401_for_revoked_token(
self, client: TestClient, refresh_token_data, mock_get_redis
):
"""Test that refresh endpoint returns 401 for revoked token.
A refresh token not in Redis (revoked/expired) should be rejected.
"""
# Don't store the token in Redis - simulating revocation
with (
patch("app.api.auth.get_redis", mock_get_redis),
patch("app.services.token_store.get_redis", mock_get_redis),
):
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token_data["token"]},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert "revoked" in response.json()["detail"]
def test_returns_401_for_deleted_user(
self, client: TestClient, refresh_token_data, mock_get_redis
):
"""Test that refresh endpoint returns 401 if user no longer exists."""
# Store the token
import asyncio
async def setup_token():
async with mock_get_redis() as redis:
key = f"refresh_token:{refresh_token_data['user_id']}:{refresh_token_data['jti']}"
await redis.setex(key, 86400, "1")
asyncio.get_event_loop().run_until_complete(setup_token())
# Mock user service to return None (user deleted)
with patch("app.api.auth.user_service") as mock_user_service:
mock_user_service.get_by_id = AsyncMock(return_value=None)
with (
patch("app.api.auth.get_redis", mock_get_redis),
patch("app.services.token_store.get_redis", mock_get_redis),
):
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token_data["token"]},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert "User not found" in response.json()["detail"]
class TestLogout:
"""Tests for POST /api/auth/logout endpoint."""
def test_revokes_token(self, client: TestClient, refresh_token_data, mock_get_redis):
"""Test that logout revokes the refresh token.
After logout, the token should no longer be in Redis.
"""
# Store the token first
import asyncio
async def setup_and_check():
async with mock_get_redis() as redis:
key = f"refresh_token:{refresh_token_data['user_id']}:{refresh_token_data['jti']}"
await redis.setex(key, 86400, "1")
return key
key = asyncio.get_event_loop().run_until_complete(setup_and_check())
# Logout
with (
patch("app.api.auth.get_redis", mock_get_redis),
patch("app.services.token_store.get_redis", mock_get_redis),
):
response = client.post(
"/api/auth/logout",
json={"refresh_token": refresh_token_data["token"]},
)
assert response.status_code == status.HTTP_204_NO_CONTENT
# Verify token is gone
async def verify_deleted():
async with mock_get_redis() as redis:
return await redis.exists(key)
exists = asyncio.get_event_loop().run_until_complete(verify_deleted())
assert exists == 0
def test_succeeds_for_invalid_token(self, client: TestClient):
"""Test that logout succeeds even for invalid tokens.
Invalid tokens are effectively "already logged out", so no error.
"""
response = client.post(
"/api/auth/logout",
json={"refresh_token": "invalid.token.here"},
)
assert response.status_code == status.HTTP_204_NO_CONTENT
class TestLogoutAll:
"""Tests for POST /api/auth/logout-all endpoint."""
def test_requires_authentication(self, client: TestClient):
"""Test that logout-all requires a valid access token.
Without authentication, endpoint should return 401.
"""
response = client.post("/api/auth/logout-all")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_revokes_all_tokens(self, client: TestClient, test_user, access_token, mock_get_redis):
"""Test that logout-all revokes all refresh tokens for user.
Should delete all tokens matching the user's ID pattern.
"""
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
# Store multiple tokens
import asyncio
async def setup_tokens():
async with mock_get_redis() as redis:
await redis.setex(f"refresh_token:{user_id}:jti-1", 86400, "1")
await redis.setex(f"refresh_token:{user_id}:jti-2", 86400, "1")
await redis.setex(f"refresh_token:{user_id}:jti-3", 86400, "1")
asyncio.get_event_loop().run_until_complete(setup_tokens())
# Mock dependencies
with patch("app.api.deps.user_service") as mock_user_service:
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
with (
patch("app.api.auth.get_redis", mock_get_redis),
patch("app.services.token_store.get_redis", mock_get_redis),
):
response = client.post(
"/api/auth/logout-all",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_204_NO_CONTENT
# Verify all tokens are gone
async def count_remaining():
async with mock_get_redis() as redis:
count = 0
async for _ in redis.scan_iter(match=f"refresh_token:{user_id}:*"):
count += 1
return count
remaining = asyncio.get_event_loop().run_until_complete(count_remaining())
assert remaining == 0

View File

@ -0,0 +1,259 @@
"""Tests for users API endpoints.
Tests the user profile management endpoints.
"""
from unittest.mock import AsyncMock, patch
from uuid import UUID
from fastapi import status
from fastapi.testclient import TestClient
from app.services.user_service import AccountLinkingError
class TestGetCurrentUser:
"""Tests for GET /api/users/me endpoint."""
def test_returns_user_profile(self, client: TestClient, test_user, access_token):
"""Test that endpoint returns user profile for authenticated user.
Should return the user's profile information.
"""
with patch("app.api.deps.user_service") as mock_user_service:
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
response = client.get(
"/api/users/me",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["email"] == test_user.email
assert data["display_name"] == test_user.display_name
assert data["avatar_url"] == test_user.avatar_url
assert data["is_premium"] == test_user.is_premium
def test_requires_authentication(self, client: TestClient):
"""Test that endpoint returns 401 without authentication."""
response = client.get("/api/users/me")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_returns_401_for_invalid_token(self, client: TestClient):
"""Test that endpoint returns 401 for invalid access token."""
response = client.get(
"/api/users/me",
headers={"Authorization": "Bearer invalid.token.here"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
class TestUpdateCurrentUser:
"""Tests for PATCH /api/users/me endpoint."""
def test_updates_display_name(self, client: TestClient, test_user, access_token):
"""Test that endpoint updates display_name when provided."""
updated_user = test_user
updated_user.display_name = "New Name"
with patch("app.api.deps.user_service") as mock_deps_service:
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
with patch("app.api.users.user_service") as mock_user_service:
mock_user_service.update = AsyncMock(return_value=updated_user)
response = client.patch(
"/api/users/me",
headers={"Authorization": f"Bearer {access_token}"},
json={"display_name": "New Name"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["display_name"] == "New Name"
def test_updates_avatar_url(self, client: TestClient, test_user, access_token):
"""Test that endpoint updates avatar_url when provided."""
updated_user = test_user
updated_user.avatar_url = "https://new-avatar.com/img.jpg"
with patch("app.api.deps.user_service") as mock_deps_service:
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
with patch("app.api.users.user_service") as mock_user_service:
mock_user_service.update = AsyncMock(return_value=updated_user)
response = client.patch(
"/api/users/me",
headers={"Authorization": f"Bearer {access_token}"},
json={"avatar_url": "https://new-avatar.com/img.jpg"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
def test_requires_authentication(self, client: TestClient):
"""Test that endpoint returns 401 without authentication."""
response = client.patch(
"/api/users/me",
json={"display_name": "New Name"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
class TestGetLinkedAccounts:
"""Tests for GET /api/users/me/linked-accounts endpoint."""
def test_returns_linked_accounts(self, client: TestClient, test_user, access_token):
"""Test that endpoint returns list of linked OAuth accounts.
Should include the primary provider and any linked accounts.
"""
with patch("app.api.deps.user_service") as mock_user_service:
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
response = client.get(
"/api/users/me/linked-accounts",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)
assert len(data) >= 1 # At least primary account
assert data[0]["provider"] == test_user.oauth_provider
def test_requires_authentication(self, client: TestClient):
"""Test that endpoint returns 401 without authentication."""
response = client.get("/api/users/me/linked-accounts")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
class TestGetActiveSessions:
"""Tests for GET /api/users/me/sessions endpoint."""
def test_returns_session_count(
self, client: TestClient, test_user, access_token, mock_get_redis
):
"""Test that endpoint returns count of active sessions.
Should return the number of valid refresh tokens.
"""
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
# Store some tokens
import asyncio
async def setup_tokens():
async with mock_get_redis() as redis:
await redis.setex(f"refresh_token:{user_id}:jti-1", 86400, "1")
await redis.setex(f"refresh_token:{user_id}:jti-2", 86400, "1")
asyncio.get_event_loop().run_until_complete(setup_tokens())
with patch("app.api.deps.user_service") as mock_user_service:
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
with patch("app.services.token_store.get_redis", mock_get_redis):
response = client.get(
"/api/users/me/sessions",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "active_sessions" in data
assert data["active_sessions"] == 2
def test_requires_authentication(self, client: TestClient):
"""Test that endpoint returns 401 without authentication."""
response = client.get("/api/users/me/sessions")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
class TestUnlinkOAuthAccount:
"""Tests for DELETE /api/users/me/link/{provider} endpoint."""
def test_unlinks_provider_successfully(self, client: TestClient, test_user, access_token):
"""Test that endpoint successfully unlinks a provider.
Should return 204 when provider is unlinked.
"""
with patch("app.api.deps.user_service") as mock_deps_service:
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
with patch("app.api.users.user_service") as mock_user_service:
mock_user_service.unlink_oauth_account = AsyncMock(return_value=True)
response = client.delete(
"/api/users/me/link/discord",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_204_NO_CONTENT
def test_returns_404_if_not_linked(self, client: TestClient, test_user, access_token):
"""Test that endpoint returns 404 if provider isn't linked.
Should return 404 when trying to unlink a provider that isn't linked.
"""
with patch("app.api.deps.user_service") as mock_deps_service:
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
with patch("app.api.users.user_service") as mock_user_service:
mock_user_service.unlink_oauth_account = AsyncMock(return_value=False)
response = client.delete(
"/api/users/me/link/discord",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "not linked" in response.json()["detail"].lower()
def test_returns_400_for_primary_provider(self, client: TestClient, test_user, access_token):
"""Test that endpoint returns 400 when trying to unlink primary provider.
Cannot unlink the provider used to create the account.
"""
with patch("app.api.deps.user_service") as mock_deps_service:
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
with patch("app.api.users.user_service") as mock_user_service:
mock_user_service.unlink_oauth_account = AsyncMock(
side_effect=AccountLinkingError(
"Cannot unlink Google - it is your primary login provider"
)
)
response = client.delete(
"/api/users/me/link/google",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "primary" in response.json()["detail"].lower()
def test_returns_400_for_unknown_provider(self, client: TestClient, test_user, access_token):
"""Test that endpoint returns 400 for unknown provider.
Only 'google' and 'discord' are valid providers.
"""
with patch("app.api.deps.user_service") as mock_deps_service:
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
response = client.delete(
"/api/users/me/link/twitter",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "unknown provider" in response.json()["detail"].lower()
def test_requires_authentication(self, client: TestClient):
"""Test that endpoint returns 401 without authentication."""
response = client.delete("/api/users/me/link/discord")
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@ -70,6 +70,7 @@ TABLES_TO_TRUNCATE = [
"campaign_progress",
"collections",
"decks",
"oauth_linked_accounts",
"users",
]

View File

@ -0,0 +1 @@
"""OAuth service tests."""

View File

@ -0,0 +1,321 @@
"""Tests for Discord OAuth service.
Tests the Discord OAuth flow with mocked HTTP responses using respx.
"""
from unittest.mock import patch
import pytest
import respx
from httpx import Response
from app.services.oauth.discord import DiscordOAuth, DiscordOAuthError
pytestmark = pytest.mark.asyncio
class TestGetAuthorizationUrl:
"""Tests for get_authorization_url method."""
def test_raises_when_not_configured(self):
"""Test that get_authorization_url raises when Discord OAuth is not configured.
Without client ID, the method should raise DiscordOAuthError.
"""
oauth = DiscordOAuth()
with patch("app.services.oauth.discord.settings") as mock_settings:
mock_settings.discord_client_id = None
with pytest.raises(DiscordOAuthError, match="not configured"):
oauth.get_authorization_url("http://localhost/callback", "state123")
def test_returns_valid_url_when_configured(self):
"""Test that get_authorization_url returns properly formatted URL.
The URL should include client_id, redirect_uri, state, and scopes.
"""
oauth = DiscordOAuth()
with patch("app.services.oauth.discord.settings") as mock_settings:
mock_settings.discord_client_id = "test-client-id"
mock_settings.discord_client_secret = "test-secret"
url = oauth.get_authorization_url("http://localhost/callback", "state123")
assert "discord.com/api/oauth2/authorize" in url
assert "client_id=test-client-id" in url
assert "redirect_uri=http" in url
assert "state=state123" in url
assert "scope=" in url
assert "response_type=code" in url
class TestExchangeCodeForTokens:
"""Tests for exchange_code_for_tokens method."""
@respx.mock
async def test_returns_tokens_on_success(self):
"""Test that exchange_code_for_tokens returns tokens on success.
Mocks Discord's token endpoint to return valid tokens.
"""
oauth = DiscordOAuth()
respx.post("https://discord.com/api/oauth2/token").mock(
return_value=Response(
200,
json={
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 604800,
"token_type": "Bearer",
},
)
)
with patch("app.services.oauth.discord.settings") as mock_settings:
mock_settings.discord_client_id = "test-client-id"
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
tokens = await oauth.exchange_code_for_tokens("auth-code", "http://localhost/callback")
assert tokens["access_token"] == "test-access-token"
@respx.mock
async def test_raises_on_error_response(self):
"""Test that exchange_code_for_tokens raises on error from Discord.
If Discord returns an error, DiscordOAuthError should be raised.
"""
oauth = DiscordOAuth()
respx.post("https://discord.com/api/oauth2/token").mock(
return_value=Response(
400,
json={
"error": "invalid_grant",
"error_description": "Invalid code",
},
)
)
with patch("app.services.oauth.discord.settings") as mock_settings:
mock_settings.discord_client_id = "test-client-id"
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
with pytest.raises(DiscordOAuthError, match="Token exchange failed"):
await oauth.exchange_code_for_tokens("invalid-code", "http://localhost/callback")
class TestFetchUserInfo:
"""Tests for fetch_user_info method."""
@respx.mock
async def test_returns_user_info_on_success(self):
"""Test that fetch_user_info returns user data from Discord.
Mocks Discord's users/@me endpoint.
"""
oauth = DiscordOAuth()
respx.get("https://discord.com/api/users/@me").mock(
return_value=Response(
200,
json={
"id": "discord-user-123",
"username": "testuser",
"global_name": "Test User",
"email": "user@discord.com",
"avatar": "abc123",
},
)
)
user_info = await oauth.fetch_user_info("test-access-token")
assert user_info["id"] == "discord-user-123"
assert user_info["email"] == "user@discord.com"
@respx.mock
async def test_raises_on_error_response(self):
"""Test that fetch_user_info raises on error from Discord."""
oauth = DiscordOAuth()
respx.get("https://discord.com/api/users/@me").mock(
return_value=Response(401, json={"message": "401: Unauthorized"})
)
with pytest.raises(DiscordOAuthError, match="Failed to fetch user info"):
await oauth.fetch_user_info("invalid-token")
class TestBuildAvatarUrl:
"""Tests for _build_avatar_url method."""
def test_returns_none_for_no_avatar(self):
"""Test that _build_avatar_url returns None when avatar is None."""
oauth = DiscordOAuth()
result = oauth._build_avatar_url("123456", None)
assert result is None
def test_builds_png_url_for_static_avatar(self):
"""Test that _build_avatar_url builds PNG URL for static avatars."""
oauth = DiscordOAuth()
result = oauth._build_avatar_url("123456", "abcdef123")
assert result == "https://cdn.discordapp.com/avatars/123456/abcdef123.png"
def test_builds_gif_url_for_animated_avatar(self):
"""Test that _build_avatar_url builds GIF URL for animated avatars.
Animated avatars start with 'a_' prefix.
"""
oauth = DiscordOAuth()
result = oauth._build_avatar_url("123456", "a_animated123")
assert result == "https://cdn.discordapp.com/avatars/123456/a_animated123.gif"
class TestGetUserInfo:
"""Tests for get_user_info method (full flow)."""
@respx.mock
async def test_returns_oauth_user_info_on_success(self):
"""Test that get_user_info completes full OAuth flow.
This tests the combined token exchange + user info fetch.
"""
oauth = DiscordOAuth()
# Mock token exchange
respx.post("https://discord.com/api/oauth2/token").mock(
return_value=Response(
200,
json={
"access_token": "test-access-token",
"refresh_token": "test-refresh",
"expires_in": 604800,
},
)
)
# Mock user info
respx.get("https://discord.com/api/users/@me").mock(
return_value=Response(
200,
json={
"id": "discord-user-456",
"username": "fullflowuser",
"global_name": "Full Flow User",
"email": "fullflow@discord.com",
"avatar": "avatar123",
},
)
)
with patch("app.services.oauth.discord.settings") as mock_settings:
mock_settings.discord_client_id = "test-client-id"
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
result = await oauth.get_user_info("auth-code", "http://localhost/callback")
assert result.provider == "discord"
assert result.oauth_id == "discord-user-456"
assert result.email == "fullflow@discord.com"
assert result.name == "Full Flow User"
assert "avatar123.png" in result.avatar_url
@respx.mock
async def test_uses_username_when_no_global_name(self):
"""Test that get_user_info falls back to username for display name.
Discord users may not have global_name set.
"""
oauth = DiscordOAuth()
respx.post("https://discord.com/api/oauth2/token").mock(
return_value=Response(
200,
json={"access_token": "test-access-token"},
)
)
respx.get("https://discord.com/api/users/@me").mock(
return_value=Response(
200,
json={
"id": "discord-user-789",
"username": "legacyuser",
"global_name": None, # No global name
"email": "legacy@discord.com",
"avatar": None,
},
)
)
with patch("app.services.oauth.discord.settings") as mock_settings:
mock_settings.discord_client_id = "test-client-id"
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
result = await oauth.get_user_info("auth-code", "http://localhost/callback")
assert result.name == "legacyuser"
assert result.avatar_url is None
@respx.mock
async def test_raises_when_no_email(self):
"""Test that get_user_info raises when Discord user has no email.
Email is required for account creation.
"""
oauth = DiscordOAuth()
respx.post("https://discord.com/api/oauth2/token").mock(
return_value=Response(
200,
json={"access_token": "test-access-token"},
)
)
respx.get("https://discord.com/api/users/@me").mock(
return_value=Response(
200,
json={
"id": "discord-user-noemail",
"username": "noemailuser",
"email": None, # No verified email
},
)
)
with patch("app.services.oauth.discord.settings") as mock_settings:
mock_settings.discord_client_id = "test-client-id"
mock_settings.discord_client_secret.get_secret_value.return_value = "test-secret"
with pytest.raises(DiscordOAuthError, match="verified email"):
await oauth.get_user_info("auth-code", "http://localhost/callback")
class TestIsConfigured:
"""Tests for is_configured method."""
def test_returns_false_when_not_configured(self):
"""Test that is_configured returns False without credentials."""
oauth = DiscordOAuth()
with patch("app.services.oauth.discord.settings") as mock_settings:
mock_settings.discord_client_id = None
mock_settings.discord_client_secret = None
assert oauth.is_configured() is False
def test_returns_true_when_configured(self):
"""Test that is_configured returns True with credentials."""
oauth = DiscordOAuth()
with patch("app.services.oauth.discord.settings") as mock_settings:
mock_settings.discord_client_id = "client-id"
mock_settings.discord_client_secret = "client-secret"
assert oauth.is_configured() is True

View File

@ -0,0 +1,241 @@
"""Tests for Google OAuth service.
Tests the Google OAuth flow with mocked HTTP responses using respx.
"""
from unittest.mock import patch
import pytest
import respx
from httpx import Response
from app.services.oauth.google import GoogleOAuth, GoogleOAuthError
pytestmark = pytest.mark.asyncio
class TestGetAuthorizationUrl:
"""Tests for get_authorization_url method."""
def test_raises_when_not_configured(self):
"""Test that get_authorization_url raises when Google OAuth is not configured.
Without client ID, the method should raise GoogleOAuthError.
"""
oauth = GoogleOAuth()
with patch("app.services.oauth.google.settings") as mock_settings:
mock_settings.google_client_id = None
with pytest.raises(GoogleOAuthError, match="not configured"):
oauth.get_authorization_url("http://localhost/callback", "state123")
def test_returns_valid_url_when_configured(self):
"""Test that get_authorization_url returns properly formatted URL.
The URL should include client_id, redirect_uri, state, and scopes.
"""
oauth = GoogleOAuth()
with patch("app.services.oauth.google.settings") as mock_settings:
mock_settings.google_client_id = "test-client-id"
mock_settings.google_client_secret = "test-secret"
url = oauth.get_authorization_url("http://localhost/callback", "state123")
assert "accounts.google.com/o/oauth2/v2/auth" in url
assert "client_id=test-client-id" in url
assert "redirect_uri=http" in url
assert "state=state123" in url
assert "scope=" in url
assert "response_type=code" in url
class TestExchangeCodeForTokens:
"""Tests for exchange_code_for_tokens method."""
@respx.mock
async def test_returns_tokens_on_success(self):
"""Test that exchange_code_for_tokens returns tokens on success.
Mocks Google's token endpoint to return valid tokens.
"""
oauth = GoogleOAuth()
respx.post("https://oauth2.googleapis.com/token").mock(
return_value=Response(
200,
json={
"access_token": "test-access-token",
"id_token": "test-id-token",
"expires_in": 3600,
"token_type": "Bearer",
},
)
)
with patch("app.services.oauth.google.settings") as mock_settings:
mock_settings.google_client_id = "test-client-id"
mock_settings.google_client_secret.get_secret_value.return_value = "test-secret"
tokens = await oauth.exchange_code_for_tokens("auth-code", "http://localhost/callback")
assert tokens["access_token"] == "test-access-token"
@respx.mock
async def test_raises_on_error_response(self):
"""Test that exchange_code_for_tokens raises on error from Google.
If Google returns an error, GoogleOAuthError should be raised.
"""
oauth = GoogleOAuth()
respx.post("https://oauth2.googleapis.com/token").mock(
return_value=Response(
400,
json={
"error": "invalid_grant",
"error_description": "Code has expired",
},
)
)
with patch("app.services.oauth.google.settings") as mock_settings:
mock_settings.google_client_id = "test-client-id"
mock_settings.google_client_secret.get_secret_value.return_value = "test-secret"
with pytest.raises(GoogleOAuthError, match="Token exchange failed"):
await oauth.exchange_code_for_tokens("expired-code", "http://localhost/callback")
class TestFetchUserInfo:
"""Tests for fetch_user_info method."""
@respx.mock
async def test_returns_user_info_on_success(self):
"""Test that fetch_user_info returns user data from Google.
Mocks Google's userinfo endpoint.
"""
oauth = GoogleOAuth()
respx.get("https://www.googleapis.com/oauth2/v2/userinfo").mock(
return_value=Response(
200,
json={
"id": "google-user-123",
"email": "user@gmail.com",
"name": "Test User",
"picture": "https://google.com/avatar.jpg",
},
)
)
user_info = await oauth.fetch_user_info("test-access-token")
assert user_info["id"] == "google-user-123"
assert user_info["email"] == "user@gmail.com"
@respx.mock
async def test_raises_on_error_response(self):
"""Test that fetch_user_info raises on error from Google."""
oauth = GoogleOAuth()
respx.get("https://www.googleapis.com/oauth2/v2/userinfo").mock(
return_value=Response(401, json={"error": "Invalid token"})
)
with pytest.raises(GoogleOAuthError, match="Failed to fetch user info"):
await oauth.fetch_user_info("invalid-token")
class TestGetUserInfo:
"""Tests for get_user_info method (full flow)."""
@respx.mock
async def test_returns_oauth_user_info_on_success(self):
"""Test that get_user_info completes full OAuth flow.
This tests the combined token exchange + user info fetch.
"""
oauth = GoogleOAuth()
# Mock token exchange
respx.post("https://oauth2.googleapis.com/token").mock(
return_value=Response(
200,
json={
"access_token": "test-access-token",
"id_token": "test-id-token",
"expires_in": 3600,
},
)
)
# Mock user info
respx.get("https://www.googleapis.com/oauth2/v2/userinfo").mock(
return_value=Response(
200,
json={
"id": "google-user-456",
"email": "fullflow@gmail.com",
"name": "Full Flow User",
"picture": "https://google.com/fullflow.jpg",
},
)
)
with patch("app.services.oauth.google.settings") as mock_settings:
mock_settings.google_client_id = "test-client-id"
mock_settings.google_client_secret.get_secret_value.return_value = "test-secret"
result = await oauth.get_user_info("auth-code", "http://localhost/callback")
assert result.provider == "google"
assert result.oauth_id == "google-user-456"
assert result.email == "fullflow@gmail.com"
assert result.name == "Full Flow User"
assert result.avatar_url == "https://google.com/fullflow.jpg"
@respx.mock
async def test_raises_when_no_access_token(self):
"""Test that get_user_info raises when token response lacks access_token."""
oauth = GoogleOAuth()
respx.post("https://oauth2.googleapis.com/token").mock(
return_value=Response(
200,
json={"id_token": "only-id-token"}, # No access_token
)
)
with patch("app.services.oauth.google.settings") as mock_settings:
mock_settings.google_client_id = "test-client-id"
mock_settings.google_client_secret.get_secret_value.return_value = "test-secret"
with pytest.raises(GoogleOAuthError, match="No access token"):
await oauth.get_user_info("auth-code", "http://localhost/callback")
class TestIsConfigured:
"""Tests for is_configured method."""
def test_returns_false_when_not_configured(self):
"""Test that is_configured returns False without credentials."""
oauth = GoogleOAuth()
with patch("app.services.oauth.google.settings") as mock_settings:
mock_settings.google_client_id = None
mock_settings.google_client_secret = None
assert oauth.is_configured() is False
def test_returns_true_when_configured(self):
"""Test that is_configured returns True with credentials."""
oauth = GoogleOAuth()
with patch("app.services.oauth.google.settings") as mock_settings:
mock_settings.google_client_id = "client-id"
mock_settings.google_client_secret = "client-secret"
assert oauth.is_configured() is True

View File

@ -0,0 +1,370 @@
"""Tests for JWT service.
Tests the JWT token creation and verification functions used for
authentication throughout the application.
"""
import uuid
from datetime import UTC, datetime, timedelta
from jose import jwt
from app.config import settings
from app.schemas.auth import TokenType
from app.services.jwt_service import (
create_access_token,
create_refresh_token,
decode_token,
get_refresh_token_expiration,
get_token_expiration_seconds,
verify_access_token,
verify_refresh_token,
)
class TestCreateAccessToken:
"""Tests for create_access_token function."""
def test_creates_valid_jwt(self):
"""Test that create_access_token returns a valid JWT string.
The returned token should be decodable and contain the expected
claims including subject, expiration, and token type.
"""
user_id = uuid.uuid4()
token = create_access_token(user_id)
# Should be a valid JWT (three dot-separated parts)
assert isinstance(token, str)
assert token.count(".") == 2
# Should be decodable
payload = jwt.decode(
token,
settings.secret_key.get_secret_value(),
algorithms=[settings.jwt_algorithm],
)
assert payload["sub"] == str(user_id)
assert payload["type"] == TokenType.ACCESS.value
def test_sets_correct_expiration(self):
"""Test that access token expiration matches configured setting.
The token should expire approximately jwt_expire_minutes from now.
JWT timestamps have second precision, so we allow 1 second tolerance.
"""
user_id = uuid.uuid4()
before = datetime.now(UTC)
token = create_access_token(user_id)
after = datetime.now(UTC)
payload = jwt.decode(
token,
settings.secret_key.get_secret_value(),
algorithms=[settings.jwt_algorithm],
)
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
expected_min = (
before + timedelta(minutes=settings.jwt_expire_minutes) - timedelta(seconds=1)
)
expected_max = after + timedelta(minutes=settings.jwt_expire_minutes) + timedelta(seconds=1)
assert expected_min <= exp <= expected_max
def test_includes_issued_at(self):
"""Test that access token includes iat (issued at) claim.
The iat claim should be set to approximately the current time.
JWT timestamps have second precision, so we allow 1 second tolerance.
"""
user_id = uuid.uuid4()
before = datetime.now(UTC)
token = create_access_token(user_id)
after = datetime.now(UTC)
payload = jwt.decode(
token,
settings.secret_key.get_secret_value(),
algorithms=[settings.jwt_algorithm],
)
iat = datetime.fromtimestamp(payload["iat"], tz=UTC)
assert before - timedelta(seconds=1) <= iat <= after + timedelta(seconds=1)
class TestCreateRefreshToken:
"""Tests for create_refresh_token function."""
def test_creates_valid_jwt_with_jti(self):
"""Test that create_refresh_token returns a valid JWT and JTI.
The function should return a tuple of (token, jti) where the
token contains the JTI for revocation tracking.
"""
user_id = uuid.uuid4()
token, jti = create_refresh_token(user_id)
# Should return token and jti
assert isinstance(token, str)
assert isinstance(jti, str)
assert token.count(".") == 2
# JTI should be a valid UUID
uuid.UUID(jti) # Will raise if invalid
# Token should contain the JTI
payload = jwt.decode(
token,
settings.secret_key.get_secret_value(),
algorithms=[settings.jwt_algorithm],
)
assert payload["jti"] == jti
assert payload["type"] == TokenType.REFRESH.value
def test_sets_correct_expiration(self):
"""Test that refresh token expiration matches configured setting.
The token should expire approximately jwt_refresh_expire_days from now.
JWT timestamps have second precision, so we allow 1 second tolerance.
"""
user_id = uuid.uuid4()
before = datetime.now(UTC)
token, _ = create_refresh_token(user_id)
after = datetime.now(UTC)
payload = jwt.decode(
token,
settings.secret_key.get_secret_value(),
algorithms=[settings.jwt_algorithm],
)
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
expected_min = (
before + timedelta(days=settings.jwt_refresh_expire_days) - timedelta(seconds=1)
)
expected_max = (
after + timedelta(days=settings.jwt_refresh_expire_days) + timedelta(seconds=1)
)
assert expected_min <= exp <= expected_max
def test_generates_unique_jti(self):
"""Test that each refresh token gets a unique JTI.
Multiple calls should generate different JTIs to ensure
each token can be individually revoked.
"""
user_id = uuid.uuid4()
_, jti1 = create_refresh_token(user_id)
_, jti2 = create_refresh_token(user_id)
assert jti1 != jti2
class TestDecodeToken:
"""Tests for decode_token function."""
def test_decodes_valid_token(self):
"""Test that decode_token returns TokenPayload for valid tokens.
A valid token should be decoded into a TokenPayload with
all expected fields populated.
"""
user_id = uuid.uuid4()
token = create_access_token(user_id)
payload = decode_token(token)
assert payload is not None
assert payload.sub == str(user_id)
assert payload.type == TokenType.ACCESS
assert payload.exp is not None
assert payload.iat is not None
def test_returns_none_for_invalid_token(self):
"""Test that decode_token returns None for malformed tokens.
Invalid JWT strings should not raise exceptions but return None.
"""
result = decode_token("invalid.token.here")
assert result is None
def test_returns_none_for_wrong_signature(self):
"""Test that decode_token returns None for tokens with wrong signature.
Tokens signed with a different key should be rejected.
"""
user_id = uuid.uuid4()
# Create token with different secret
payload = {
"sub": str(user_id),
"exp": datetime.now(UTC) + timedelta(hours=1),
"iat": datetime.now(UTC),
"type": "access",
}
token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
result = decode_token(token)
assert result is None
def test_returns_none_for_expired_token(self):
"""Test that decode_token returns None for expired tokens.
Tokens past their expiration should be rejected.
"""
user_id = uuid.uuid4()
payload = {
"sub": str(user_id),
"exp": datetime.now(UTC) - timedelta(hours=1), # Already expired
"iat": datetime.now(UTC) - timedelta(hours=2),
"type": "access",
}
token = jwt.encode(
payload,
settings.secret_key.get_secret_value(),
algorithm=settings.jwt_algorithm,
)
result = decode_token(token)
assert result is None
class TestVerifyAccessToken:
"""Tests for verify_access_token function."""
def test_returns_user_id_for_valid_access_token(self):
"""Test that verify_access_token returns user ID for valid tokens.
A valid access token should return the UUID of the user.
"""
user_id = uuid.uuid4()
token = create_access_token(user_id)
result = verify_access_token(token)
assert result == user_id
def test_returns_none_for_refresh_token(self):
"""Test that verify_access_token rejects refresh tokens.
Even valid refresh tokens should be rejected when verifying
as access tokens to prevent token type confusion.
"""
user_id = uuid.uuid4()
token, _ = create_refresh_token(user_id)
result = verify_access_token(token)
assert result is None
def test_returns_none_for_invalid_token(self):
"""Test that verify_access_token returns None for invalid tokens."""
result = verify_access_token("invalid.token.here")
assert result is None
def test_returns_none_for_invalid_uuid_subject(self):
"""Test that verify_access_token returns None for non-UUID subject.
If the subject claim is not a valid UUID, the token should be rejected.
"""
payload = {
"sub": "not-a-uuid",
"exp": datetime.now(UTC) + timedelta(hours=1),
"iat": datetime.now(UTC),
"type": "access",
}
token = jwt.encode(
payload,
settings.secret_key.get_secret_value(),
algorithm=settings.jwt_algorithm,
)
result = verify_access_token(token)
assert result is None
class TestVerifyRefreshToken:
"""Tests for verify_refresh_token function."""
def test_returns_user_id_and_jti_for_valid_refresh_token(self):
"""Test that verify_refresh_token returns user ID and JTI.
A valid refresh token should return both values needed for
revocation checking.
"""
user_id = uuid.uuid4()
token, jti = create_refresh_token(user_id)
result = verify_refresh_token(token)
assert result is not None
result_user_id, result_jti = result
assert result_user_id == user_id
assert result_jti == jti
def test_returns_none_for_access_token(self):
"""Test that verify_refresh_token rejects access tokens.
Even valid access tokens should be rejected when verifying
as refresh tokens.
"""
user_id = uuid.uuid4()
token = create_access_token(user_id)
result = verify_refresh_token(token)
assert result is None
def test_returns_none_for_token_without_jti(self):
"""Test that verify_refresh_token rejects tokens missing JTI.
Refresh tokens must have a JTI for revocation tracking.
"""
payload = {
"sub": str(uuid.uuid4()),
"exp": datetime.now(UTC) + timedelta(days=7),
"iat": datetime.now(UTC),
"type": "refresh",
# No jti
}
token = jwt.encode(
payload,
settings.secret_key.get_secret_value(),
algorithm=settings.jwt_algorithm,
)
result = verify_refresh_token(token)
assert result is None
def test_returns_none_for_invalid_token(self):
"""Test that verify_refresh_token returns None for invalid tokens."""
result = verify_refresh_token("invalid.token.here")
assert result is None
class TestHelperFunctions:
"""Tests for helper functions."""
def test_get_token_expiration_seconds(self):
"""Test that get_token_expiration_seconds returns correct value.
Should return jwt_expire_minutes converted to seconds.
"""
result = get_token_expiration_seconds()
assert result == settings.jwt_expire_minutes * 60
def test_get_refresh_token_expiration(self):
"""Test that get_refresh_token_expiration returns future datetime.
Should return a datetime approximately jwt_refresh_expire_days
in the future.
"""
before = datetime.now(UTC)
result = get_refresh_token_expiration()
after = datetime.now(UTC)
expected_min = before + timedelta(days=settings.jwt_refresh_expire_days)
expected_max = after + timedelta(days=settings.jwt_refresh_expire_days)
assert expected_min <= result <= expected_max

View File

@ -0,0 +1,664 @@
"""Tests for UserService.
Tests the user service CRUD operations and OAuth-based user creation.
Uses real Postgres via the db_session fixture from conftest.
"""
from datetime import UTC, datetime, timedelta
import pytest
from app.db.models import User
from app.db.models.oauth_account import OAuthLinkedAccount
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
from app.services.user_service import AccountLinkingError, user_service
# Import db_session fixture from db conftest
pytestmark = pytest.mark.asyncio
class TestGetById:
"""Tests for get_by_id method."""
async def test_returns_user_when_found(self, db_session):
"""Test that get_by_id returns user when it exists.
Creates a user and verifies it can be retrieved by ID.
"""
# Create user directly
user = User(
email="test@example.com",
display_name="Test User",
oauth_provider="google",
oauth_id="123456",
)
db_session.add(user)
await db_session.commit()
# Retrieve by ID
from uuid import UUID
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
result = await user_service.get_by_id(db_session, user_id)
assert result is not None
assert result.email == "test@example.com"
async def test_returns_none_when_not_found(self, db_session):
"""Test that get_by_id returns None for nonexistent users."""
from uuid import uuid4
result = await user_service.get_by_id(db_session, uuid4())
assert result is None
class TestGetByEmail:
"""Tests for get_by_email method."""
async def test_returns_user_when_found(self, db_session):
"""Test that get_by_email returns user when it exists."""
user = User(
email="findme@example.com",
display_name="Find Me",
oauth_provider="discord",
oauth_id="discord123",
)
db_session.add(user)
await db_session.commit()
result = await user_service.get_by_email(db_session, "findme@example.com")
assert result is not None
assert result.display_name == "Find Me"
async def test_returns_none_when_not_found(self, db_session):
"""Test that get_by_email returns None for nonexistent emails."""
result = await user_service.get_by_email(db_session, "nobody@example.com")
assert result is None
class TestGetByOAuth:
"""Tests for get_by_oauth method."""
async def test_returns_user_when_found(self, db_session):
"""Test that get_by_oauth returns user for matching provider+id."""
user = User(
email="oauth@example.com",
display_name="OAuth User",
oauth_provider="google",
oauth_id="google-unique-id",
)
db_session.add(user)
await db_session.commit()
result = await user_service.get_by_oauth(db_session, "google", "google-unique-id")
assert result is not None
assert result.email == "oauth@example.com"
async def test_returns_none_for_wrong_provider(self, db_session):
"""Test that get_by_oauth returns None if provider doesn't match."""
user = User(
email="oauth2@example.com",
display_name="OAuth User 2",
oauth_provider="google",
oauth_id="google-id-2",
)
db_session.add(user)
await db_session.commit()
# Same ID, different provider
result = await user_service.get_by_oauth(db_session, "discord", "google-id-2")
assert result is None
async def test_returns_none_when_not_found(self, db_session):
"""Test that get_by_oauth returns None for nonexistent OAuth."""
result = await user_service.get_by_oauth(db_session, "google", "nonexistent")
assert result is None
class TestCreate:
"""Tests for create method."""
async def test_creates_user_with_all_fields(self, db_session):
"""Test that create properly persists all user fields."""
user_data = UserCreate(
email="new@example.com",
display_name="New User",
avatar_url="https://example.com/avatar.jpg",
oauth_provider="discord",
oauth_id="discord-new-id",
)
result = await user_service.create(db_session, user_data)
assert result.id is not None
assert result.email == "new@example.com"
assert result.display_name == "New User"
assert result.avatar_url == "https://example.com/avatar.jpg"
assert result.oauth_provider == "discord"
assert result.oauth_id == "discord-new-id"
assert result.is_premium is False
assert result.premium_until is None
async def test_creates_user_without_avatar(self, db_session):
"""Test that create works without optional avatar_url."""
user_data = UserCreate(
email="noavatar@example.com",
display_name="No Avatar",
oauth_provider="google",
oauth_id="google-no-avatar",
)
result = await user_service.create(db_session, user_data)
assert result.avatar_url is None
class TestCreateFromOAuth:
"""Tests for create_from_oauth method."""
async def test_creates_user_from_oauth_info(self, db_session):
"""Test that create_from_oauth converts OAuthUserInfo to User."""
oauth_info = OAuthUserInfo(
provider="google",
oauth_id="google-oauth-123",
email="oauthcreate@example.com",
name="OAuth Created User",
avatar_url="https://google.com/avatar.jpg",
)
result = await user_service.create_from_oauth(db_session, oauth_info)
assert result.email == "oauthcreate@example.com"
assert result.display_name == "OAuth Created User"
assert result.oauth_provider == "google"
assert result.oauth_id == "google-oauth-123"
class TestGetOrCreateFromOAuth:
"""Tests for get_or_create_from_oauth method."""
async def test_returns_existing_user_by_oauth(self, db_session):
"""Test that existing user is returned when OAuth matches.
Verifies the method returns (user, False) for existing users.
"""
# Create existing user
existing = User(
email="existing@example.com",
display_name="Existing",
oauth_provider="google",
oauth_id="existing-oauth-id",
)
db_session.add(existing)
await db_session.commit()
# Try to get or create with same OAuth
oauth_info = OAuthUserInfo(
provider="google",
oauth_id="existing-oauth-id",
email="existing@example.com",
name="Existing",
)
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
assert created is False
assert result.id == existing.id
async def test_links_existing_user_by_email(self, db_session):
"""Test that OAuth is linked when email matches existing user.
If a user exists with the same email but different OAuth,
the new OAuth should be linked to the existing account.
"""
# Create user with Google
existing = User(
email="link@example.com",
display_name="Link Me",
oauth_provider="google",
oauth_id="google-link-id",
)
db_session.add(existing)
await db_session.commit()
# Login with Discord (same email)
oauth_info = OAuthUserInfo(
provider="discord",
oauth_id="discord-link-id",
email="link@example.com",
name="Link Me",
avatar_url="https://discord.com/avatar.jpg",
)
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
assert created is False
assert result.id == existing.id
# OAuth should be updated to Discord
assert result.oauth_provider == "discord"
assert result.oauth_id == "discord-link-id"
async def test_creates_new_user_when_not_found(self, db_session):
"""Test that new user is created when no match exists.
Verifies the method returns (user, True) for new users.
"""
oauth_info = OAuthUserInfo(
provider="discord",
oauth_id="brand-new-id",
email="brandnew@example.com",
name="Brand New",
)
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
assert created is True
assert result.email == "brandnew@example.com"
class TestUpdate:
"""Tests for update method."""
async def test_updates_display_name(self, db_session):
"""Test that update changes display_name when provided."""
user = User(
email="update@example.com",
display_name="Old Name",
oauth_provider="google",
oauth_id="update-id",
)
db_session.add(user)
await db_session.commit()
update_data = UserUpdate(display_name="New Name")
result = await user_service.update(db_session, user, update_data)
assert result.display_name == "New Name"
async def test_updates_avatar_url(self, db_session):
"""Test that update changes avatar_url when provided."""
user = User(
email="avatar@example.com",
display_name="Avatar User",
oauth_provider="google",
oauth_id="avatar-id",
)
db_session.add(user)
await db_session.commit()
update_data = UserUpdate(avatar_url="https://new-avatar.com/img.jpg")
result = await user_service.update(db_session, user, update_data)
assert result.avatar_url == "https://new-avatar.com/img.jpg"
async def test_ignores_none_values(self, db_session):
"""Test that update doesn't change fields set to None.
Only explicitly provided fields should be updated.
"""
user = User(
email="keep@example.com",
display_name="Keep Me",
avatar_url="https://keep.com/avatar.jpg",
oauth_provider="google",
oauth_id="keep-id",
)
db_session.add(user)
await db_session.commit()
# Update only display_name, leave avatar alone
update_data = UserUpdate(display_name="Changed")
result = await user_service.update(db_session, user, update_data)
assert result.display_name == "Changed"
assert result.avatar_url == "https://keep.com/avatar.jpg"
class TestUpdateLastLogin:
"""Tests for update_last_login method."""
async def test_updates_last_login_timestamp(self, db_session):
"""Test that update_last_login sets current timestamp."""
user = User(
email="login@example.com",
display_name="Login User",
oauth_provider="google",
oauth_id="login-id",
)
db_session.add(user)
await db_session.commit()
assert user.last_login is None
before = datetime.now(UTC)
result = await user_service.update_last_login(db_session, user)
after = datetime.now(UTC)
assert result.last_login is not None
# Allow 1 second tolerance
assert before - timedelta(seconds=1) <= result.last_login <= after + timedelta(seconds=1)
class TestUpdatePremium:
"""Tests for update_premium method."""
async def test_grants_premium(self, db_session):
"""Test that update_premium sets premium status and expiration."""
user = User(
email="premium@example.com",
display_name="Premium User",
oauth_provider="google",
oauth_id="premium-id",
)
db_session.add(user)
await db_session.commit()
assert user.is_premium is False
expires = datetime.now(UTC) + timedelta(days=30)
result = await user_service.update_premium(db_session, user, expires)
assert result.is_premium is True
assert result.premium_until == expires
async def test_removes_premium(self, db_session):
"""Test that update_premium with None removes premium status."""
user = User(
email="unpremium@example.com",
display_name="Unpremium User",
oauth_provider="google",
oauth_id="unpremium-id",
is_premium=True,
premium_until=datetime.now(UTC) + timedelta(days=30),
)
db_session.add(user)
await db_session.commit()
result = await user_service.update_premium(db_session, user, None)
assert result.is_premium is False
assert result.premium_until is None
class TestDelete:
"""Tests for delete method."""
async def test_deletes_user(self, db_session):
"""Test that delete removes user from database."""
user = User(
email="delete@example.com",
display_name="Delete Me",
oauth_provider="google",
oauth_id="delete-id",
)
db_session.add(user)
await db_session.commit()
user_id = user.id
await user_service.delete(db_session, user)
# Verify user is gone
from uuid import UUID
result = await user_service.get_by_id(
db_session, UUID(user_id) if isinstance(user_id, str) else user_id
)
assert result is None
class TestGetLinkedAccount:
"""Tests for get_linked_account method."""
async def test_returns_linked_account_when_found(self, db_session):
"""Test that get_linked_account returns account when it exists.
Creates a user with a linked account and verifies it can be retrieved.
"""
# Create user
user = User(
email="primary@example.com",
display_name="Primary User",
oauth_provider="google",
oauth_id="google-primary",
)
db_session.add(user)
await db_session.commit()
# Create linked account
linked = OAuthLinkedAccount(
user_id=user.id,
provider="discord",
oauth_id="discord-linked-123",
email="linked@example.com",
)
db_session.add(linked)
await db_session.commit()
# Retrieve linked account
result = await user_service.get_linked_account(db_session, "discord", "discord-linked-123")
assert result is not None
assert result.provider == "discord"
assert result.oauth_id == "discord-linked-123"
async def test_returns_none_when_not_found(self, db_session):
"""Test that get_linked_account returns None for nonexistent accounts."""
result = await user_service.get_linked_account(db_session, "discord", "nonexistent-id")
assert result is None
class TestLinkOAuthAccount:
"""Tests for link_oauth_account method."""
async def test_links_new_provider(self, db_session):
"""Test that link_oauth_account successfully links a new provider.
Creates a Google user and links Discord to them.
"""
# Create user with Google
user = User(
email="google-user@example.com",
display_name="Google User",
oauth_provider="google",
oauth_id="google-123",
)
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
# Link Discord
discord_info = OAuthUserInfo(
provider="discord",
oauth_id="discord-456",
email="discord@example.com",
name="Discord Name",
avatar_url="https://discord.com/avatar.png",
)
result = await user_service.link_oauth_account(db_session, user, discord_info)
assert result is not None
assert result.provider == "discord"
assert result.oauth_id == "discord-456"
assert result.email == "discord@example.com"
assert result.display_name == "Discord Name"
assert str(result.user_id) == str(user.id)
async def test_raises_error_if_already_linked_to_same_user(self, db_session):
"""Test that linking same provider twice raises error.
A user cannot have the same provider linked multiple times.
"""
user = User(
email="double-link@example.com",
display_name="Double Link",
oauth_provider="google",
oauth_id="google-double",
)
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
# Link Discord first time
discord_info = OAuthUserInfo(
provider="discord",
oauth_id="discord-first",
email="first@discord.com",
name="First",
)
await user_service.link_oauth_account(db_session, user, discord_info)
await db_session.refresh(user)
# Try to link same Discord account again
with pytest.raises(AccountLinkingError) as exc_info:
await user_service.link_oauth_account(db_session, user, discord_info)
assert "already linked to your account" in str(exc_info.value)
async def test_raises_error_if_linked_to_another_user(self, db_session):
"""Test that linking account already linked to another user raises error.
The same OAuth provider+ID cannot be linked to multiple users.
"""
# Create first user and link Discord
user1 = User(
email="user1@example.com",
display_name="User 1",
oauth_provider="google",
oauth_id="google-user1",
)
db_session.add(user1)
await db_session.commit()
await db_session.refresh(user1)
discord_info = OAuthUserInfo(
provider="discord",
oauth_id="shared-discord",
email="shared@discord.com",
name="Shared",
)
await user_service.link_oauth_account(db_session, user1, discord_info)
# Create second user
user2 = User(
email="user2@example.com",
display_name="User 2",
oauth_provider="google",
oauth_id="google-user2",
)
db_session.add(user2)
await db_session.commit()
await db_session.refresh(user2)
# Try to link same Discord account to second user
with pytest.raises(AccountLinkingError) as exc_info:
await user_service.link_oauth_account(db_session, user2, discord_info)
assert "already linked to another user" in str(exc_info.value)
async def test_raises_error_if_linking_primary_provider(self, db_session):
"""Test that linking the same provider as primary raises error.
User cannot link Google if they already signed up with Google.
"""
user = User(
email="google-primary@example.com",
display_name="Google Primary",
oauth_provider="google",
oauth_id="google-primary-id",
)
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
# Try to link another Google account
google_info = OAuthUserInfo(
provider="google",
oauth_id="google-different-id",
email="different@gmail.com",
name="Different",
)
with pytest.raises(AccountLinkingError) as exc_info:
await user_service.link_oauth_account(db_session, user, google_info)
assert "primary login provider" in str(exc_info.value)
class TestUnlinkOAuthAccount:
"""Tests for unlink_oauth_account method."""
async def test_unlinks_linked_account(self, db_session):
"""Test that unlink_oauth_account removes a linked account.
Links Discord then unlinks it successfully.
"""
user = User(
email="unlink@example.com",
display_name="Unlink User",
oauth_provider="google",
oauth_id="google-unlink",
)
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
# Link Discord
discord_info = OAuthUserInfo(
provider="discord",
oauth_id="discord-unlink",
email="discord@unlink.com",
name="Discord Unlink",
)
await user_service.link_oauth_account(db_session, user, discord_info)
await db_session.refresh(user)
# Verify linked
assert len(user.linked_accounts) == 1
# Unlink
result = await user_service.unlink_oauth_account(db_session, user, "discord")
assert result is True
# Verify unlinked
linked = await user_service.get_linked_account(db_session, "discord", "discord-unlink")
assert linked is None
async def test_returns_false_if_not_linked(self, db_session):
"""Test that unlink returns False if provider isn't linked."""
user = User(
email="not-linked@example.com",
display_name="Not Linked",
oauth_provider="google",
oauth_id="google-notlinked",
)
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
result = await user_service.unlink_oauth_account(db_session, user, "discord")
assert result is False
async def test_raises_error_if_unlinking_primary(self, db_session):
"""Test that unlinking primary provider raises error.
User cannot unlink their primary OAuth provider.
"""
user = User(
email="primary-unlink@example.com",
display_name="Primary Unlink",
oauth_provider="google",
oauth_id="google-primary-unlink",
)
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
with pytest.raises(AccountLinkingError) as exc_info:
await user_service.unlink_oauth_account(db_session, user, "google")
assert "primary login provider" in str(exc_info.value)

64
backend/uv.lock generated
View File

@ -368,6 +368,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/cc/48/d9f421cb8da5afaa1a64570d9989e00fb7955e6acddc5a12979f7666ef60/coverage-7.13.1-py3-none-any.whl", hash = "sha256:2016745cb3ba554469d02819d78958b571792bb68e31302610e898f80dd3a573", size = 210722, upload-time = "2025-12-28T15:42:54.901Z" },
]
[[package]]
name = "dnspython"
version = "2.8.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8c/8b/57666417c0f90f08bcafa776861060426765fdb422eb10212086fb811d26/dnspython-2.8.0.tar.gz", hash = "sha256:181d3c6996452cb1189c4046c61599b84a5a86e099562ffde77d26984ff26d0f", size = 368251, upload-time = "2025-09-07T18:58:00.022Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" },
]
[[package]]
name = "docker"
version = "7.1.0"
@ -394,6 +403,32 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" },
]
[[package]]
name = "email-validator"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "dnspython" },
{ name = "idna" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f5/22/900cb125c76b7aaa450ce02fd727f452243f2e91a61af068b40adba60ea9/email_validator-2.3.0.tar.gz", hash = "sha256:9fc05c37f2f6cf439ff414f8fc46d917929974a82244c20eb10231ba60c54426", size = 51238, upload-time = "2025-08-26T13:09:06.831Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" },
]
[[package]]
name = "fakeredis"
version = "2.33.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "redis" },
{ name = "sortedcontainers" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5f/f9/57464119936414d60697fcbd32f38909bb5688b616ae13de6e98384433e0/fakeredis-2.33.0.tar.gz", hash = "sha256:d7bc9a69d21df108a6451bbffee23b3eba432c21a654afc7ff2d295428ec5770", size = 175187, upload-time = "2025-12-16T19:45:52.269Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/6e/78/a850fed8aeef96d4a99043c90b818b2ed5419cd5b24a4049fd7cfb9f1471/fakeredis-2.33.0-py3-none-any.whl", hash = "sha256:de535f3f9ccde1c56672ab2fdd6a8efbc4f2619fc2f1acc87b8737177d71c965", size = 119605, upload-time = "2025-12-16T19:45:51.08Z" },
]
[[package]]
name = "fastapi"
version = "0.128.0"
@ -579,7 +614,9 @@ dependencies = [
{ name = "alembic" },
{ name = "asyncpg" },
{ name = "bcrypt" },
{ name = "email-validator" },
{ name = "fastapi" },
{ name = "httpx" },
{ name = "passlib" },
{ name = "psycopg2-binary" },
{ name = "pydantic" },
@ -595,12 +632,14 @@ dependencies = [
dev = [
{ name = "beautifulsoup4" },
{ name = "black" },
{ name = "fakeredis" },
{ name = "httpx" },
{ name = "mypy" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-cov" },
{ name = "requests" },
{ name = "respx" },
{ name = "ruff" },
{ name = "testcontainers", extra = ["redis"] },
]
@ -610,7 +649,9 @@ requires-dist = [
{ name = "alembic", specifier = ">=1.18.1" },
{ name = "asyncpg", specifier = ">=0.31.0" },
{ name = "bcrypt", specifier = ">=5.0.0" },
{ name = "email-validator", specifier = ">=2.3.0" },
{ name = "fastapi", specifier = ">=0.128.0" },
{ name = "httpx", specifier = ">=0.28.1" },
{ name = "passlib", specifier = ">=1.7.4" },
{ name = "psycopg2-binary", specifier = ">=2.9.11" },
{ name = "pydantic", specifier = ">=2.12.5" },
@ -626,12 +667,14 @@ requires-dist = [
dev = [
{ name = "beautifulsoup4", specifier = ">=4.12.0" },
{ name = "black", specifier = ">=26.1.0" },
{ name = "fakeredis", specifier = ">=2.33.0" },
{ name = "httpx", specifier = ">=0.28.1" },
{ name = "mypy", specifier = ">=1.19.1" },
{ name = "pytest", specifier = ">=9.0.2" },
{ name = "pytest-asyncio", specifier = ">=1.3.0" },
{ name = "pytest-cov", specifier = ">=7.0.0" },
{ name = "requests", specifier = ">=2.31.0" },
{ name = "respx", specifier = ">=0.22.0" },
{ name = "ruff", specifier = ">=0.14.14" },
{ name = "testcontainers", extras = ["postgres", "redis"], specifier = ">=4.0.0" },
]
@ -1105,6 +1148,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
]
[[package]]
name = "respx"
version = "0.22.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f4/7c/96bd0bc759cf009675ad1ee1f96535edcb11e9666b985717eb8c87192a95/respx-0.22.0.tar.gz", hash = "sha256:3c8924caa2a50bd71aefc07aa812f2466ff489f1848c96e954a5362d17095d91", size = 28439, upload-time = "2024-12-19T22:33:59.374Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8e/67/afbb0978d5399bc9ea200f1d4489a23c9a1dad4eee6376242b8182389c79/respx-0.22.0-py2.py3-none-any.whl", hash = "sha256:631128d4c9aba15e56903fb5f66fb1eff412ce28dd387ca3a81339e52dbd3ad0", size = 25127, upload-time = "2024-12-19T22:33:57.837Z" },
]
[[package]]
name = "rsa"
version = "4.9.1"
@ -1164,6 +1219,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" },
]
[[package]]
name = "sortedcontainers"
version = "2.4.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594, upload-time = "2021-05-16T22:03:42.897Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" },
]
[[package]]
name = "soupsieve"
version = "2.8.3"