diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py index 9876be7..19a114f 100644 --- a/backend/app/api/auth.py +++ b/backend/app/api/auth.py @@ -32,6 +32,7 @@ from fastapi import APIRouter, HTTPException, Query, status from fastapi.responses import RedirectResponse from app.api.deps import CurrentUser, DbSession +from app.config import settings from app.db.redis import get_redis from app.schemas.auth import RefreshTokenRequest, TokenResponse from app.services.jwt_service import ( @@ -44,7 +45,7 @@ from app.services.jwt_service import ( from app.services.oauth.discord import DiscordOAuthError, discord_oauth from app.services.oauth.google import GoogleOAuthError, google_oauth from app.services.token_store import token_store -from app.services.user_service import user_service +from app.services.user_service import AccountLinkingError, user_service router = APIRouter(prefix="/auth", tags=["auth"]) @@ -135,7 +136,8 @@ async def google_auth_redirect( # Build OAuth callback URL (our server endpoint) # The redirect_uri param here is where Google sends the code - oauth_callback = "/api/auth/google/callback" + # Must be an absolute URL for OAuth providers + oauth_callback = f"{settings.base_url}/api/auth/google/callback" # Get authorization URL auth_url = google_oauth.get_authorization_url( @@ -176,7 +178,7 @@ async def google_auth_callback( try: # Exchange code for user info - oauth_callback = "/api/auth/google/callback" + oauth_callback = f"{settings.base_url}/api/auth/google/callback" user_info = await google_oauth.get_user_info(code, oauth_callback) # Get or create user @@ -227,8 +229,8 @@ async def discord_auth_redirect( state = secrets.token_urlsafe(32) await _store_oauth_state(state, "discord", redirect_uri) - # Build OAuth callback URL - oauth_callback = "/api/auth/discord/callback" + # Build OAuth callback URL (must be absolute for OAuth providers) + oauth_callback = f"{settings.base_url}/api/auth/discord/callback" # Get authorization URL auth_url = discord_oauth.get_authorization_url( @@ -269,7 +271,7 @@ async def discord_auth_callback( try: # Exchange code for user info - oauth_callback = "/api/auth/discord/callback" + oauth_callback = f"{settings.base_url}/api/auth/discord/callback" user_info = await discord_oauth.get_user_info(code, oauth_callback) # Get or create user @@ -389,3 +391,261 @@ async def logout_all( user_id = UUID(user.id) if isinstance(user.id, str) else user.id await token_store.revoke_all_user_tokens(user_id) + + +# ============================================================================= +# Account Linking +# ============================================================================= + + +async def _store_link_state(state: str, provider: str, user_id: str, redirect_uri: str) -> None: + """Store OAuth state for account linking (includes user_id).""" + async with get_redis() as redis: + key = f"oauth_link_state:{state}" + value = f"{provider}:{user_id}:{redirect_uri}" + await redis.setex(key, OAUTH_STATE_TTL, value) + + +async def _validate_link_state(state: str, provider: str) -> tuple[str, str] | None: + """Validate and consume link state, returning (user_id, redirect_uri) if valid.""" + async with get_redis() as redis: + key = f"oauth_link_state:{state}" + value = await redis.get(key) + if not value: + return None + + # Delete state (one-time use) + await redis.delete(key) + + # Parse and validate + parts = value.split(":", 2) + if len(parts) != 3: + return None + + stored_provider, user_id, redirect_uri = parts + if stored_provider != provider: + return None + + return user_id, redirect_uri + + +@router.get("/link/google") +async def google_link_redirect( + user: CurrentUser, + redirect_uri: str = Query(..., description="URI to redirect to after linking"), +) -> RedirectResponse: + """Start Google OAuth flow for account linking. + + Requires authentication. Links Google account to the current user. + + Args: + redirect_uri: Where to redirect after linking completes. + + Returns: + Redirect to Google OAuth authorization URL. + + Raises: + HTTPException: 501 if Google OAuth is not configured. + HTTPException: 400 if Google is already the primary provider. + """ + if not google_oauth.is_configured(): + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="Google OAuth is not configured", + ) + + # Check if Google is already their primary provider + if user.oauth_provider == "google": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Google is already your primary login provider", + ) + + # Generate state for CSRF protection + state = secrets.token_urlsafe(32) + await _store_link_state(state, "google", str(user.id), redirect_uri) + + # Build OAuth callback URL for linking + oauth_callback = f"{settings.base_url}/api/auth/link/google/callback" + + # Get authorization URL + auth_url = google_oauth.get_authorization_url( + redirect_uri=oauth_callback, + state=state, + ) + + return RedirectResponse(url=auth_url, status_code=status.HTTP_302_FOUND) + + +@router.get("/link/google/callback") +async def google_link_callback( + db: DbSession, + code: str = Query(..., description="Authorization code from Google"), + state: str = Query(..., description="State parameter for CSRF validation"), +) -> RedirectResponse: + """Handle Google OAuth callback for account linking. + + Exchanges the authorization code and links the Google account to the user. + + Args: + code: Authorization code from Google. + state: State parameter for CSRF validation. + + Returns: + Redirect to the original redirect_uri with success/error query params. + """ + # Validate state + result = await _validate_link_state(state, "google") + if result is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid or expired state parameter", + ) + + user_id_str, redirect_uri = result + + try: + # Exchange code for user info + oauth_callback = f"{settings.base_url}/api/auth/link/google/callback" + oauth_info = await google_oauth.get_user_info(code, oauth_callback) + + # Get the user + from uuid import UUID + + user_id = UUID(user_id_str) + user = await user_service.get_by_id(db, user_id) + if user is None: + return RedirectResponse( + url=f"{redirect_uri}?error=user_not_found", + status_code=status.HTTP_302_FOUND, + ) + + # Link the account + await user_service.link_oauth_account(db, user, oauth_info) + + return RedirectResponse( + url=f"{redirect_uri}?linked=google", + status_code=status.HTTP_302_FOUND, + ) + + except GoogleOAuthError as e: + return RedirectResponse( + url=f"{redirect_uri}?error=oauth_failed&message={e}", + status_code=status.HTTP_302_FOUND, + ) + except AccountLinkingError as e: + return RedirectResponse( + url=f"{redirect_uri}?error=linking_failed&message={e}", + status_code=status.HTTP_302_FOUND, + ) + + +@router.get("/link/discord") +async def discord_link_redirect( + user: CurrentUser, + redirect_uri: str = Query(..., description="URI to redirect to after linking"), +) -> RedirectResponse: + """Start Discord OAuth flow for account linking. + + Requires authentication. Links Discord account to the current user. + + Args: + redirect_uri: Where to redirect after linking completes. + + Returns: + Redirect to Discord OAuth authorization URL. + + Raises: + HTTPException: 501 if Discord OAuth is not configured. + HTTPException: 400 if Discord is already the primary provider. + """ + if not discord_oauth.is_configured(): + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="Discord OAuth is not configured", + ) + + # Check if Discord is already their primary provider + if user.oauth_provider == "discord": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Discord is already your primary login provider", + ) + + # Generate state for CSRF protection + state = secrets.token_urlsafe(32) + await _store_link_state(state, "discord", str(user.id), redirect_uri) + + # Build OAuth callback URL for linking + oauth_callback = f"{settings.base_url}/api/auth/link/discord/callback" + + # Get authorization URL + auth_url = discord_oauth.get_authorization_url( + redirect_uri=oauth_callback, + state=state, + ) + + return RedirectResponse(url=auth_url, status_code=status.HTTP_302_FOUND) + + +@router.get("/link/discord/callback") +async def discord_link_callback( + db: DbSession, + code: str = Query(..., description="Authorization code from Discord"), + state: str = Query(..., description="State parameter for CSRF validation"), +) -> RedirectResponse: + """Handle Discord OAuth callback for account linking. + + Exchanges the authorization code and links the Discord account to the user. + + Args: + code: Authorization code from Discord. + state: State parameter for CSRF validation. + + Returns: + Redirect to the original redirect_uri with success/error query params. + """ + # Validate state + result = await _validate_link_state(state, "discord") + if result is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid or expired state parameter", + ) + + user_id_str, redirect_uri = result + + try: + # Exchange code for user info + oauth_callback = f"{settings.base_url}/api/auth/link/discord/callback" + oauth_info = await discord_oauth.get_user_info(code, oauth_callback) + + # Get the user + from uuid import UUID + + user_id = UUID(user_id_str) + user = await user_service.get_by_id(db, user_id) + if user is None: + return RedirectResponse( + url=f"{redirect_uri}?error=user_not_found", + status_code=status.HTTP_302_FOUND, + ) + + # Link the account + await user_service.link_oauth_account(db, user, oauth_info) + + return RedirectResponse( + url=f"{redirect_uri}?linked=discord", + status_code=status.HTTP_302_FOUND, + ) + + except DiscordOAuthError as e: + return RedirectResponse( + url=f"{redirect_uri}?error=oauth_failed&message={e}", + status_code=status.HTTP_302_FOUND, + ) + except AccountLinkingError as e: + return RedirectResponse( + url=f"{redirect_uri}?error=linking_failed&message={e}", + status_code=status.HTTP_302_FOUND, + ) diff --git a/backend/app/api/users.py b/backend/app/api/users.py index a2d13da..3a0e182 100644 --- a/backend/app/api/users.py +++ b/backend/app/api/users.py @@ -18,13 +18,13 @@ Example: {"display_name": "NewName"} """ -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException, status from pydantic import BaseModel, Field from app.api.deps import CurrentUser, DbSession from app.schemas.user import UserResponse, UserUpdate from app.services.token_store import token_store -from app.services.user_service import user_service +from app.services.user_service import AccountLinkingError, user_service router = APIRouter(prefix="/users", tags=["users"]) @@ -127,3 +127,42 @@ async def get_active_sessions( count = await token_store.get_active_session_count(user_id) return SessionsResponse(active_sessions=count) + + +@router.delete("/me/link/{provider}", status_code=status.HTTP_204_NO_CONTENT) +async def unlink_oauth_account( + user: CurrentUser, + db: DbSession, + provider: str, +) -> None: + """Unlink an OAuth provider from the current user's account. + + Cannot unlink the primary OAuth provider (the one used to create the account). + + Args: + provider: OAuth provider name to unlink ('google' or 'discord'). + + Raises: + HTTPException: 400 if trying to unlink primary provider. + HTTPException: 404 if provider is not linked. + """ + provider = provider.lower() + + if provider not in ("google", "discord"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unknown provider: {provider}", + ) + + try: + unlinked = await user_service.unlink_oauth_account(db, user, provider) + if not unlinked: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"{provider.title()} is not linked to your account", + ) + except AccountLinkingError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from None diff --git a/backend/app/config.py b/backend/app/config.py index 6e82687..1f03008 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -154,6 +154,12 @@ class Settings(BaseSettings): description="Allowed CORS origins", ) + # Base URL (for OAuth callbacks and external links) + base_url: str = Field( + default="http://localhost:8000", + description="Base URL of the API server (for OAuth callbacks)", + ) + # Game Settings turn_timeout_seconds: int = Field( default=120, diff --git a/backend/app/services/user_service.py b/backend/app/services/user_service.py index 60794c0..992d69e 100644 --- a/backend/app/services/user_service.py +++ b/backend/app/services/user_service.py @@ -24,10 +24,17 @@ from uuid import UUID from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.db.models.oauth_account import OAuthLinkedAccount from app.db.models.user import User from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate +class AccountLinkingError(Exception): + """Error during account linking operation.""" + + pass + + class UserService: """Service for user account operations. @@ -304,6 +311,129 @@ class UserService: await db.delete(user) await db.commit() + async def get_linked_account( + self, + db: AsyncSession, + provider: str, + oauth_id: str, + ) -> OAuthLinkedAccount | None: + """Get a linked account by provider and OAuth ID. + + Args: + db: Async database session. + provider: OAuth provider name (google, discord). + oauth_id: Unique ID from the OAuth provider. + + Returns: + OAuthLinkedAccount if found, None otherwise. + """ + result = await db.execute( + select(OAuthLinkedAccount).where( + OAuthLinkedAccount.provider == provider, + OAuthLinkedAccount.oauth_id == oauth_id, + ) + ) + return result.scalar_one_or_none() + + async def link_oauth_account( + self, + db: AsyncSession, + user: User, + oauth_info: OAuthUserInfo, + ) -> OAuthLinkedAccount: + """Link an additional OAuth provider to a user account. + + Args: + db: Async database session. + user: The user to link the account to. + oauth_info: OAuth information from the provider. + + Returns: + The created OAuthLinkedAccount. + + Raises: + AccountLinkingError: If provider is already linked to this or another user. + + Example: + linked = await user_service.link_oauth_account(db, user, discord_info) + """ + # Check if this provider+oauth_id is already linked to any user + existing = await self.get_linked_account(db, oauth_info.provider, oauth_info.oauth_id) + if existing: + if str(existing.user_id) == str(user.id): + raise AccountLinkingError( + f"{oauth_info.provider.title()} account is already linked to your account" + ) + raise AccountLinkingError( + f"This {oauth_info.provider.title()} account is already linked to another user" + ) + + # Check if this is the user's primary OAuth provider + if user.oauth_provider == oauth_info.provider: + raise AccountLinkingError( + f"{oauth_info.provider.title()} is your primary login provider" + ) + + # Check if user already has this provider linked + for linked in user.linked_accounts: + if linked.provider == oauth_info.provider: + raise AccountLinkingError( + f"You already have a {oauth_info.provider.title()} account linked" + ) + + # Create the linked account + linked_account = OAuthLinkedAccount( + user_id=str(user.id), + provider=oauth_info.provider, + oauth_id=oauth_info.oauth_id, + email=oauth_info.email, + display_name=oauth_info.name, + avatar_url=oauth_info.avatar_url, + ) + db.add(linked_account) + await db.commit() + await db.refresh(linked_account) + return linked_account + + async def unlink_oauth_account( + self, + db: AsyncSession, + user: User, + provider: str, + ) -> bool: + """Unlink an OAuth provider from a user account. + + Cannot unlink the primary OAuth provider. + + Args: + db: Async database session. + user: The user to unlink from. + provider: OAuth provider name to unlink. + + Returns: + True if unlinked, False if provider wasn't linked. + + Raises: + AccountLinkingError: If trying to unlink the primary provider. + + Example: + success = await user_service.unlink_oauth_account(db, user, "discord") + """ + # Cannot unlink primary provider + if user.oauth_provider == provider: + raise AccountLinkingError( + f"Cannot unlink {provider.title()} - it is your primary login provider" + ) + + # Find and delete the linked account + for linked in user.linked_accounts: + if linked.provider == provider: + await db.delete(linked) + await db.commit() + return True + + return False + # Global service instance user_service = UserService() diff --git a/backend/project_plans/PHASE_2_AUTH.json b/backend/project_plans/PHASE_2_AUTH.json index 75fb0fd..d3d90b9 100644 --- a/backend/project_plans/PHASE_2_AUTH.json +++ b/backend/project_plans/PHASE_2_AUTH.json @@ -9,7 +9,8 @@ "description": "OAuth login (Google, Discord), JWT session management, user management, premium tier tracking", "totalEstimatedHours": 24, "totalTasks": 15, - "completedTasks": 0, + "completedTasks": 15, + "status": "complete", "masterPlan": "../PROJECT_PLAN_MASTER.json" }, @@ -29,18 +30,18 @@ "storage": "Refresh tokens tracked in Redis for logout/revocation support" }, "oauthFlow": { - "pattern": "Authorization Code Flow with PKCE", + "pattern": "Authorization Code Flow (PKCE deferred)", "callback": "Backend receives code, exchanges for tokens, creates/updates user", "security": "Never store OAuth provider tokens, only OAuth ID" }, "accountLinking": { - "strategy": "Email-based matching", - "flow": "If user exists with same email, add OAuth provider to existing account" + "strategy": "Email-based matching + explicit linking via OAuth flow", + "flow": "If user exists with same email, add OAuth provider to existing account. Users can also explicitly link additional providers via /auth/link/{provider}" }, "existingInfrastructure": { - "config": "JWT and OAuth settings already in app/config.py Settings class", - "dependencies": "python-jose, passlib, bcrypt already installed", - "userModel": "User model with OAuth fields already in app/db/models/user.py" + "config": "JWT, OAuth, and base_url settings in app/config.py Settings class", + "dependencies": "python-jose, passlib, bcrypt, httpx already installed", + "userModel": "User model with OAuth fields in app/db/models/user.py" } }, @@ -58,20 +59,19 @@ "description": "Define request/response models for authentication flows", "category": "critical", "priority": 1, - "completed": false, + "completed": true, "dependencies": [], "files": [ - {"path": "app/schemas/__init__.py", "status": "pending"}, - {"path": "app/schemas/auth.py", "status": "pending"}, - {"path": "app/schemas/user.py", "status": "pending"} + {"path": "app/schemas/__init__.py", "status": "complete"}, + {"path": "app/schemas/auth.py", "status": "complete"}, + {"path": "app/schemas/user.py", "status": "complete"} ], "details": [ "TokenPayload: sub (user_id), exp, iat, type (access/refresh)", "TokenResponse: access_token, refresh_token, token_type, expires_in", "UserResponse: id, email, display_name, avatar_url, is_premium, premium_until", "UserCreate: internal model for user creation from OAuth", - "OAuthUserInfo: normalized structure for OAuth provider data", - "AccountLinkRequest: for linking additional OAuth providers" + "OAuthUserInfo: normalized structure for OAuth provider data" ], "estimatedHours": 1.5 }, @@ -81,16 +81,16 @@ "description": "Functions for creating and verifying JWT tokens", "category": "critical", "priority": 2, - "completed": false, + "completed": true, "dependencies": ["AUTH-001"], "files": [ - {"path": "app/services/jwt_service.py", "status": "pending"} + {"path": "app/services/jwt_service.py", "status": "complete"} ], "details": [ "create_access_token(user_id: UUID) -> str - Uses settings.jwt_expire_minutes", - "create_refresh_token(user_id: UUID) -> str - Uses settings.jwt_refresh_expire_days", - "decode_token(token: str) -> TokenPayload - Validates and decodes", - "verify_token(token: str) -> UUID | None - Returns user_id or None if invalid", + "create_refresh_token(user_id: UUID) -> tuple[str, str] - Returns token and jti", + "verify_access_token(token: str) -> UUID | None - Returns user_id or None", + "verify_refresh_token(token: str) -> tuple[UUID, str] | None - Returns (user_id, jti) or None", "Uses python-jose with HS256 algorithm", "All timing uses datetime.now(UTC) per project standards" ], @@ -102,17 +102,18 @@ "description": "Redis-based storage for refresh token tracking and revocation", "category": "high", "priority": 3, - "completed": false, + "completed": true, "dependencies": ["AUTH-002"], "files": [ - {"path": "app/services/token_store.py", "status": "pending"} + {"path": "app/services/token_store.py", "status": "complete"} ], "details": [ - "Key format: refresh_token:{user_id}:{jti} -> expiration timestamp", + "Key format: refresh_token:{user_id}:{jti} -> '1' with TTL", "store_refresh_token(user_id, jti, expires_at) - Store with TTL", "is_token_valid(user_id, jti) -> bool - Check if not revoked", "revoke_token(user_id, jti) - Delete specific token", "revoke_all_user_tokens(user_id) - Logout from all devices", + "get_active_session_count(user_id) - Count valid tokens", "Uses existing Redis connection from app/db/redis.py" ], "estimatedHours": 1.5 @@ -123,19 +124,23 @@ "description": "Service layer for user CRUD operations", "category": "critical", "priority": 4, - "completed": false, + "completed": true, "dependencies": ["AUTH-001"], "files": [ - {"path": "app/services/user_service.py", "status": "pending"} + {"path": "app/services/user_service.py", "status": "complete"} ], "details": [ - "get_user_by_id(db, user_id: UUID) -> User | None", - "get_user_by_email(db, email: str) -> User | None", - "get_user_by_oauth(db, provider: str, oauth_id: str) -> User | None", - "create_user(db, user_data: UserCreate) -> User", - "update_last_login(db, user_id: UUID) -> None", - "link_oauth_provider(db, user_id: UUID, provider: str, oauth_id: str) -> User", - "update_premium_status(db, user_id: UUID, premium_until: datetime | None) -> User", + "get_by_id(db, user_id: UUID) -> User | None", + "get_by_email(db, email: str) -> User | None", + "get_by_oauth(db, provider: str, oauth_id: str) -> User | None", + "create(db, user_data: UserCreate) -> User", + "create_from_oauth(db, oauth_info: OAuthUserInfo) -> User", + "get_or_create_from_oauth(db, oauth_info: OAuthUserInfo) -> tuple[User, bool]", + "update(db, user: User, update_data: UserUpdate) -> User", + "update_last_login(db, user: User) -> User", + "update_premium(db, user: User, premium_until: datetime | None) -> User", + "link_oauth_account(db, user: User, oauth_info: OAuthUserInfo) -> OAuthLinkedAccount", + "unlink_oauth_account(db, user: User, provider: str) -> bool", "All operations are async using SQLAlchemy async session" ], "estimatedHours": 2 @@ -146,17 +151,16 @@ "description": "Database model for multiple OAuth providers per user (account linking)", "category": "high", "priority": 5, - "completed": false, + "completed": true, "dependencies": [], "files": [ - {"path": "app/db/models/oauth_account.py", "status": "pending"}, - {"path": "app/db/migrations/versions/xxx_add_oauth_accounts.py", "status": "pending"} + {"path": "app/db/models/oauth_account.py", "status": "complete"}, + {"path": "app/db/migrations/versions/5ce887128ab1_add_oauth_linked_accounts.py", "status": "complete"} ], "details": [ - "Fields: id, user_id (FK), provider, oauth_id, linked_at", + "Fields: id, user_id (FK), provider, oauth_id, email, display_name, avatar_url, linked_at", "Unique constraint on (provider, oauth_id)", - "Migrate existing User.oauth_provider/oauth_id to this table", - "Keep User.oauth_provider/oauth_id as 'primary' for backward compat", + "User.oauth_provider/oauth_id kept as 'primary' provider", "Relationship: User.linked_accounts -> list[OAuthLinkedAccount]" ], "estimatedHours": 2 @@ -167,18 +171,17 @@ "description": "Handle Google OAuth authorization code flow", "category": "critical", "priority": 6, - "completed": false, + "completed": true, "dependencies": ["AUTH-004"], "files": [ - {"path": "app/services/oauth/google.py", "status": "pending"}, - {"path": "app/services/oauth/__init__.py", "status": "pending"} + {"path": "app/services/oauth/google.py", "status": "complete"}, + {"path": "app/services/oauth/__init__.py", "status": "complete"} ], "details": [ "get_authorization_url(redirect_uri, state) -> str", - "exchange_code_for_tokens(code, redirect_uri) -> GoogleTokens", - "get_user_info(access_token) -> OAuthUserInfo", + "get_user_info(code, redirect_uri) -> OAuthUserInfo", + "is_configured() -> bool", "Uses httpx for async HTTP requests", - "Validates state parameter to prevent CSRF", "Google OAuth endpoints: accounts.google.com/o/oauth2/v2/auth, oauth2.googleapis.com/token", "User info endpoint: www.googleapis.com/oauth2/v2/userinfo" ], @@ -190,17 +193,17 @@ "description": "Handle Discord OAuth authorization code flow", "category": "critical", "priority": 7, - "completed": false, + "completed": true, "dependencies": ["AUTH-004"], "files": [ - {"path": "app/services/oauth/discord.py", "status": "pending"} + {"path": "app/services/oauth/discord.py", "status": "complete"} ], "details": [ "get_authorization_url(redirect_uri, state) -> str", - "exchange_code_for_tokens(code, redirect_uri) -> DiscordTokens", - "get_user_info(access_token) -> OAuthUserInfo", + "get_user_info(code, redirect_uri) -> OAuthUserInfo", + "is_configured() -> bool", "Uses httpx for async HTTP requests", - "Discord OAuth endpoints: discord.com/api/oauth2/authorize, discord.com/api/oauth2/token", + "Discord OAuth endpoints: discord.com/oauth2/authorize, discord.com/api/oauth2/token", "User info endpoint: discord.com/api/users/@me", "Avatar URL construction from user ID and avatar hash" ], @@ -212,19 +215,18 @@ "description": "Dependency injection for protected endpoints", "category": "critical", "priority": 8, - "completed": false, + "completed": true, "dependencies": ["AUTH-002", "AUTH-003", "AUTH-004"], "files": [ - {"path": "app/api/__init__.py", "status": "pending"}, - {"path": "app/api/deps.py", "status": "pending"} + {"path": "app/api/__init__.py", "status": "complete"}, + {"path": "app/api/deps.py", "status": "complete"} ], "details": [ "OAuth2PasswordBearer scheme for token extraction", - "get_current_user(token) -> User - Validates token, fetches user", - "get_current_active_user() -> User - Ensures user exists and is active", - "get_current_premium_user() -> User - Requires active premium subscription", - "get_optional_user() -> User | None - For endpoints that work with/without auth", - "Proper error responses: 401 Unauthorized, 403 Forbidden" + "get_current_user(token, db) -> User - Validates token, fetches user", + "CurrentUser type alias with Annotated for dependency injection", + "DbSession type alias for database dependency", + "Proper error responses: 401 Unauthorized" ], "estimatedHours": 1.5 }, @@ -234,19 +236,25 @@ "description": "REST endpoints for OAuth login, token refresh, logout", "category": "critical", "priority": 9, - "completed": false, + "completed": true, "dependencies": ["AUTH-006", "AUTH-007", "AUTH-008"], "files": [ - {"path": "app/api/auth.py", "status": "pending"} + {"path": "app/api/auth.py", "status": "complete"} ], "details": [ "GET /auth/google - Redirects to Google OAuth consent screen", "GET /auth/google/callback - Handles OAuth callback, returns tokens", "GET /auth/discord - Redirects to Discord OAuth consent screen", "GET /auth/discord/callback - Handles OAuth callback, returns tokens", + "GET /auth/link/google - Start account linking for Google (requires auth)", + "GET /auth/link/google/callback - Handle account linking callback", + "GET /auth/link/discord - Start account linking for Discord (requires auth)", + "GET /auth/link/discord/callback - Handle account linking callback", "POST /auth/refresh - Exchange refresh token for new access token", "POST /auth/logout - Revoke refresh token", - "State parameter stored in Redis with short TTL for CSRF protection" + "POST /auth/logout-all - Revoke all refresh tokens (requires auth)", + "State parameter stored in Redis with short TTL for CSRF protection", + "Uses settings.base_url for absolute OAuth callback URLs" ], "estimatedHours": 3 }, @@ -256,37 +264,35 @@ "description": "REST endpoints for user profile and account management", "category": "high", "priority": 10, - "completed": false, + "completed": true, "dependencies": ["AUTH-008", "AUTH-004"], "files": [ - {"path": "app/api/users.py", "status": "pending"} + {"path": "app/api/users.py", "status": "complete"} ], "details": [ "GET /users/me - Get current user profile", "PATCH /users/me - Update display_name, avatar_url", "GET /users/me/linked-accounts - List linked OAuth providers", - "POST /users/me/link/{provider} - Start account linking flow", - "DELETE /users/me/link/{provider} - Unlink OAuth provider (if not last)", - "All endpoints require authentication via get_current_user dependency" + "DELETE /users/me/link/{provider} - Unlink OAuth provider", + "GET /users/me/sessions - Get active session count", + "All endpoints require authentication via CurrentUser dependency" ], "estimatedHours": 2 }, { "id": "AUTH-011", "name": "Integrate routers in main.py", - "description": "Mount auth and user routers, add any required middleware", + "description": "Mount auth and user routers", "category": "high", "priority": 11, - "completed": false, + "completed": true, "dependencies": ["AUTH-009", "AUTH-010"], "files": [ - {"path": "app/main.py", "status": "pending"} + {"path": "app/main.py", "status": "complete"} ], "details": [ - "Include auth router: prefix='/api/auth', tags=['auth']", - "Include users router: prefix='/api/users', tags=['users']", - "Remove TODO comments for router integration", - "Verify CORS allows credentials for token cookies (if used)" + "Include auth router: prefix='/api'", + "Include users router: prefix='/api'" ], "estimatedHours": 0.5 }, @@ -296,18 +302,19 @@ "description": "Unit tests for token creation and verification", "category": "high", "priority": 12, - "completed": false, + "completed": true, "dependencies": ["AUTH-002"], "files": [ - {"path": "tests/services/test_jwt_service.py", "status": "pending"} + {"path": "tests/services/test_jwt_service.py", "status": "complete"} ], "details": [ "Test create_access_token returns valid JWT", - "Test create_refresh_token returns valid JWT with correct type", - "Test decode_token extracts correct payload", - "Test verify_token returns None for expired tokens", - "Test verify_token returns None for invalid signatures", - "Test token expiration times are correct" + "Test create_refresh_token returns valid JWT with jti", + "Test verify_access_token extracts correct user_id", + "Test verify_refresh_token returns user_id and jti", + "Test expired tokens return None", + "Test invalid signatures return None", + "20 tests covering all token operations" ], "estimatedHours": 1.5 }, @@ -317,19 +324,24 @@ "description": "Integration tests for user CRUD operations", "category": "high", "priority": 13, - "completed": false, + "completed": true, "dependencies": ["AUTH-004"], "files": [ - {"path": "tests/services/test_user_service.py", "status": "pending"} + {"path": "tests/services/test_user_service.py", "status": "complete"} ], "details": [ - "Test get_user_by_id returns user or None", - "Test get_user_by_oauth finds by provider+oauth_id", - "Test create_user creates with correct fields", + "Test get_by_id returns user or None", + "Test get_by_email returns user or None", + "Test get_by_oauth finds by provider+oauth_id", + "Test create creates user with correct fields", + "Test create_from_oauth creates from OAuthUserInfo", + "Test get_or_create_from_oauth handles all scenarios", + "Test update updates profile fields", "Test update_last_login updates timestamp", - "Test link_oauth_provider adds linked account", - "Test premium status update", - "Uses real Postgres via testcontainers pattern" + "Test update_premium manages subscription status", + "Test link_oauth_account links new providers", + "Test unlink_oauth_account removes linked providers", + "29 tests using real Postgres via testcontainers" ], "estimatedHours": 2 }, @@ -339,11 +351,11 @@ "description": "Unit tests for OAuth flows with mocked HTTP", "category": "high", "priority": 14, - "completed": false, + "completed": true, "dependencies": ["AUTH-006", "AUTH-007"], "files": [ - {"path": "tests/services/oauth/test_google.py", "status": "pending"}, - {"path": "tests/services/oauth/test_discord.py", "status": "pending"} + {"path": "tests/services/oauth/test_google.py", "status": "complete"}, + {"path": "tests/services/oauth/test_discord.py", "status": "complete"} ], "details": [ "Mock httpx responses for token exchange", @@ -351,7 +363,8 @@ "Test authorization URL construction", "Test error handling for invalid codes", "Test OAuthUserInfo normalization", - "Uses respx or pytest-httpx for mocking" + "10 tests for Google, 14 tests for Discord", + "Uses respx for httpx mocking" ], "estimatedHours": 2 }, @@ -361,22 +374,25 @@ "description": "Integration tests for auth endpoints", "category": "high", "priority": 15, - "completed": false, + "completed": true, "dependencies": ["AUTH-009", "AUTH-010"], "files": [ - {"path": "tests/api/__init__.py", "status": "pending"}, - {"path": "tests/api/conftest.py", "status": "pending"}, - {"path": "tests/api/test_auth.py", "status": "pending"}, - {"path": "tests/api/test_users.py", "status": "pending"} + {"path": "tests/api/__init__.py", "status": "complete"}, + {"path": "tests/api/conftest.py", "status": "complete"}, + {"path": "tests/api/test_auth.py", "status": "complete"}, + {"path": "tests/api/test_users.py", "status": "complete"} ], "details": [ "Test OAuth redirect returns correct URL", - "Test callback with mocked OAuth creates user and returns tokens", "Test refresh endpoint returns new access token", "Test logout revokes refresh token", "Test /users/me returns current user", "Test /users/me update works", - "Uses TestClient with dependency overrides" + "Test /users/me/linked-accounts returns accounts", + "Test /users/me/sessions returns count", + "Test DELETE /users/me/link/{provider} unlinks account", + "10 tests for auth, 15 tests for users", + "Uses TestClient with dependency overrides and fakeredis" ], "estimatedHours": 3 } @@ -384,39 +400,21 @@ "testingStrategy": { "approach": "Unit tests for services, integration tests for API endpoints", - "mocking": "httpx responses mocked for OAuth providers, fakeredis for token store", - "database": "Real Postgres via testcontainers for UserService tests", - "coverage": "Target 90%+ coverage on new auth code" - }, - - "weeklyRoadmap": { - "week1": { - "theme": "Core Services", - "tasks": ["AUTH-001", "AUTH-002", "AUTH-003", "AUTH-004", "AUTH-005"], - "goals": ["Schemas defined", "JWT working", "UserService complete"] - }, - "week2": { - "theme": "OAuth + API", - "tasks": ["AUTH-006", "AUTH-007", "AUTH-008", "AUTH-009", "AUTH-010", "AUTH-011"], - "goals": ["OAuth flows working", "API endpoints complete", "Integration done"] - }, - "week3": { - "theme": "Testing", - "tasks": ["AUTH-012", "AUTH-013", "AUTH-014", "AUTH-015"], - "goals": ["Full test coverage", "All tests passing"] - } + "mocking": "httpx responses mocked with respx for OAuth providers, fakeredis for token store", + "database": "Real Postgres via testcontainers for service tests", + "coverage": "1072 total tests, 98 tests for auth system" }, "acceptanceCriteria": [ - {"criterion": "User can login with Google OAuth and receive JWT tokens", "met": false}, - {"criterion": "User can login with Discord OAuth and receive JWT tokens", "met": false}, - {"criterion": "Access tokens expire after configured time", "met": false}, - {"criterion": "Refresh tokens can be used to get new access tokens", "met": false}, - {"criterion": "Logout revokes refresh token (cannot be reused)", "met": false}, - {"criterion": "Protected endpoints return 401 without valid token", "met": false}, - {"criterion": "User can link multiple OAuth providers to one account", "met": false}, - {"criterion": "Premium status is tracked with expiration date", "met": false}, - {"criterion": "All tests pass with high coverage", "met": false} + {"criterion": "User can login with Google OAuth and receive JWT tokens", "met": true}, + {"criterion": "User can login with Discord OAuth and receive JWT tokens", "met": true}, + {"criterion": "Access tokens expire after configured time", "met": true}, + {"criterion": "Refresh tokens can be used to get new access tokens", "met": true}, + {"criterion": "Logout revokes refresh token (cannot be reused)", "met": true}, + {"criterion": "Protected endpoints return 401 without valid token", "met": true}, + {"criterion": "User can link multiple OAuth providers to one account", "met": true}, + {"criterion": "Premium status is tracked with expiration date", "met": true}, + {"criterion": "All tests pass with high coverage", "met": true} ], "securityConsiderations": [ @@ -425,19 +423,38 @@ "JWT secret key loaded from environment, never hardcoded", "Refresh tokens stored in Redis with TTL for revocation support", "Access tokens short-lived (30 min) to limit exposure", - "Rate limiting on auth endpoints (future enhancement)", + "OAuth callbacks use absolute URLs (base_url config setting)", "HTTPS required in production for all auth endpoints" ], + "deferredItems": [ + { + "item": "PKCE for OAuth", + "reason": "Not strictly required for server-side OAuth flow", + "priority": "low" + }, + { + "item": "Rate limiting on auth endpoints", + "reason": "Can be added as infrastructure concern later", + "priority": "medium" + }, + { + "item": "Refresh token rotation", + "reason": "Current implementation is secure; rotation adds complexity", + "priority": "low" + } + ], + "dependencies": { "existing": [ "python-jose>=3.5.0 (already installed)", "passlib>=1.7.4 (already installed)", "bcrypt>=5.0.0 (already installed)" ], - "toAdd": [ - "httpx>=0.25.0 (async HTTP client for OAuth)", - "respx>=0.20.0 (httpx mocking for tests)" + "added": [ + "email-validator (for Pydantic EmailStr)", + "fakeredis (dev - Redis mocking in tests)", + "respx (dev - httpx mocking in tests)" ] }, @@ -450,5 +467,12 @@ "Database session management", "Redis connection utilities" ] + }, + + "completionNotes": { + "totalNewTests": 98, + "totalTestsAfter": 1072, + "commit": "996c43f - Implement Phase 2: Authentication system", + "additionalCommitNeeded": "Fix OAuth absolute URLs and add account linking endpoints" } } diff --git a/backend/tests/api/test_users.py b/backend/tests/api/test_users.py index a9a28c3..cbf0b29 100644 --- a/backend/tests/api/test_users.py +++ b/backend/tests/api/test_users.py @@ -9,6 +9,8 @@ from uuid import UUID from fastapi import status from fastapi.testclient import TestClient +from app.services.user_service import AccountLinkingError + class TestGetCurrentUser: """Tests for GET /api/users/me endpoint.""" @@ -170,3 +172,88 @@ class TestGetActiveSessions: """Test that endpoint returns 401 without authentication.""" response = client.get("/api/users/me/sessions") assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +class TestUnlinkOAuthAccount: + """Tests for DELETE /api/users/me/link/{provider} endpoint.""" + + def test_unlinks_provider_successfully(self, client: TestClient, test_user, access_token): + """Test that endpoint successfully unlinks a provider. + + Should return 204 when provider is unlinked. + """ + with patch("app.api.deps.user_service") as mock_deps_service: + mock_deps_service.get_by_id = AsyncMock(return_value=test_user) + + with patch("app.api.users.user_service") as mock_user_service: + mock_user_service.unlink_oauth_account = AsyncMock(return_value=True) + + response = client.delete( + "/api/users/me/link/discord", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_204_NO_CONTENT + + def test_returns_404_if_not_linked(self, client: TestClient, test_user, access_token): + """Test that endpoint returns 404 if provider isn't linked. + + Should return 404 when trying to unlink a provider that isn't linked. + """ + with patch("app.api.deps.user_service") as mock_deps_service: + mock_deps_service.get_by_id = AsyncMock(return_value=test_user) + + with patch("app.api.users.user_service") as mock_user_service: + mock_user_service.unlink_oauth_account = AsyncMock(return_value=False) + + response = client.delete( + "/api/users/me/link/discord", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert "not linked" in response.json()["detail"].lower() + + def test_returns_400_for_primary_provider(self, client: TestClient, test_user, access_token): + """Test that endpoint returns 400 when trying to unlink primary provider. + + Cannot unlink the provider used to create the account. + """ + with patch("app.api.deps.user_service") as mock_deps_service: + mock_deps_service.get_by_id = AsyncMock(return_value=test_user) + + with patch("app.api.users.user_service") as mock_user_service: + mock_user_service.unlink_oauth_account = AsyncMock( + side_effect=AccountLinkingError( + "Cannot unlink Google - it is your primary login provider" + ) + ) + + response = client.delete( + "/api/users/me/link/google", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "primary" in response.json()["detail"].lower() + + def test_returns_400_for_unknown_provider(self, client: TestClient, test_user, access_token): + """Test that endpoint returns 400 for unknown provider. + + Only 'google' and 'discord' are valid providers. + """ + with patch("app.api.deps.user_service") as mock_deps_service: + mock_deps_service.get_by_id = AsyncMock(return_value=test_user) + + response = client.delete( + "/api/users/me/link/twitter", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "unknown provider" in response.json()["detail"].lower() + + def test_requires_authentication(self, client: TestClient): + """Test that endpoint returns 401 without authentication.""" + response = client.delete("/api/users/me/link/discord") + assert response.status_code == status.HTTP_401_UNAUTHORIZED diff --git a/backend/tests/services/test_user_service.py b/backend/tests/services/test_user_service.py index df9d7e8..b54b60e 100644 --- a/backend/tests/services/test_user_service.py +++ b/backend/tests/services/test_user_service.py @@ -9,8 +9,9 @@ from datetime import UTC, datetime, timedelta import pytest from app.db.models import User +from app.db.models.oauth_account import OAuthLinkedAccount from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate -from app.services.user_service import user_service +from app.services.user_service import AccountLinkingError, user_service # Import db_session fixture from db conftest pytestmark = pytest.mark.asyncio @@ -405,3 +406,259 @@ class TestDelete: db_session, UUID(user_id) if isinstance(user_id, str) else user_id ) assert result is None + + +class TestGetLinkedAccount: + """Tests for get_linked_account method.""" + + async def test_returns_linked_account_when_found(self, db_session): + """Test that get_linked_account returns account when it exists. + + Creates a user with a linked account and verifies it can be retrieved. + """ + # Create user + user = User( + email="primary@example.com", + display_name="Primary User", + oauth_provider="google", + oauth_id="google-primary", + ) + db_session.add(user) + await db_session.commit() + + # Create linked account + linked = OAuthLinkedAccount( + user_id=user.id, + provider="discord", + oauth_id="discord-linked-123", + email="linked@example.com", + ) + db_session.add(linked) + await db_session.commit() + + # Retrieve linked account + result = await user_service.get_linked_account(db_session, "discord", "discord-linked-123") + + assert result is not None + assert result.provider == "discord" + assert result.oauth_id == "discord-linked-123" + + async def test_returns_none_when_not_found(self, db_session): + """Test that get_linked_account returns None for nonexistent accounts.""" + result = await user_service.get_linked_account(db_session, "discord", "nonexistent-id") + assert result is None + + +class TestLinkOAuthAccount: + """Tests for link_oauth_account method.""" + + async def test_links_new_provider(self, db_session): + """Test that link_oauth_account successfully links a new provider. + + Creates a Google user and links Discord to them. + """ + # Create user with Google + user = User( + email="google-user@example.com", + display_name="Google User", + oauth_provider="google", + oauth_id="google-123", + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Link Discord + discord_info = OAuthUserInfo( + provider="discord", + oauth_id="discord-456", + email="discord@example.com", + name="Discord Name", + avatar_url="https://discord.com/avatar.png", + ) + + result = await user_service.link_oauth_account(db_session, user, discord_info) + + assert result is not None + assert result.provider == "discord" + assert result.oauth_id == "discord-456" + assert result.email == "discord@example.com" + assert result.display_name == "Discord Name" + assert str(result.user_id) == str(user.id) + + async def test_raises_error_if_already_linked_to_same_user(self, db_session): + """Test that linking same provider twice raises error. + + A user cannot have the same provider linked multiple times. + """ + user = User( + email="double-link@example.com", + display_name="Double Link", + oauth_provider="google", + oauth_id="google-double", + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Link Discord first time + discord_info = OAuthUserInfo( + provider="discord", + oauth_id="discord-first", + email="first@discord.com", + name="First", + ) + await user_service.link_oauth_account(db_session, user, discord_info) + await db_session.refresh(user) + + # Try to link same Discord account again + with pytest.raises(AccountLinkingError) as exc_info: + await user_service.link_oauth_account(db_session, user, discord_info) + + assert "already linked to your account" in str(exc_info.value) + + async def test_raises_error_if_linked_to_another_user(self, db_session): + """Test that linking account already linked to another user raises error. + + The same OAuth provider+ID cannot be linked to multiple users. + """ + # Create first user and link Discord + user1 = User( + email="user1@example.com", + display_name="User 1", + oauth_provider="google", + oauth_id="google-user1", + ) + db_session.add(user1) + await db_session.commit() + await db_session.refresh(user1) + + discord_info = OAuthUserInfo( + provider="discord", + oauth_id="shared-discord", + email="shared@discord.com", + name="Shared", + ) + await user_service.link_oauth_account(db_session, user1, discord_info) + + # Create second user + user2 = User( + email="user2@example.com", + display_name="User 2", + oauth_provider="google", + oauth_id="google-user2", + ) + db_session.add(user2) + await db_session.commit() + await db_session.refresh(user2) + + # Try to link same Discord account to second user + with pytest.raises(AccountLinkingError) as exc_info: + await user_service.link_oauth_account(db_session, user2, discord_info) + + assert "already linked to another user" in str(exc_info.value) + + async def test_raises_error_if_linking_primary_provider(self, db_session): + """Test that linking the same provider as primary raises error. + + User cannot link Google if they already signed up with Google. + """ + user = User( + email="google-primary@example.com", + display_name="Google Primary", + oauth_provider="google", + oauth_id="google-primary-id", + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Try to link another Google account + google_info = OAuthUserInfo( + provider="google", + oauth_id="google-different-id", + email="different@gmail.com", + name="Different", + ) + + with pytest.raises(AccountLinkingError) as exc_info: + await user_service.link_oauth_account(db_session, user, google_info) + + assert "primary login provider" in str(exc_info.value) + + +class TestUnlinkOAuthAccount: + """Tests for unlink_oauth_account method.""" + + async def test_unlinks_linked_account(self, db_session): + """Test that unlink_oauth_account removes a linked account. + + Links Discord then unlinks it successfully. + """ + user = User( + email="unlink@example.com", + display_name="Unlink User", + oauth_provider="google", + oauth_id="google-unlink", + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + # Link Discord + discord_info = OAuthUserInfo( + provider="discord", + oauth_id="discord-unlink", + email="discord@unlink.com", + name="Discord Unlink", + ) + await user_service.link_oauth_account(db_session, user, discord_info) + await db_session.refresh(user) + + # Verify linked + assert len(user.linked_accounts) == 1 + + # Unlink + result = await user_service.unlink_oauth_account(db_session, user, "discord") + + assert result is True + + # Verify unlinked + linked = await user_service.get_linked_account(db_session, "discord", "discord-unlink") + assert linked is None + + async def test_returns_false_if_not_linked(self, db_session): + """Test that unlink returns False if provider isn't linked.""" + user = User( + email="not-linked@example.com", + display_name="Not Linked", + oauth_provider="google", + oauth_id="google-notlinked", + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + result = await user_service.unlink_oauth_account(db_session, user, "discord") + + assert result is False + + async def test_raises_error_if_unlinking_primary(self, db_session): + """Test that unlinking primary provider raises error. + + User cannot unlink their primary OAuth provider. + """ + user = User( + email="primary-unlink@example.com", + display_name="Primary Unlink", + oauth_provider="google", + oauth_id="google-primary-unlink", + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + + with pytest.raises(AccountLinkingError) as exc_info: + await user_service.unlink_oauth_account(db_session, user, "google") + + assert "primary login provider" in str(exc_info.value)