import logging from typing import Any from urllib.parse import urlencode import httpx from fastapi import APIRouter, HTTPException, Header, Request, Response from fastapi.responses import RedirectResponse from pydantic import BaseModel from app.config import get_settings from app.utils.auth import create_token, verify_token from app.utils.cookies import ( ACCESS_TOKEN_COOKIE, REFRESH_TOKEN_COOKIE, set_auth_cookies, clear_auth_cookies, ) from app.services.oauth_state import create_oauth_state, validate_and_consume_state from jose import JWTError logger = logging.getLogger(f"{__name__}.auth") router = APIRouter() settings = get_settings() # ============================================================================ # Request/Response Models # ============================================================================ class DiscordCallbackRequest(BaseModel): """Request model for Discord OAuth callback""" code: str state: str class DiscordUser(BaseModel): """Discord user information""" id: str username: str discriminator: str avatar: str | None = None email: str | None = None class AuthResponse(BaseModel): """Response model for successful authentication""" access_token: str refresh_token: str expires_in: int token_type: str = "bearer" user: DiscordUser class RefreshRequest(BaseModel): """Request model for token refresh""" refresh_token: str class RefreshResponse(BaseModel): """Response model for token refresh""" access_token: str expires_in: int token_type: str = "bearer" class UserInfoResponse(BaseModel): """Response model for /me endpoint""" user: DiscordUser teams: list[dict[str, Any]] = [] # ============================================================================ # Discord OAuth Helpers # ============================================================================ def is_discord_id_allowed(discord_id: str) -> bool: """ Check if a Discord user ID is allowed to access the system Args: discord_id: Discord user ID to check Returns: True if allowed, False otherwise """ allowed_ids = settings.allowed_discord_ids.strip() # If empty or "*", allow all (for development) if not allowed_ids or allowed_ids == "*": logger.warning("Discord whitelist disabled - allowing all users") return True # Parse comma-separated list whitelist = [id.strip() for id in allowed_ids.split(",") if id.strip()] is_allowed = discord_id in whitelist if not is_allowed: logger.warning(f"Discord ID {discord_id} not in whitelist - access denied") else: logger.info(f"Discord ID {discord_id} verified in whitelist") return is_allowed async def exchange_code_for_token( code: str, redirect_uri: str | None = None ) -> dict[str, Any]: """ Exchange Discord OAuth code for access token Args: code: OAuth authorization code from Discord redirect_uri: OAuth redirect URI (defaults to legacy URI for backwards compat) Returns: Discord OAuth token response Raises: HTTPException: If exchange fails """ # Use provided redirect_uri or fall back to legacy for backwards compatibility uri = redirect_uri or settings.discord_redirect_uri data = { "client_id": settings.discord_client_id, "client_secret": settings.discord_client_secret, "grant_type": "authorization_code", "code": code, "redirect_uri": uri, } # Debug logging logger.info(f"Token exchange using redirect_uri: {uri}") async with httpx.AsyncClient() as client: try: response = await client.post( "https://discord.com/api/oauth2/token", data=data, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: logger.error(f"Discord token exchange failed: {e}") logger.error(f"Response: {e.response.text}") raise HTTPException( status_code=400, detail="Failed to exchange code for token" ) async def get_discord_user(access_token: str) -> DiscordUser: """ Get Discord user information using access token Args: access_token: Discord OAuth access token Returns: Discord user information Raises: HTTPException: If request fails """ async with httpx.AsyncClient() as client: try: response = await client.get( "https://discord.com/api/users/@me", headers={"Authorization": f"Bearer {access_token}"}, ) response.raise_for_status() user_data = response.json() return DiscordUser(**user_data) except httpx.HTTPStatusError as e: logger.error(f"Failed to get Discord user: {e}") raise HTTPException(status_code=400, detail="Failed to get user information") # ============================================================================ # Auth Endpoints - Server-Side OAuth Flow # ============================================================================ @router.get("/discord/login") async def initiate_discord_login(return_url: str = "/") -> RedirectResponse: """ Initiate Discord OAuth flow (server-side). Creates state token in Redis and redirects to Discord authorization. Args: return_url: Frontend URL to redirect after successful auth Returns: Redirect to Discord OAuth authorization page """ # Create and store state in Redis state = await create_oauth_state(return_url) # Build Discord OAuth URL with BACKEND redirect URI params = { "client_id": settings.discord_client_id, "redirect_uri": settings.discord_server_redirect_uri, "response_type": "code", "scope": "identify email", "state": state, } auth_url = f"https://discord.com/api/oauth2/authorize?{urlencode(params)}" logger.info(f"Initiating Discord OAuth, state={state[:10]}...") return RedirectResponse(url=auth_url, status_code=302) @router.get("/discord/callback/server") async def discord_callback_server(code: str, state: str) -> RedirectResponse: """ Handle Discord OAuth callback (server-side flow). This endpoint: 1. Validates state token (CSRF protection via Redis) 2. Exchanges code for Discord access token 3. Gets Discord user info 4. Checks whitelist authorization 5. Creates JWT tokens 6. Sets HttpOnly cookies 7. Redirects to frontend success page No JavaScript required on callback page. Args: code: OAuth authorization code from Discord state: State token for CSRF protection Returns: Redirect to frontend with cookies set """ # Validate state (CSRF protection) return_url = await validate_and_consume_state(state) if not return_url: logger.warning("OAuth callback with invalid/expired state") frontend_url = settings.frontend_url return RedirectResponse( url=f"{frontend_url}/auth/login?error=invalid_state", status_code=302 ) try: # Exchange code for Discord token logger.info("Exchanging Discord code for token") discord_token_data = await exchange_code_for_token( code, redirect_uri=settings.discord_server_redirect_uri ) # Get Discord user information logger.info("Fetching Discord user information") discord_user = await get_discord_user(discord_token_data["access_token"]) # Check whitelist if not is_discord_id_allowed(discord_user.id): logger.warning(f"Unauthorized Discord ID: {discord_user.id}") frontend_url = settings.frontend_url return RedirectResponse( url=f"{frontend_url}/auth/login?error=unauthorized", status_code=302 ) # Create JWT tokens user_payload = { "user_id": discord_user.id, "username": discord_user.username, "discord_id": discord_user.id, } access_token = create_token(user_payload) refresh_token_payload = {**user_payload, "type": "refresh"} refresh_token = create_token(refresh_token_payload) logger.info( f"User {discord_user.username} authenticated via server flow, redirecting to {return_url}" ) # Create redirect response with cookies frontend_url = settings.frontend_url response = RedirectResponse(url=f"{frontend_url}{return_url}", status_code=302) set_auth_cookies(response, access_token, refresh_token) return response except HTTPException as e: logger.error(f"OAuth callback error: {e.detail}") frontend_url = settings.frontend_url return RedirectResponse( url=f"{frontend_url}/auth/login?error=auth_failed", status_code=302 ) except Exception as e: logger.error(f"OAuth callback unexpected error: {e}", exc_info=True) frontend_url = settings.frontend_url return RedirectResponse( url=f"{frontend_url}/auth/login?error=server_error", status_code=302 ) # ============================================================================ # Legacy Endpoints (Backwards Compatibility) # ============================================================================ @router.post("/discord/callback", response_model=AuthResponse) async def discord_callback(request: DiscordCallbackRequest): """ Handle Discord OAuth callback Exchange authorization code for Discord token, get user info, and create our JWT tokens. Args: request: OAuth callback data (code and state) Returns: JWT tokens and user information """ try: # Exchange code for Discord access token logger.info("Exchanging Discord code for token") discord_token_data = await exchange_code_for_token(request.code) # Get Discord user information logger.info("Fetching Discord user information") discord_user = await get_discord_user(discord_token_data["access_token"]) # Check if user is allowed if not is_discord_id_allowed(discord_user.id): raise HTTPException( status_code=403, detail="Access denied. Your Discord account is not authorized to access this system." ) # Create JWT tokens for our application user_payload = { "user_id": discord_user.id, "username": discord_user.username, "discord_id": discord_user.id, } access_token = create_token(user_payload) # Create refresh token with longer expiration refresh_token_payload = {**user_payload, "type": "refresh"} refresh_token = create_token(refresh_token_payload) logger.info(f"User {discord_user.username} authenticated successfully") return AuthResponse( access_token=access_token, refresh_token=refresh_token, expires_in=604800, # 7 days in seconds user=discord_user, ) except HTTPException: raise except Exception as e: logger.error(f"Discord OAuth callback error: {e}", exc_info=True) raise HTTPException(status_code=500, detail="Authentication failed") @router.post("/refresh", response_model=RefreshResponse) async def refresh_access_token( request_obj: Request, response: Response, body: RefreshRequest | None = None, ): """ Refresh JWT access token. Supports cookie-based refresh (preferred) and body-based (legacy). Sets new access_token cookie on success. Args: request_obj: FastAPI request object response: FastAPI response object body: Refresh token in body (optional, for backwards compatibility) Returns: New access token """ # Try cookie first refresh_token_value = request_obj.cookies.get(REFRESH_TOKEN_COOKIE) # Fall back to body if not refresh_token_value and body and body.refresh_token: refresh_token_value = body.refresh_token if not refresh_token_value: raise HTTPException(status_code=401, detail="Missing refresh token") try: # Verify refresh token payload = verify_token(refresh_token_value) # Check if it's a refresh token if payload.get("type") != "refresh": raise HTTPException(status_code=400, detail="Invalid refresh token") # Create new access token user_payload = { "user_id": payload["user_id"], "username": payload["username"], "discord_id": payload["discord_id"], } access_token = create_token(user_payload) # Set new access token cookie from app.utils.cookies import ACCESS_TOKEN_MAX_AGE, is_production response.set_cookie( key=ACCESS_TOKEN_COOKIE, value=access_token, max_age=ACCESS_TOKEN_MAX_AGE, httponly=True, secure=is_production(), samesite="lax", path="/api", ) logger.info(f"Token refreshed for user {payload['username']}") return RefreshResponse( access_token=access_token, expires_in=ACCESS_TOKEN_MAX_AGE, # 1 hour in seconds ) except JWTError: logger.warning("Invalid refresh token provided") raise HTTPException(status_code=401, detail="Invalid or expired refresh token") except Exception as e: logger.error(f"Token refresh error: {e}", exc_info=True) raise HTTPException(status_code=500, detail="Failed to refresh token") @router.get("/me", response_model=UserInfoResponse) async def get_current_user_info( request: Request, authorization: str = Header(None) ): """ Get current authenticated user information. Supports both: - Cookie-based auth (HttpOnly cookies) - preferred - Header-based auth (Authorization: Bearer token) - fallback Args: request: FastAPI request object authorization: Bearer token in Authorization header (optional) Returns: User information and teams """ # Try cookie first token = request.cookies.get(ACCESS_TOKEN_COOKIE) # Fall back to Authorization header if not token and authorization and authorization.startswith("Bearer "): token = authorization.split(" ")[1] if not token: raise HTTPException(status_code=401, detail="Missing authentication") try: # Verify token payload = verify_token(token) # Create user info user = DiscordUser( id=payload["discord_id"], username=payload["username"], discriminator="0", # Discord removed discriminators ) # TODO: Load user's teams from database teams = [] logger.info(f"User info retrieved for {user.username}") return UserInfoResponse(user=user, teams=teams) except JWTError: logger.warning("Invalid token in /me request") raise HTTPException(status_code=401, detail="Invalid or expired token") except Exception as e: logger.error(f"Get user info error: {e}", exc_info=True) raise HTTPException(status_code=500, detail="Failed to get user information") @router.get("/ws-token") async def get_websocket_token(request: Request): """ Get a short-lived token for WebSocket authentication. Safari/iOS sometimes fails to send cookies with WebSocket connections. This endpoint provides a token that can be passed as a query parameter. The token is valid for 5 minutes and can only be used for WebSocket auth. Returns: {"token": "..."} - Short-lived JWT token """ # Get token from cookie token = request.cookies.get(ACCESS_TOKEN_COOKIE) # Debug: log all cookies received logger.info(f"ws-token request cookies: {list(request.cookies.keys())}") if token: logger.info(f"Token length: {len(token)}, starts with: {token[:20] if len(token) > 20 else token}...") if not token: raise HTTPException(status_code=401, detail="Missing authentication") try: # Verify the existing token payload = verify_token(token) # Create a short-lived token (5 minutes) for WebSocket auth ws_token = create_token( { "user_id": payload["user_id"], "discord_id": payload["discord_id"], "username": payload["username"], "ws_only": True, # Mark as WebSocket-only token }, expires_minutes=5, ) logger.debug(f"Generated WS token for user {payload['username']}") return {"token": ws_token} except JWTError: logger.warning("Invalid token in /ws-token request") raise HTTPException(status_code=401, detail="Invalid or expired token") except Exception as e: logger.error(f"WS token generation error: {e}", exc_info=True) raise HTTPException(status_code=500, detail="Failed to generate token") @router.get("/verify") async def verify_auth(authorization: str = Header(None)): """ Verify authentication status Args: authorization: Bearer token in Authorization header Returns: Authentication status """ if not authorization or not authorization.startswith("Bearer "): return {"authenticated": False} token = authorization.split(" ")[1] try: payload = verify_token(token) return { "authenticated": True, "user_id": payload["user_id"], "username": payload["username"], } except JWTError: return {"authenticated": False} @router.get("/logout") @router.post("/logout") async def logout(response: Response) -> dict: """ Clear auth cookies (logout). Args: response: FastAPI response object Returns: Success message """ clear_auth_cookies(response) logger.info("User logged out, cookies cleared") return {"message": "Logged out successfully"} # ============================================================================ # Testing Endpoints (Development Only) # ============================================================================ class TestTokenRequest(BaseModel): """Request model for test token creation""" user_id: str username: str discord_id: str @router.post("/token", response_model=AuthResponse) async def create_test_token(request: TestTokenRequest): """ Create test JWT token without OAuth (for development/testing) **WARNING**: This endpoint should be disabled in production! It bypasses Discord OAuth but still respects the whitelist. Args: request: Test user data Returns: JWT tokens and mock user information """ # Still check whitelist for test tokens if not is_discord_id_allowed(request.discord_id): raise HTTPException( status_code=403, detail="Access denied. This Discord ID is not authorized." ) # Create JWT tokens user_payload = { "user_id": request.user_id, "username": request.username, "discord_id": request.discord_id, } access_token = create_token(user_payload) refresh_token = create_token({**user_payload, "type": "refresh"}) logger.info(f"Test token created for {request.username} (discord_id: {request.discord_id})") # Create mock Discord user mock_user = DiscordUser( id=request.discord_id, username=request.username, discriminator="0001", avatar=None, email=None, ) return AuthResponse( access_token=access_token, refresh_token=refresh_token, expires_in=604800, # 7 days user=mock_user, )