Implement UserRepository pattern with dependency injection

- Add UserRepository and LinkedAccountRepository protocols to protocols.py
- Add UserEntry and LinkedAccountEntry DTOs for service layer decoupling
- Implement PostgresUserRepository and PostgresLinkedAccountRepository
- Refactor UserService to use constructor-injected repositories
- Add get_user_service factory and UserServiceDep to API deps
- Update auth.py and users.py endpoints to use UserServiceDep
- Rewrite tests to use FastAPI dependency overrides (no monkey patching)

This follows the established repository pattern used by DeckService and
CollectionService, enabling future offline fork support.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2026-01-30 07:30:16 -06:00
parent f6e8ab5f67
commit 7fcb86ff51
12 changed files with 1316 additions and 454 deletions

View File

@ -31,7 +31,7 @@ import secrets
from fastapi import APIRouter, HTTPException, Query, status from fastapi import APIRouter, HTTPException, Query, status
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from app.api.deps import CurrentUser, DbSession from app.api.deps import CurrentUser, UserServiceDep
from app.config import settings from app.config import settings
from app.db.redis import get_redis from app.db.redis import get_redis
from app.schemas.auth import RefreshTokenRequest, TokenResponse from app.schemas.auth import RefreshTokenRequest, TokenResponse
@ -45,7 +45,7 @@ from app.services.jwt_service import (
from app.services.oauth.discord import DiscordOAuthError, discord_oauth from app.services.oauth.discord import DiscordOAuthError, discord_oauth
from app.services.oauth.google import GoogleOAuthError, google_oauth from app.services.oauth.google import GoogleOAuthError, google_oauth
from app.services.token_store import token_store from app.services.token_store import token_store
from app.services.user_service import AccountLinkingError, user_service from app.services.user_service import AccountLinkingError
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
@ -150,7 +150,7 @@ async def google_auth_redirect(
@router.get("/google/callback") @router.get("/google/callback")
async def google_auth_callback( async def google_auth_callback(
db: DbSession, user_service: UserServiceDep,
code: str = Query(..., description="Authorization code from Google"), code: str = Query(..., description="Authorization code from Google"),
state: str = Query(..., description="State parameter for CSRF validation"), state: str = Query(..., description="State parameter for CSRF validation"),
) -> TokenResponse: ) -> TokenResponse:
@ -182,10 +182,10 @@ async def google_auth_callback(
user_info = await google_oauth.get_user_info(code, oauth_callback) user_info = await google_oauth.get_user_info(code, oauth_callback)
# Get or create user # Get or create user
user, created = await user_service.get_or_create_from_oauth(db, user_info) user, _created = await user_service.get_or_create_from_oauth(user_info)
# Update last login # Update last login
await user_service.update_last_login(db, user) await user_service.update_last_login(user.id)
# Create tokens # Create tokens
return await _create_tokens_for_user(user.id) return await _create_tokens_for_user(user.id)
@ -243,7 +243,7 @@ async def discord_auth_redirect(
@router.get("/discord/callback") @router.get("/discord/callback")
async def discord_auth_callback( async def discord_auth_callback(
db: DbSession, user_service: UserServiceDep,
code: str = Query(..., description="Authorization code from Discord"), code: str = Query(..., description="Authorization code from Discord"),
state: str = Query(..., description="State parameter for CSRF validation"), state: str = Query(..., description="State parameter for CSRF validation"),
) -> RedirectResponse: ) -> RedirectResponse:
@ -276,10 +276,10 @@ async def discord_auth_callback(
user_info = await discord_oauth.get_user_info(code, oauth_callback) user_info = await discord_oauth.get_user_info(code, oauth_callback)
# Get or create user # Get or create user
user, created = await user_service.get_or_create_from_oauth(db, user_info) user, _created = await user_service.get_or_create_from_oauth(user_info)
# Update last login # Update last login
await user_service.update_last_login(db, user) await user_service.update_last_login(user.id)
# Create tokens # Create tokens
tokens = await _create_tokens_for_user(user.id) tokens = await _create_tokens_for_user(user.id)
@ -307,7 +307,7 @@ async def discord_auth_callback(
@router.post("/refresh", response_model=TokenResponse) @router.post("/refresh", response_model=TokenResponse)
async def refresh_tokens( async def refresh_tokens(
db: DbSession, user_service: UserServiceDep,
request: RefreshTokenRequest, request: RefreshTokenRequest,
) -> TokenResponse: ) -> TokenResponse:
"""Refresh access token using refresh token. """Refresh access token using refresh token.
@ -344,7 +344,7 @@ async def refresh_tokens(
) )
# Verify user still exists # Verify user still exists
user = await user_service.get_by_id(db, user_id) user = await user_service.get_by_id(user_id)
if user is None: if user is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -489,7 +489,7 @@ async def google_link_redirect(
@router.get("/link/google/callback") @router.get("/link/google/callback")
async def google_link_callback( async def google_link_callback(
db: DbSession, user_service: UserServiceDep,
code: str = Query(..., description="Authorization code from Google"), code: str = Query(..., description="Authorization code from Google"),
state: str = Query(..., description="State parameter for CSRF validation"), state: str = Query(..., description="State parameter for CSRF validation"),
) -> RedirectResponse: ) -> RedirectResponse:
@ -523,7 +523,7 @@ async def google_link_callback(
from uuid import UUID from uuid import UUID
user_id = UUID(user_id_str) user_id = UUID(user_id_str)
user = await user_service.get_by_id(db, user_id) user = await user_service.get_by_id(user_id)
if user is None: if user is None:
return RedirectResponse( return RedirectResponse(
url=f"{redirect_uri}?error=user_not_found", url=f"{redirect_uri}?error=user_not_found",
@ -531,7 +531,7 @@ async def google_link_callback(
) )
# Link the account # Link the account
await user_service.link_oauth_account(db, user, oauth_info) await user_service.link_oauth_account(user.id, user.oauth_provider, oauth_info)
return RedirectResponse( return RedirectResponse(
url=f"{redirect_uri}?linked=google", url=f"{redirect_uri}?linked=google",
@ -600,7 +600,7 @@ async def discord_link_redirect(
@router.get("/link/discord/callback") @router.get("/link/discord/callback")
async def discord_link_callback( async def discord_link_callback(
db: DbSession, user_service: UserServiceDep,
code: str = Query(..., description="Authorization code from Discord"), code: str = Query(..., description="Authorization code from Discord"),
state: str = Query(..., description="State parameter for CSRF validation"), state: str = Query(..., description="State parameter for CSRF validation"),
) -> RedirectResponse: ) -> RedirectResponse:
@ -634,7 +634,7 @@ async def discord_link_callback(
from uuid import UUID from uuid import UUID
user_id = UUID(user_id_str) user_id = UUID(user_id_str)
user = await user_service.get_by_id(db, user_id) user = await user_service.get_by_id(user_id)
if user is None: if user is None:
return RedirectResponse( return RedirectResponse(
url=f"{redirect_uri}?error=user_not_found", url=f"{redirect_uri}?error=user_not_found",
@ -642,7 +642,7 @@ async def discord_link_callback(
) )
# Link the account # Link the account
await user_service.link_oauth_account(db, user, oauth_info) await user_service.link_oauth_account(user.id, user.oauth_provider, oauth_info)
return RedirectResponse( return RedirectResponse(
url=f"{redirect_uri}?linked=discord", url=f"{redirect_uri}?linked=discord",

View File

@ -28,6 +28,7 @@ from typing import Annotated
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings from app.config import settings
@ -35,13 +36,15 @@ from app.db import get_session
from app.db.models import User from app.db.models import User
from app.repositories.postgres.collection import PostgresCollectionRepository from app.repositories.postgres.collection import PostgresCollectionRepository
from app.repositories.postgres.deck import PostgresDeckRepository from app.repositories.postgres.deck import PostgresDeckRepository
from app.repositories.postgres.linked_account import PostgresLinkedAccountRepository
from app.repositories.postgres.user import PostgresUserRepository
from app.services.card_service import CardService, get_card_service from app.services.card_service import CardService, get_card_service
from app.services.collection_service import CollectionService from app.services.collection_service import CollectionService
from app.services.deck_service import DeckService from app.services.deck_service import DeckService
from app.services.game_service import GameService, game_service from app.services.game_service import GameService, game_service
from app.services.game_state_manager import GameStateManager, game_state_manager from app.services.game_state_manager import GameStateManager, game_state_manager
from app.services.jwt_service import verify_access_token from app.services.jwt_service import verify_access_token
from app.services.user_service import user_service from app.services.user_service import UserService
# OAuth2 scheme for extracting Bearer token from Authorization header # OAuth2 scheme for extracting Bearer token from Authorization header
# tokenUrl is for OpenAPI docs - points to where tokens are obtained # tokenUrl is for OpenAPI docs - points to where tokens are obtained
@ -148,8 +151,9 @@ async def get_current_user(
if user_id is None: if user_id is None:
raise credentials_exception raise credentials_exception
# Fetch user from database # Fetch user from database (direct query for auth - not business logic)
user = await user_service.get_by_id(db, user_id) result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None: if user is None:
raise credentials_exception raise credentials_exception
@ -187,7 +191,9 @@ async def get_optional_user(
if user_id is None: if user_id is None:
return None return None
return await user_service.get_by_id(db, user_id) # Direct query for auth - not business logic
result = await db.execute(select(User).where(User.id == user_id))
return result.scalar_one_or_none()
async def get_current_premium_user( async def get_current_premium_user(
@ -333,6 +339,31 @@ def get_game_state_manager_dep() -> GameStateManager:
return game_state_manager return game_state_manager
def get_user_service(
db: Annotated[AsyncSession, Depends(get_db)],
) -> UserService:
"""Get UserService with PostgreSQL repositories.
Creates a UserService instance with user and linked account repositories.
Args:
db: Database session from request.
Returns:
UserService configured for PostgreSQL.
Example:
@router.post("/auth/google/callback")
async def google_callback(
user_service: UserService = Depends(get_user_service),
):
user, created = await user_service.get_or_create_from_oauth(oauth_info)
"""
user_repo = PostgresUserRepository(db)
linked_repo = PostgresLinkedAccountRepository(db)
return UserService(user_repo, linked_repo)
# ============================================================================= # =============================================================================
# Type Aliases for Cleaner Endpoint Signatures # Type Aliases for Cleaner Endpoint Signatures
# ============================================================================= # =============================================================================
@ -351,6 +382,7 @@ CollectionServiceDep = Annotated[CollectionService, Depends(get_collection_servi
CardServiceDep = Annotated[CardService, Depends(get_card_service_dep)] CardServiceDep = Annotated[CardService, Depends(get_card_service_dep)]
GameServiceDep = Annotated[GameService, Depends(get_game_service_dep)] GameServiceDep = Annotated[GameService, Depends(get_game_service_dep)]
GameStateManagerDep = Annotated[GameStateManager, Depends(get_game_state_manager_dep)] GameStateManagerDep = Annotated[GameStateManager, Depends(get_game_state_manager_dep)]
UserServiceDep = Annotated[UserService, Depends(get_user_service)]
# Admin authentication # Admin authentication
AdminAuth = Annotated[None, Depends(verify_admin_token)] AdminAuth = Annotated[None, Depends(verify_admin_token)]

View File

@ -26,12 +26,12 @@ Example:
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.api.deps import CurrentUser, DbSession, DeckServiceDep from app.api.deps import CurrentUser, DeckServiceDep, UserServiceDep
from app.schemas.deck import DeckResponse, StarterDeckSelectRequest, StarterStatusResponse from app.schemas.deck import DeckResponse, StarterDeckSelectRequest, StarterStatusResponse
from app.schemas.user import UserResponse, UserUpdate from app.schemas.user import UserResponse, UserUpdate
from app.services.deck_service import DeckLimitExceededError, StarterAlreadySelectedError from app.services.deck_service import DeckLimitExceededError, StarterAlreadySelectedError
from app.services.token_store import token_store from app.services.token_store import token_store
from app.services.user_service import AccountLinkingError, user_service from app.services.user_service import AccountLinkingError
router = APIRouter(prefix="/users", tags=["users"]) router = APIRouter(prefix="/users", tags=["users"])
@ -65,7 +65,7 @@ async def get_current_user_profile(
@router.patch("/me", response_model=UserResponse) @router.patch("/me", response_model=UserResponse)
async def update_current_user_profile( async def update_current_user_profile(
user: CurrentUser, user: CurrentUser,
db: DbSession, user_service: UserServiceDep,
update_data: UserUpdate, update_data: UserUpdate,
) -> UserResponse: ) -> UserResponse:
"""Update the current user's profile. """Update the current user's profile.
@ -78,7 +78,7 @@ async def update_current_user_profile(
Returns: Returns:
Updated user profile. Updated user profile.
""" """
updated_user = await user_service.update(db, user, update_data) updated_user = await user_service.update(user.id, update_data)
return UserResponse.model_validate(updated_user) return UserResponse.model_validate(updated_user)
@ -139,7 +139,7 @@ async def get_active_sessions(
@router.delete("/me/link/{provider}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/me/link/{provider}", status_code=status.HTTP_204_NO_CONTENT)
async def unlink_oauth_account( async def unlink_oauth_account(
user: CurrentUser, user: CurrentUser,
db: DbSession, user_service: UserServiceDep,
provider: str, provider: str,
) -> None: ) -> None:
"""Unlink an OAuth provider from the current user's account. """Unlink an OAuth provider from the current user's account.
@ -162,7 +162,7 @@ async def unlink_oauth_account(
) )
try: try:
unlinked = await user_service.unlink_oauth_account(db, user, provider) unlinked = await user_service.unlink_oauth_account(user.id, user.oauth_provider, provider)
if not unlinked: if not unlinked:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,

View File

@ -10,11 +10,12 @@ The protocol pattern enables:
- Offline fork support without rewriting service layer - Offline fork support without rewriting service layer
Usage: Usage:
from app.repositories import CollectionRepository, DeckRepository from app.repositories import CollectionRepository, DeckRepository, UserRepository
from app.repositories.postgres import PostgresCollectionRepository from app.repositories.postgres import PostgresCollectionRepository, PostgresUserRepository
# In production (dependency injection) # In production (dependency injection)
repo = PostgresCollectionRepository(db_session) repo = PostgresCollectionRepository(db_session)
user_repo = PostgresUserRepository(db_session)
# In tests # In tests
repo = MockCollectionRepository() repo = MockCollectionRepository()
@ -23,9 +24,13 @@ Usage:
from app.repositories.protocols import ( from app.repositories.protocols import (
CollectionRepository, CollectionRepository,
DeckRepository, DeckRepository,
LinkedAccountRepository,
UserRepository,
) )
__all__ = [ __all__ = [
"CollectionRepository", "CollectionRepository",
"DeckRepository", "DeckRepository",
"LinkedAccountRepository",
"UserRepository",
] ]

View File

@ -8,11 +8,15 @@ Usage:
from app.repositories.postgres import ( from app.repositories.postgres import (
PostgresCollectionRepository, PostgresCollectionRepository,
PostgresDeckRepository, PostgresDeckRepository,
PostgresLinkedAccountRepository,
PostgresUserRepository,
) )
# Create repository with database session # Create repository with database session
collection_repo = PostgresCollectionRepository(db_session) collection_repo = PostgresCollectionRepository(db_session)
deck_repo = PostgresDeckRepository(db_session) deck_repo = PostgresDeckRepository(db_session)
user_repo = PostgresUserRepository(db_session)
linked_repo = PostgresLinkedAccountRepository(db_session)
# Use via service layer # Use via service layer
service = CollectionService(collection_repo) service = CollectionService(collection_repo)
@ -20,8 +24,12 @@ Usage:
from app.repositories.postgres.collection import PostgresCollectionRepository from app.repositories.postgres.collection import PostgresCollectionRepository
from app.repositories.postgres.deck import PostgresDeckRepository from app.repositories.postgres.deck import PostgresDeckRepository
from app.repositories.postgres.linked_account import PostgresLinkedAccountRepository
from app.repositories.postgres.user import PostgresUserRepository
__all__ = [ __all__ = [
"PostgresCollectionRepository", "PostgresCollectionRepository",
"PostgresDeckRepository", "PostgresDeckRepository",
"PostgresLinkedAccountRepository",
"PostgresUserRepository",
] ]

View File

@ -0,0 +1,149 @@
"""PostgreSQL implementation of LinkedAccountRepository.
This module provides the PostgreSQL-specific implementation of the
LinkedAccountRepository protocol using SQLAlchemy async sessions.
Example:
async with get_db_session() as db:
repo = PostgresLinkedAccountRepository(db)
accounts = await repo.get_by_user(user_id)
"""
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.oauth_account import OAuthLinkedAccount
from app.repositories.protocols import LinkedAccountEntry
def _to_dto(model: OAuthLinkedAccount) -> LinkedAccountEntry:
"""Convert ORM model to DTO."""
return LinkedAccountEntry(
id=model.id,
user_id=UUID(model.user_id) if isinstance(model.user_id, str) else model.user_id,
provider=model.provider,
oauth_id=model.oauth_id,
email=model.email,
display_name=model.display_name,
avatar_url=model.avatar_url,
linked_at=model.linked_at,
)
class PostgresLinkedAccountRepository:
"""PostgreSQL implementation of LinkedAccountRepository.
Uses SQLAlchemy async sessions for database access.
Attributes:
_db: The async database session.
"""
def __init__(self, db: AsyncSession) -> None:
"""Initialize with database session.
Args:
db: SQLAlchemy async session.
"""
self._db = db
async def get_by_provider(
self,
provider: str,
oauth_id: str,
) -> LinkedAccountEntry | None:
"""Get a linked account by provider and OAuth ID.
Args:
provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider.
Returns:
LinkedAccountEntry if found, None otherwise.
"""
result = await self._db.execute(
select(OAuthLinkedAccount).where(
OAuthLinkedAccount.provider == provider,
OAuthLinkedAccount.oauth_id == oauth_id,
)
)
model = result.scalar_one_or_none()
return _to_dto(model) if model else None
async def get_by_user(self, user_id: UUID) -> list[LinkedAccountEntry]:
"""Get all linked accounts for a user.
Args:
user_id: The user's UUID.
Returns:
List of LinkedAccountEntry, ordered by provider.
"""
result = await self._db.execute(
select(OAuthLinkedAccount)
.where(OAuthLinkedAccount.user_id == str(user_id))
.order_by(OAuthLinkedAccount.provider)
)
return [_to_dto(model) for model in result.scalars().all()]
async def create(
self,
user_id: UUID,
provider: str,
oauth_id: str,
email: str | None = None,
display_name: str | None = None,
avatar_url: str | None = None,
) -> LinkedAccountEntry:
"""Link an OAuth provider to a user account.
Args:
user_id: The user's UUID.
provider: OAuth provider name.
oauth_id: Unique ID from the OAuth provider.
email: Email from the OAuth provider.
display_name: Display name from the OAuth provider.
avatar_url: Avatar URL from the OAuth provider.
Returns:
The created LinkedAccountEntry.
"""
linked_account = OAuthLinkedAccount(
user_id=str(user_id),
provider=provider,
oauth_id=oauth_id,
email=email,
display_name=display_name,
avatar_url=avatar_url,
)
self._db.add(linked_account)
await self._db.commit()
await self._db.refresh(linked_account)
return _to_dto(linked_account)
async def delete(self, user_id: UUID, provider: str) -> bool:
"""Unlink an OAuth provider from a user account.
Args:
user_id: The user's UUID.
provider: OAuth provider name to unlink.
Returns:
True if deleted, False if not found.
"""
result = await self._db.execute(
select(OAuthLinkedAccount).where(
OAuthLinkedAccount.user_id == str(user_id),
OAuthLinkedAccount.provider == provider,
)
)
linked_account = result.scalar_one_or_none()
if linked_account is None:
return False
await self._db.delete(linked_account)
await self._db.commit()
return True

View File

@ -0,0 +1,243 @@
"""PostgreSQL implementation of UserRepository.
This module provides the PostgreSQL-specific implementation of the
UserRepository protocol using SQLAlchemy async sessions.
Example:
async with get_db_session() as db:
repo = PostgresUserRepository(db)
user = await repo.get_by_id(user_id)
"""
from datetime import UTC, datetime
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.user import User
from app.repositories.protocols import UNSET, UserEntry
def _to_dto(model: User) -> UserEntry:
"""Convert ORM model to DTO."""
return UserEntry(
id=model.id,
email=model.email,
display_name=model.display_name,
avatar_url=model.avatar_url,
oauth_provider=model.oauth_provider,
oauth_id=model.oauth_id,
is_premium=model.is_premium,
premium_until=model.premium_until,
last_login=model.last_login,
created_at=model.created_at,
updated_at=model.updated_at,
)
class PostgresUserRepository:
"""PostgreSQL implementation of UserRepository.
Uses SQLAlchemy async sessions for database access.
Attributes:
_db: The async database session.
"""
def __init__(self, db: AsyncSession) -> None:
"""Initialize with database session.
Args:
db: SQLAlchemy async session.
"""
self._db = db
async def get_by_id(self, user_id: UUID) -> UserEntry | None:
"""Get a user by their ID.
Args:
user_id: The user's UUID.
Returns:
UserEntry if found, None otherwise.
"""
result = await self._db.execute(select(User).where(User.id == user_id))
model = result.scalar_one_or_none()
return _to_dto(model) if model else None
async def get_by_email(self, email: str) -> UserEntry | None:
"""Get a user by their email address.
Args:
email: The user's email address.
Returns:
UserEntry if found, None otherwise.
"""
result = await self._db.execute(select(User).where(User.email == email))
model = result.scalar_one_or_none()
return _to_dto(model) if model else None
async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None:
"""Get a user by their OAuth provider and ID.
Args:
provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider.
Returns:
UserEntry if found, None otherwise.
"""
result = await self._db.execute(
select(User).where(
User.oauth_provider == provider,
User.oauth_id == oauth_id,
)
)
model = result.scalar_one_or_none()
return _to_dto(model) if model else None
async def create(
self,
email: str,
display_name: str,
oauth_provider: str,
oauth_id: str,
avatar_url: str | None = None,
) -> UserEntry:
"""Create a new user.
Args:
email: User's email address.
display_name: Public display name.
oauth_provider: OAuth provider name.
oauth_id: Unique ID from the OAuth provider.
avatar_url: Optional avatar URL.
Returns:
The created UserEntry.
"""
user = User(
email=email,
display_name=display_name,
oauth_provider=oauth_provider,
oauth_id=oauth_id,
avatar_url=avatar_url,
)
self._db.add(user)
await self._db.commit()
await self._db.refresh(user)
return _to_dto(user)
async def update(
self,
user_id: UUID,
display_name: str | None = None,
avatar_url: str | None = UNSET, # type: ignore[assignment]
oauth_provider: str | None = None,
oauth_id: str | None = None,
) -> UserEntry | None:
"""Update user profile fields.
Only provided (non-None/non-UNSET) fields are updated.
Args:
user_id: The user's UUID.
display_name: New display name (None keeps existing).
avatar_url: New avatar URL (UNSET=keep, None=clear, str=set).
oauth_provider: New OAuth provider (None keeps existing).
oauth_id: New OAuth ID (None keeps existing).
Returns:
Updated UserEntry, or None if user not found.
"""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return None
if display_name is not None:
user.display_name = display_name
if avatar_url is not UNSET:
user.avatar_url = avatar_url
if oauth_provider is not None:
user.oauth_provider = oauth_provider
if oauth_id is not None:
user.oauth_id = oauth_id
await self._db.commit()
await self._db.refresh(user)
return _to_dto(user)
async def update_last_login(self, user_id: UUID) -> UserEntry | None:
"""Update the user's last login timestamp to now.
Args:
user_id: The user's UUID.
Returns:
Updated UserEntry, or None if user not found.
"""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return None
user.last_login = datetime.now(UTC)
await self._db.commit()
await self._db.refresh(user)
return _to_dto(user)
async def update_premium(
self,
user_id: UUID,
is_premium: bool,
premium_until: datetime | None,
) -> UserEntry | None:
"""Update user's premium subscription status.
Args:
user_id: The user's UUID.
is_premium: Whether user has premium.
premium_until: When premium expires, or None if not premium.
Returns:
Updated UserEntry, or None if user not found.
"""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return None
user.is_premium = is_premium
user.premium_until = premium_until
await self._db.commit()
await self._db.refresh(user)
return _to_dto(user)
async def delete(self, user_id: UUID) -> bool:
"""Delete a user account.
This will cascade delete all related data (decks, collection, etc.)
based on database constraints.
Args:
user_id: The user's UUID.
Returns:
True if deleted, False if not found.
"""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return False
await self._db.delete(user)
await self._db.commit()
return True

View File

@ -102,6 +102,46 @@ class DeckEntry:
updated_at: datetime updated_at: datetime
@dataclass
class UserEntry:
"""Storage-agnostic representation of a user account.
This DTO decouples the service layer from the ORM model,
enabling different storage backends (PostgreSQL, SQLite, JSON)
to be used interchangeably.
"""
id: UUID
email: str
display_name: str
avatar_url: str | None
oauth_provider: str
oauth_id: str
is_premium: bool
premium_until: datetime | None
last_login: datetime | None
created_at: datetime
updated_at: datetime
@dataclass
class LinkedAccountEntry:
"""Storage-agnostic representation of a linked OAuth account.
Users can link multiple OAuth providers (e.g., Google + Discord)
to a single account for flexible login options.
"""
id: UUID
user_id: UUID
provider: str
oauth_id: str
email: str | None
display_name: str | None
avatar_url: str | None
linked_at: datetime
# ============================================================================= # =============================================================================
# Repository Protocols # Repository Protocols
# ============================================================================= # =============================================================================
@ -342,3 +382,211 @@ class DeckRepository(Protocol):
Tuple of (has_starter, starter_type). Tuple of (has_starter, starter_type).
""" """
... ...
class UserRepository(Protocol):
"""Protocol for user account data access.
Implementations handle storage-specific details (PostgreSQL, SQLite, JSON).
Services use this protocol for business logic without knowing storage details.
Note: Business logic like get_or_create_from_oauth belongs in the service layer,
not in the repository.
"""
async def get_by_id(self, user_id: UUID) -> UserEntry | None:
"""Get a user by their ID.
Args:
user_id: The user's UUID.
Returns:
UserEntry if found, None otherwise.
"""
...
async def get_by_email(self, email: str) -> UserEntry | None:
"""Get a user by their email address.
Args:
email: The user's email address.
Returns:
UserEntry if found, None otherwise.
"""
...
async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None:
"""Get a user by their OAuth provider and ID.
Args:
provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider.
Returns:
UserEntry if found, None otherwise.
"""
...
async def create(
self,
email: str,
display_name: str,
oauth_provider: str,
oauth_id: str,
avatar_url: str | None = None,
) -> UserEntry:
"""Create a new user.
Args:
email: User's email address.
display_name: Public display name.
oauth_provider: OAuth provider name.
oauth_id: Unique ID from the OAuth provider.
avatar_url: Optional avatar URL.
Returns:
The created UserEntry.
"""
...
async def update(
self,
user_id: UUID,
display_name: str | None = None,
avatar_url: str | None = UNSET, # type: ignore[assignment]
oauth_provider: str | None = None,
oauth_id: str | None = None,
) -> UserEntry | None:
"""Update user profile fields.
Only provided (non-None/non-UNSET) fields are updated.
Use UNSET (default) to keep existing value for nullable fields,
or None to explicitly clear them.
Args:
user_id: The user's UUID.
display_name: New display name (None keeps existing).
avatar_url: New avatar URL (UNSET=keep, None=clear, str=set).
oauth_provider: New OAuth provider (None keeps existing).
oauth_id: New OAuth ID (None keeps existing).
Returns:
Updated UserEntry, or None if user not found.
"""
...
async def update_last_login(self, user_id: UUID) -> UserEntry | None:
"""Update the user's last login timestamp to now.
Args:
user_id: The user's UUID.
Returns:
Updated UserEntry, or None if user not found.
"""
...
async def update_premium(
self,
user_id: UUID,
is_premium: bool,
premium_until: datetime | None,
) -> UserEntry | None:
"""Update user's premium subscription status.
Args:
user_id: The user's UUID.
is_premium: Whether user has premium.
premium_until: When premium expires, or None if not premium.
Returns:
Updated UserEntry, or None if user not found.
"""
...
async def delete(self, user_id: UUID) -> bool:
"""Delete a user account.
This will cascade delete all related data (decks, collection, etc.)
based on database constraints.
Args:
user_id: The user's UUID.
Returns:
True if deleted, False if not found.
"""
...
class LinkedAccountRepository(Protocol):
"""Protocol for OAuth linked accounts data access.
Users can link multiple OAuth providers to a single account.
The primary OAuth provider is stored on the User model itself;
additional linked providers are stored as LinkedAccount records.
"""
async def get_by_provider(
self,
provider: str,
oauth_id: str,
) -> LinkedAccountEntry | None:
"""Get a linked account by provider and OAuth ID.
Args:
provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider.
Returns:
LinkedAccountEntry if found, None otherwise.
"""
...
async def get_by_user(self, user_id: UUID) -> list[LinkedAccountEntry]:
"""Get all linked accounts for a user.
Args:
user_id: The user's UUID.
Returns:
List of LinkedAccountEntry, ordered by provider.
"""
...
async def create(
self,
user_id: UUID,
provider: str,
oauth_id: str,
email: str | None = None,
display_name: str | None = None,
avatar_url: str | None = None,
) -> LinkedAccountEntry:
"""Link an OAuth provider to a user account.
Args:
user_id: The user's UUID.
provider: OAuth provider name.
oauth_id: Unique ID from the OAuth provider.
email: Email from the OAuth provider.
display_name: Display name from the OAuth provider.
avatar_url: Avatar URL from the OAuth provider.
Returns:
The created LinkedAccountEntry.
"""
...
async def delete(self, user_id: UUID, provider: str) -> bool:
"""Unlink an OAuth provider from a user account.
Args:
user_id: The user's UUID.
provider: OAuth provider name to unlink.
Returns:
True if deleted, False if not found.
"""
...

View File

@ -1,32 +1,37 @@
"""User service for Mantimon TCG. """User service for Mantimon TCG.
This module provides async CRUD operations for user accounts, This module provides business logic for user account operations,
including OAuth-based user creation and premium status management. including OAuth-based user creation, account linking, and premium
status management.
All database operations use async SQLAlchemy sessions. The service layer contains business logic while repositories handle
pure data access. This separation enables testing and different
storage backends.
Example: Example:
from app.services.user_service import user_service from app.services.user_service import UserService
from app.repositories.postgres import PostgresUserRepository, PostgresLinkedAccountRepository
# Get user by ID # Create service with injected repositories
user = await user_service.get_by_id(db, user_id) user_repo = PostgresUserRepository(db)
linked_repo = PostgresLinkedAccountRepository(db)
service = UserService(user_repo, linked_repo)
# Create from OAuth # Use service
user = await user_service.create_from_oauth(db, oauth_info) user = await service.get_by_id(user_id)
user, created = await service.get_or_create_from_oauth(oauth_info)
# Update premium status
user = await user_service.update_premium(db, user_id, premium_until)
""" """
from datetime import UTC, datetime from datetime import datetime
from uuid import UUID from uuid import UUID
from sqlalchemy import select from app.repositories.protocols import (
from sqlalchemy.ext.asyncio import AsyncSession LinkedAccountEntry,
LinkedAccountRepository,
from app.db.models.oauth_account import OAuthLinkedAccount UserEntry,
from app.db.models.user import User UserRepository,
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate )
from app.schemas.user import OAuthUserInfo, UserUpdate
class AccountLinkingError(Exception): class AccountLinkingError(Exception):
@ -38,117 +43,119 @@ class AccountLinkingError(Exception):
class UserService: class UserService:
"""Service for user account operations. """Service for user account operations.
Provides async methods for user CRUD, OAuth-based creation, Provides business logic for user CRUD, OAuth-based creation,
and premium subscription management. account linking, and premium subscription management.
Attributes:
_user_repo: Repository for user data access.
_linked_repo: Repository for linked account data access.
""" """
async def get_by_id(self, db: AsyncSession, user_id: UUID) -> User | None: def __init__(
self,
user_repository: UserRepository,
linked_account_repository: LinkedAccountRepository,
) -> None:
"""Initialize with repository dependencies.
Args:
user_repository: Repository for user data access.
linked_account_repository: Repository for linked account data access.
"""
self._user_repo = user_repository
self._linked_repo = linked_account_repository
async def get_by_id(self, user_id: UUID) -> UserEntry | None:
"""Get a user by their ID. """Get a user by their ID.
Args: Args:
db: Async database session.
user_id: The user's UUID. user_id: The user's UUID.
Returns: Returns:
User if found, None otherwise. UserEntry if found, None otherwise.
Example: Example:
user = await user_service.get_by_id(db, user_id) user = await service.get_by_id(user_id)
if user: if user:
print(f"Found user: {user.display_name}") print(f"Found user: {user.display_name}")
""" """
result = await db.execute(select(User).where(User.id == user_id)) return await self._user_repo.get_by_id(user_id)
return result.scalar_one_or_none()
async def get_by_email(self, db: AsyncSession, email: str) -> User | None: async def get_by_email(self, email: str) -> UserEntry | None:
"""Get a user by their email address. """Get a user by their email address.
Args: Args:
db: Async database session.
email: The user's email address. email: The user's email address.
Returns: Returns:
User if found, None otherwise. UserEntry if found, None otherwise.
Example: Example:
user = await user_service.get_by_email(db, "player@example.com") user = await service.get_by_email("player@example.com")
""" """
result = await db.execute(select(User).where(User.email == email)) return await self._user_repo.get_by_email(email)
return result.scalar_one_or_none()
async def get_by_oauth( async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None:
self,
db: AsyncSession,
provider: str,
oauth_id: str,
) -> User | None:
"""Get a user by their OAuth provider and ID. """Get a user by their OAuth provider and ID.
Args: Args:
db: Async database session.
provider: OAuth provider name (google, discord). provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider. oauth_id: Unique ID from the OAuth provider.
Returns: Returns:
User if found, None otherwise. UserEntry if found, None otherwise.
Example: Example:
user = await user_service.get_by_oauth(db, "google", "123456789") user = await service.get_by_oauth("google", "123456789")
""" """
result = await db.execute( return await self._user_repo.get_by_oauth(provider, oauth_id)
select(User).where(
User.oauth_provider == provider,
User.oauth_id == oauth_id,
)
)
return result.scalar_one_or_none()
async def create(self, db: AsyncSession, user_data: UserCreate) -> User: async def create(
self,
email: str,
display_name: str,
oauth_provider: str,
oauth_id: str,
avatar_url: str | None = None,
) -> UserEntry:
"""Create a new user. """Create a new user.
Args: Args:
db: Async database session. email: User's email address.
user_data: User creation data. display_name: Public display name.
oauth_provider: OAuth provider name.
oauth_id: Unique ID from the OAuth provider.
avatar_url: Optional avatar URL.
Returns: Returns:
The created User instance. The created UserEntry.
Example: Example:
user_data = UserCreate( user = await service.create(
email="player@example.com", email="player@example.com",
display_name="Player1", display_name="Player1",
oauth_provider="google", oauth_provider="google",
oauth_id="123456789" oauth_id="123456789"
) )
user = await user_service.create(db, user_data)
""" """
user = User( return await self._user_repo.create(
email=user_data.email, email=email,
display_name=user_data.display_name, display_name=display_name,
avatar_url=user_data.avatar_url, oauth_provider=oauth_provider,
oauth_provider=user_data.oauth_provider, oauth_id=oauth_id,
oauth_id=user_data.oauth_id, avatar_url=avatar_url,
) )
db.add(user)
await db.commit()
await db.refresh(user)
return user
async def create_from_oauth( async def create_from_oauth(self, oauth_info: OAuthUserInfo) -> UserEntry:
self,
db: AsyncSession,
oauth_info: OAuthUserInfo,
) -> User:
"""Create a new user from OAuth provider info. """Create a new user from OAuth provider info.
Convenience method that converts OAuthUserInfo to UserCreate. Convenience method that extracts fields from OAuthUserInfo.
Args: Args:
db: Async database session.
oauth_info: Normalized OAuth user information. oauth_info: Normalized OAuth user information.
Returns: Returns:
The created User instance. The created UserEntry.
Example: Example:
oauth_info = OAuthUserInfo( oauth_info = OAuthUserInfo(
@ -158,209 +165,212 @@ class UserService:
name="Player One", name="Player One",
avatar_url="https://..." avatar_url="https://..."
) )
user = await user_service.create_from_oauth(db, oauth_info) user = await service.create_from_oauth(oauth_info)
""" """
user_data = oauth_info.to_user_create() return await self.create(
return await self.create(db, user_data) email=oauth_info.email,
display_name=oauth_info.name,
oauth_provider=oauth_info.provider,
oauth_id=oauth_info.oauth_id,
avatar_url=oauth_info.avatar_url,
)
async def get_or_create_from_oauth( async def get_or_create_from_oauth(
self, self,
db: AsyncSession,
oauth_info: OAuthUserInfo, oauth_info: OAuthUserInfo,
) -> tuple[User, bool]: ) -> tuple[UserEntry, bool]:
"""Get existing user or create new one from OAuth info. """Get existing user or create new one from OAuth info.
First checks for existing user by OAuth provider+ID, then by email First checks for existing user by OAuth provider+ID, then by email
(for account linking), and finally creates a new user if not found. (for account linking), and finally creates a new user if not found.
Args: Args:
db: Async database session.
oauth_info: Normalized OAuth user information. oauth_info: Normalized OAuth user information.
Returns: Returns:
Tuple of (User, created) where created is True if new user. Tuple of (UserEntry, created) where created is True if new user.
Example: Example:
user, created = await user_service.get_or_create_from_oauth(db, oauth_info) user, created = await service.get_or_create_from_oauth(oauth_info)
if created: if created:
print("Welcome, new user!") print("Welcome, new user!")
else: else:
print("Welcome back!") print("Welcome back!")
""" """
# First, check by OAuth provider + ID (exact match) # First, check by OAuth provider + ID (exact match)
user = await self.get_by_oauth(db, oauth_info.provider, oauth_info.oauth_id) user = await self._user_repo.get_by_oauth(oauth_info.provider, oauth_info.oauth_id)
if user: if user:
return user, False return user, False
# Check by email for potential account linking # Check by email for potential account linking
# If user exists with same email but different OAuth, update their OAuth # If user exists with same email but different OAuth, update their OAuth
user = await self.get_by_email(db, oauth_info.email) user = await self._user_repo.get_by_email(oauth_info.email)
if user: if user:
# Update OAuth credentials for existing user # Update OAuth credentials for existing user
# This links the new OAuth provider to the existing account # This links the new OAuth provider to the existing account
user.oauth_provider = oauth_info.provider updated_user = await self._user_repo.update(
user.oauth_id = oauth_info.oauth_id user_id=user.id,
# Optionally update avatar if not set oauth_provider=oauth_info.provider,
if not user.avatar_url and oauth_info.avatar_url: oauth_id=oauth_info.oauth_id,
user.avatar_url = oauth_info.avatar_url avatar_url=oauth_info.avatar_url if not user.avatar_url else None,
await db.commit() )
await db.refresh(user) return updated_user or user, False
return user, False
# Create new user # Create new user
user = await self.create_from_oauth(db, oauth_info) user = await self.create_from_oauth(oauth_info)
return user, True return user, True
async def update( async def update(
self, self,
db: AsyncSession, user_id: UUID,
user: User,
update_data: UserUpdate, update_data: UserUpdate,
) -> User: ) -> UserEntry | None:
"""Update user profile fields. """Update user profile fields.
Only updates fields that are provided (not None). Only updates fields that are explicitly provided. Uses UNSET pattern
for avatar_url to distinguish "not provided" from "set to None".
Args: Args:
db: Async database session. user_id: The user's UUID.
user: The user to update.
update_data: Fields to update. update_data: Fields to update.
Returns: Returns:
The updated User instance. The updated UserEntry, or None if not found.
Example: Example:
update_data = UserUpdate(display_name="New Name") update_data = UserUpdate(display_name="New Name")
user = await user_service.update(db, user, update_data) user = await service.update(user_id, update_data)
""" """
if update_data.display_name is not None: from app.repositories.protocols import UNSET
user.display_name = update_data.display_name
if update_data.avatar_url is not None:
user.avatar_url = update_data.avatar_url
await db.commit() # Use UNSET for avatar_url unless explicitly provided
await db.refresh(user) avatar_url = (
return user update_data.avatar_url if "avatar_url" in update_data.model_fields_set else UNSET
)
async def update_last_login(self, db: AsyncSession, user: User) -> User: return await self._user_repo.update(
user_id=user_id,
display_name=update_data.display_name,
avatar_url=avatar_url,
)
async def update_last_login(self, user_id: UUID) -> UserEntry | None:
"""Update the user's last login timestamp. """Update the user's last login timestamp.
Args: Args:
db: Async database session. user_id: The user's UUID.
user: The user to update.
Returns: Returns:
The updated User instance. The updated UserEntry, or None if not found.
Example: Example:
user = await user_service.update_last_login(db, user) user = await service.update_last_login(user_id)
""" """
user.last_login = datetime.now(UTC) return await self._user_repo.update_last_login(user_id)
await db.commit()
await db.refresh(user)
return user
async def update_premium( async def update_premium(
self, self,
db: AsyncSession, user_id: UUID,
user: User,
premium_until: datetime | None, premium_until: datetime | None,
) -> User: ) -> UserEntry | None:
"""Update user's premium subscription status. """Update user's premium subscription status.
Args: Args:
db: Async database session. user_id: The user's UUID.
user: The user to update.
premium_until: When premium expires, or None to remove premium. premium_until: When premium expires, or None to remove premium.
Returns: Returns:
The updated User instance. The updated UserEntry, or None if not found.
Example: Example:
# Grant 30 days of premium # Grant 30 days of premium
expires = datetime.now(UTC) + timedelta(days=30) expires = datetime.now(UTC) + timedelta(days=30)
user = await user_service.update_premium(db, user, expires) user = await service.update_premium(user_id, expires)
# Remove premium # Remove premium
user = await user_service.update_premium(db, user, None) user = await service.update_premium(user_id, None)
""" """
if premium_until is not None: is_premium = premium_until is not None
user.is_premium = True return await self._user_repo.update_premium(
user.premium_until = premium_until user_id=user_id,
else: is_premium=is_premium,
user.is_premium = False premium_until=premium_until,
user.premium_until = None )
await db.commit() async def delete(self, user_id: UUID) -> bool:
await db.refresh(user)
return user
async def delete(self, db: AsyncSession, user: User) -> None:
"""Delete a user account. """Delete a user account.
This will cascade delete all related data (decks, collection, etc.) This will cascade delete all related data (decks, collection, etc.)
based on the model relationships. based on the database constraints.
Args: Args:
db: Async database session. user_id: The user's UUID.
user: The user to delete.
Returns:
True if deleted, False if not found.
Example: Example:
await user_service.delete(db, user) success = await service.delete(user_id)
""" """
await db.delete(user) return await self._user_repo.delete(user_id)
await db.commit()
# =========================================================================
# Linked Account Operations
# =========================================================================
async def get_linked_accounts(self, user_id: UUID) -> list[LinkedAccountEntry]:
"""Get all linked OAuth accounts for a user.
Args:
user_id: The user's UUID.
Returns:
List of LinkedAccountEntry.
"""
return await self._linked_repo.get_by_user(user_id)
async def get_linked_account( async def get_linked_account(
self, self,
db: AsyncSession,
provider: str, provider: str,
oauth_id: str, oauth_id: str,
) -> OAuthLinkedAccount | None: ) -> LinkedAccountEntry | None:
"""Get a linked account by provider and OAuth ID. """Get a linked account by provider and OAuth ID.
Args: Args:
db: Async database session.
provider: OAuth provider name (google, discord). provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider. oauth_id: Unique ID from the OAuth provider.
Returns: Returns:
OAuthLinkedAccount if found, None otherwise. LinkedAccountEntry if found, None otherwise.
""" """
result = await db.execute( return await self._linked_repo.get_by_provider(provider, oauth_id)
select(OAuthLinkedAccount).where(
OAuthLinkedAccount.provider == provider,
OAuthLinkedAccount.oauth_id == oauth_id,
)
)
return result.scalar_one_or_none()
async def link_oauth_account( async def link_oauth_account(
self, self,
db: AsyncSession, user_id: UUID,
user: User, user_oauth_provider: str,
oauth_info: OAuthUserInfo, oauth_info: OAuthUserInfo,
) -> OAuthLinkedAccount: ) -> LinkedAccountEntry:
"""Link an additional OAuth provider to a user account. """Link an additional OAuth provider to a user account.
Args: Args:
db: Async database session. user_id: The user's UUID.
user: The user to link the account to. user_oauth_provider: The user's primary OAuth provider.
oauth_info: OAuth information from the provider. oauth_info: OAuth information from the provider.
Returns: Returns:
The created OAuthLinkedAccount. The created LinkedAccountEntry.
Raises: Raises:
AccountLinkingError: If provider is already linked to this or another user. AccountLinkingError: If provider is already linked to this or another user.
Example: Example:
linked = await user_service.link_oauth_account(db, user, discord_info) linked = await service.link_oauth_account(user.id, user.oauth_provider, discord_info)
""" """
# Check if this provider+oauth_id is already linked to any user # Check if this provider+oauth_id is already linked to any user
existing = await self.get_linked_account(db, oauth_info.provider, oauth_info.oauth_id) existing = await self._linked_repo.get_by_provider(oauth_info.provider, oauth_info.oauth_id)
if existing: if existing:
if str(existing.user_id) == str(user.id): if existing.user_id == user_id:
raise AccountLinkingError( raise AccountLinkingError(
f"{oauth_info.provider.title()} account is already linked to your account" f"{oauth_info.provider.title()} account is already linked to your account"
) )
@ -369,36 +379,33 @@ class UserService:
) )
# Check if this is the user's primary OAuth provider # Check if this is the user's primary OAuth provider
if user.oauth_provider == oauth_info.provider: if user_oauth_provider == oauth_info.provider:
raise AccountLinkingError( raise AccountLinkingError(
f"{oauth_info.provider.title()} is your primary login provider" f"{oauth_info.provider.title()} is your primary login provider"
) )
# Check if user already has this provider linked # Check if user already has this provider linked
for linked in user.linked_accounts: linked_accounts = await self._linked_repo.get_by_user(user_id)
for linked in linked_accounts:
if linked.provider == oauth_info.provider: if linked.provider == oauth_info.provider:
raise AccountLinkingError( raise AccountLinkingError(
f"You already have a {oauth_info.provider.title()} account linked" f"You already have a {oauth_info.provider.title()} account linked"
) )
# Create the linked account # Create the linked account
linked_account = OAuthLinkedAccount( return await self._linked_repo.create(
user_id=str(user.id), user_id=user_id,
provider=oauth_info.provider, provider=oauth_info.provider,
oauth_id=oauth_info.oauth_id, oauth_id=oauth_info.oauth_id,
email=oauth_info.email, email=oauth_info.email,
display_name=oauth_info.name, display_name=oauth_info.name,
avatar_url=oauth_info.avatar_url, avatar_url=oauth_info.avatar_url,
) )
db.add(linked_account)
await db.commit()
await db.refresh(linked_account)
return linked_account
async def unlink_oauth_account( async def unlink_oauth_account(
self, self,
db: AsyncSession, user_id: UUID,
user: User, user_oauth_provider: str,
provider: str, provider: str,
) -> bool: ) -> bool:
"""Unlink an OAuth provider from a user account. """Unlink an OAuth provider from a user account.
@ -406,8 +413,8 @@ class UserService:
Cannot unlink the primary OAuth provider. Cannot unlink the primary OAuth provider.
Args: Args:
db: Async database session. user_id: The user's UUID.
user: The user to unlink from. user_oauth_provider: The user's primary OAuth provider.
provider: OAuth provider name to unlink. provider: OAuth provider name to unlink.
Returns: Returns:
@ -417,23 +424,12 @@ class UserService:
AccountLinkingError: If trying to unlink the primary provider. AccountLinkingError: If trying to unlink the primary provider.
Example: Example:
success = await user_service.unlink_oauth_account(db, user, "discord") success = await service.unlink_oauth_account(user.id, user.oauth_provider, "discord")
""" """
# Cannot unlink primary provider # Cannot unlink primary provider
if user.oauth_provider == provider: if user_oauth_provider == provider:
raise AccountLinkingError( raise AccountLinkingError(
f"Cannot unlink {provider.title()} - it is your primary login provider" f"Cannot unlink {provider.title()} - it is your primary login provider"
) )
# Find and delete the linked account return await self._linked_repo.delete(user_id, provider)
for linked in user.linked_accounts:
if linked.provider == provider:
await db.delete(linked)
await db.commit()
return True
return False
# Global service instance
user_service = UserService()

View File

@ -2,14 +2,19 @@
Tests the authentication endpoints including OAuth redirects, Tests the authentication endpoints including OAuth redirects,
token refresh, and logout. token refresh, and logout.
Uses FastAPI's dependency override pattern for proper dependency injection testing.
""" """
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID from uuid import UUID
import pytest
from fastapi import status from fastapi import status
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from app.api.deps import get_user_service
class TestGoogleAuthRedirect: class TestGoogleAuthRedirect:
"""Tests for GET /api/auth/google endpoint.""" """Tests for GET /api/auth/google endpoint."""
@ -50,11 +55,32 @@ class TestDiscordAuthRedirect:
assert "not configured" in response.json()["detail"] assert "not configured" in response.json()["detail"]
@pytest.fixture
def mock_user_service_instance():
"""Create a mock UserService for dependency injection.
Returns a MagicMock with async methods configured.
"""
mock = MagicMock()
mock.get_by_id = AsyncMock()
mock.get_by_email = AsyncMock()
mock.get_by_oauth = AsyncMock()
mock.get_or_create_from_oauth = AsyncMock()
mock.update_last_login = AsyncMock()
return mock
class TestRefreshTokens: class TestRefreshTokens:
"""Tests for POST /api/auth/refresh endpoint.""" """Tests for POST /api/auth/refresh endpoint."""
def test_returns_new_access_token( def test_returns_new_access_token(
self, client: TestClient, test_user, refresh_token_data, mock_get_redis self,
app,
client: TestClient,
test_user,
refresh_token_data,
mock_get_redis,
mock_user_service_instance,
): ):
"""Test that refresh endpoint returns new access token for valid refresh token. """Test that refresh endpoint returns new access token for valid refresh token.
@ -70,25 +96,31 @@ class TestRefreshTokens:
asyncio.get_event_loop().run_until_complete(setup_token()) asyncio.get_event_loop().run_until_complete(setup_token())
# Mock user service to return our test user # Configure mock to return test user
with patch("app.api.auth.user_service") as mock_user_service: # Convert to UserEntry-like object
mock_user_service.get_by_id = AsyncMock(return_value=test_user) mock_user_entry = MagicMock()
mock_user_entry.id = test_user.id
mock_user_entry.email = test_user.email
mock_user_service_instance.get_by_id.return_value = mock_user_entry
with ( # Override the dependency on the test app
patch("app.api.auth.get_redis", mock_get_redis), app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
patch("app.services.token_store.get_redis", mock_get_redis),
):
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token_data["token"]},
)
assert response.status_code == status.HTTP_200_OK try:
data = response.json() response = client.post(
assert "access_token" in data "/api/auth/refresh",
assert data["refresh_token"] == refresh_token_data["token"] json={"refresh_token": refresh_token_data["token"]},
assert data["token_type"] == "bearer" )
assert "expires_in" in data
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "access_token" in data
assert data["refresh_token"] == refresh_token_data["token"]
assert data["token_type"] == "bearer"
assert "expires_in" in data
finally:
# Clean up override
app.dependency_overrides.pop(get_user_service, None)
def test_returns_401_for_invalid_token(self, client: TestClient): def test_returns_401_for_invalid_token(self, client: TestClient):
"""Test that refresh endpoint returns 401 for invalid refresh token.""" """Test that refresh endpoint returns 401 for invalid refresh token."""
@ -107,21 +139,23 @@ class TestRefreshTokens:
A refresh token not in Redis (revoked/expired) should be rejected. A refresh token not in Redis (revoked/expired) should be rejected.
""" """
# Don't store the token in Redis - simulating revocation # Don't store the token in Redis - simulating revocation
# The mock_get_redis is already patched via conftest's app fixture
with ( response = client.post(
patch("app.api.auth.get_redis", mock_get_redis), "/api/auth/refresh",
patch("app.services.token_store.get_redis", mock_get_redis), json={"refresh_token": refresh_token_data["token"]},
): )
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token_data["token"]},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert "revoked" in response.json()["detail"] assert "revoked" in response.json()["detail"]
def test_returns_401_for_deleted_user( def test_returns_401_for_deleted_user(
self, client: TestClient, refresh_token_data, mock_get_redis self,
app,
client: TestClient,
refresh_token_data,
mock_get_redis,
mock_user_service_instance,
): ):
"""Test that refresh endpoint returns 401 if user no longer exists.""" """Test that refresh endpoint returns 401 if user no longer exists."""
# Store the token # Store the token
@ -134,21 +168,23 @@ class TestRefreshTokens:
asyncio.get_event_loop().run_until_complete(setup_token()) asyncio.get_event_loop().run_until_complete(setup_token())
# Mock user service to return None (user deleted) # Configure mock to return None (user deleted)
with patch("app.api.auth.user_service") as mock_user_service: mock_user_service_instance.get_by_id.return_value = None
mock_user_service.get_by_id = AsyncMock(return_value=None)
with ( # Override the dependency on the test app
patch("app.api.auth.get_redis", mock_get_redis), app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
patch("app.services.token_store.get_redis", mock_get_redis),
):
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token_data["token"]},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED try:
assert "User not found" in response.json()["detail"] response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token_data["token"]},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert "User not found" in response.json()["detail"]
finally:
# Clean up override
app.dependency_overrides.pop(get_user_service, None)
class TestLogout: class TestLogout:
@ -171,14 +207,10 @@ class TestLogout:
key = asyncio.get_event_loop().run_until_complete(setup_and_check()) key = asyncio.get_event_loop().run_until_complete(setup_and_check())
# Logout # Logout
with ( response = client.post(
patch("app.api.auth.get_redis", mock_get_redis), "/api/auth/logout",
patch("app.services.token_store.get_redis", mock_get_redis), json={"refresh_token": refresh_token_data["token"]},
): )
response = client.post(
"/api/auth/logout",
json={"refresh_token": refresh_token_data["token"]},
)
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
@ -214,7 +246,9 @@ class TestLogoutAll:
response = client.post("/api/auth/logout-all") response = client.post("/api/auth/logout-all")
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_revokes_all_tokens(self, client: TestClient, test_user, access_token, mock_get_redis): def test_revokes_all_tokens(
self, app, client: TestClient, test_user, access_token, mock_get_redis, mock_db_session
):
"""Test that logout-all revokes all refresh tokens for user. """Test that logout-all revokes all refresh tokens for user.
Should delete all tokens matching the user's ID pattern. Should delete all tokens matching the user's ID pattern.
@ -233,18 +267,16 @@ class TestLogoutAll:
asyncio.get_event_loop().run_until_complete(setup_tokens()) asyncio.get_event_loop().run_until_complete(setup_tokens())
# Mock dependencies # Set up mock db session to return test user when queried
with patch("app.api.deps.user_service") as mock_user_service: # The get_current_user dependency now does a direct DB query
mock_user_service.get_by_id = AsyncMock(return_value=test_user) mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with ( response = client.post(
patch("app.api.auth.get_redis", mock_get_redis), "/api/auth/logout-all",
patch("app.services.token_store.get_redis", mock_get_redis), headers={"Authorization": f"Bearer {access_token}"},
): )
response = client.post(
"/api/auth/logout-all",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT

View File

@ -1,32 +1,55 @@
"""Tests for users API endpoints. """Tests for users API endpoints.
Tests the user profile management endpoints. Tests the user profile management endpoints.
Uses FastAPI's dependency override pattern for proper dependency injection testing.
""" """
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID from uuid import UUID
import pytest
from fastapi import status from fastapi import status
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from app.api.deps import get_user_service
from app.services.user_service import AccountLinkingError from app.services.user_service import AccountLinkingError
@pytest.fixture
def mock_user_service_instance():
"""Create a mock UserService for dependency injection.
Returns a MagicMock with async methods configured.
"""
mock = MagicMock()
mock.get_by_id = AsyncMock()
mock.get_by_email = AsyncMock()
mock.get_by_oauth = AsyncMock()
mock.update = AsyncMock()
mock.unlink_oauth_account = AsyncMock()
return mock
class TestGetCurrentUser: class TestGetCurrentUser:
"""Tests for GET /api/users/me endpoint.""" """Tests for GET /api/users/me endpoint."""
def test_returns_user_profile(self, client: TestClient, test_user, access_token): def test_returns_user_profile(
self, app, client: TestClient, test_user, access_token, mock_db_session
):
"""Test that endpoint returns user profile for authenticated user. """Test that endpoint returns user profile for authenticated user.
Should return the user's profile information. Should return the user's profile information.
""" """
with patch("app.api.deps.user_service") as mock_user_service: # Set up mock db session to return test user when queried
mock_user_service.get_by_id = AsyncMock(return_value=test_user) mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
response = client.get( response = client.get(
"/api/users/me", "/api/users/me",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
@ -52,47 +75,93 @@ class TestGetCurrentUser:
class TestUpdateCurrentUser: class TestUpdateCurrentUser:
"""Tests for PATCH /api/users/me endpoint.""" """Tests for PATCH /api/users/me endpoint."""
def test_updates_display_name(self, client: TestClient, test_user, access_token): def test_updates_display_name(
self,
app,
client: TestClient,
test_user,
access_token,
mock_db_session,
mock_user_service_instance,
):
"""Test that endpoint updates display_name when provided.""" """Test that endpoint updates display_name when provided."""
updated_user = test_user # Create an updated user mock
updated_user = MagicMock()
updated_user.id = test_user.id
updated_user.email = test_user.email
updated_user.display_name = "New Name" updated_user.display_name = "New Name"
updated_user.avatar_url = test_user.avatar_url
updated_user.is_premium = test_user.is_premium
updated_user.premium_until = test_user.premium_until
updated_user.created_at = test_user.created_at
with patch("app.api.deps.user_service") as mock_deps_service: # Set up db session to return test user for authentication
mock_deps_service.get_by_id = AsyncMock(return_value=test_user) mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with patch("app.api.users.user_service") as mock_user_service: # Set up user service mock
mock_user_service.update = AsyncMock(return_value=updated_user) mock_user_service_instance.update.return_value = updated_user
response = client.patch( # Override the dependency
"/api/users/me", app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
headers={"Authorization": f"Bearer {access_token}"},
json={"display_name": "New Name"},
)
assert response.status_code == status.HTTP_200_OK try:
data = response.json() response = client.patch(
assert data["display_name"] == "New Name" "/api/users/me",
headers={"Authorization": f"Bearer {access_token}"},
json={"display_name": "New Name"},
)
def test_updates_avatar_url(self, client: TestClient, test_user, access_token): assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["display_name"] == "New Name"
finally:
app.dependency_overrides.pop(get_user_service, None)
def test_updates_avatar_url(
self,
app,
client: TestClient,
test_user,
access_token,
mock_db_session,
mock_user_service_instance,
):
"""Test that endpoint updates avatar_url when provided.""" """Test that endpoint updates avatar_url when provided."""
updated_user = test_user # Create an updated user mock
updated_user = MagicMock()
updated_user.id = test_user.id
updated_user.email = test_user.email
updated_user.display_name = test_user.display_name
updated_user.avatar_url = "https://new-avatar.com/img.jpg" updated_user.avatar_url = "https://new-avatar.com/img.jpg"
updated_user.is_premium = test_user.is_premium
updated_user.premium_until = test_user.premium_until
updated_user.created_at = test_user.created_at
with patch("app.api.deps.user_service") as mock_deps_service: # Set up db session to return test user for authentication
mock_deps_service.get_by_id = AsyncMock(return_value=test_user) mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with patch("app.api.users.user_service") as mock_user_service: # Set up user service mock
mock_user_service.update = AsyncMock(return_value=updated_user) mock_user_service_instance.update.return_value = updated_user
response = client.patch( # Override the dependency
"/api/users/me", app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
headers={"Authorization": f"Bearer {access_token}"},
json={"avatar_url": "https://new-avatar.com/img.jpg"},
)
assert response.status_code == status.HTTP_200_OK try:
data = response.json() response = client.patch(
assert data["avatar_url"] == "https://new-avatar.com/img.jpg" "/api/users/me",
headers={"Authorization": f"Bearer {access_token}"},
json={"avatar_url": "https://new-avatar.com/img.jpg"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
finally:
app.dependency_overrides.pop(get_user_service, None)
def test_requires_authentication(self, client: TestClient): def test_requires_authentication(self, client: TestClient):
"""Test that endpoint returns 401 without authentication.""" """Test that endpoint returns 401 without authentication."""
@ -106,18 +175,22 @@ class TestUpdateCurrentUser:
class TestGetLinkedAccounts: class TestGetLinkedAccounts:
"""Tests for GET /api/users/me/linked-accounts endpoint.""" """Tests for GET /api/users/me/linked-accounts endpoint."""
def test_returns_linked_accounts(self, client: TestClient, test_user, access_token): def test_returns_linked_accounts(
self, app, client: TestClient, test_user, access_token, mock_db_session
):
"""Test that endpoint returns list of linked OAuth accounts. """Test that endpoint returns list of linked OAuth accounts.
Should include the primary provider and any linked accounts. Should include the primary provider and any linked accounts.
""" """
with patch("app.api.deps.user_service") as mock_user_service: # Set up db session to return test user for authentication
mock_user_service.get_by_id = AsyncMock(return_value=test_user) mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
response = client.get( response = client.get(
"/api/users/me/linked-accounts", "/api/users/me/linked-accounts",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
@ -135,7 +208,7 @@ class TestGetActiveSessions:
"""Tests for GET /api/users/me/sessions endpoint.""" """Tests for GET /api/users/me/sessions endpoint."""
def test_returns_session_count( def test_returns_session_count(
self, client: TestClient, test_user, access_token, mock_get_redis self, app, client: TestClient, test_user, access_token, mock_get_redis, mock_db_session
): ):
"""Test that endpoint returns count of active sessions. """Test that endpoint returns count of active sessions.
@ -154,14 +227,16 @@ class TestGetActiveSessions:
asyncio.get_event_loop().run_until_complete(setup_tokens()) asyncio.get_event_loop().run_until_complete(setup_tokens())
with patch("app.api.deps.user_service") as mock_user_service: # Set up db session to return test user for authentication
mock_user_service.get_by_id = AsyncMock(return_value=test_user) mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with patch("app.services.token_store.get_redis", mock_get_redis): with patch("app.services.token_store.get_redis", mock_get_redis):
response = client.get( response = client.get(
"/api/users/me/sessions", "/api/users/me/sessions",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
@ -177,78 +252,128 @@ class TestGetActiveSessions:
class TestUnlinkOAuthAccount: class TestUnlinkOAuthAccount:
"""Tests for DELETE /api/users/me/link/{provider} endpoint.""" """Tests for DELETE /api/users/me/link/{provider} endpoint."""
def test_unlinks_provider_successfully(self, client: TestClient, test_user, access_token): def test_unlinks_provider_successfully(
self,
app,
client: TestClient,
test_user,
access_token,
mock_db_session,
mock_user_service_instance,
):
"""Test that endpoint successfully unlinks a provider. """Test that endpoint successfully unlinks a provider.
Should return 204 when provider is unlinked. Should return 204 when provider is unlinked.
""" """
with patch("app.api.deps.user_service") as mock_deps_service: # Set up db session to return test user for authentication
mock_deps_service.get_by_id = AsyncMock(return_value=test_user) mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with patch("app.api.users.user_service") as mock_user_service: # Set up user service mock
mock_user_service.unlink_oauth_account = AsyncMock(return_value=True) mock_user_service_instance.unlink_oauth_account.return_value = True
response = client.delete( # Override the dependency
"/api/users/me/link/discord", app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_204_NO_CONTENT try:
response = client.delete(
"/api/users/me/link/discord",
headers={"Authorization": f"Bearer {access_token}"},
)
def test_returns_404_if_not_linked(self, client: TestClient, test_user, access_token): assert response.status_code == status.HTTP_204_NO_CONTENT
finally:
app.dependency_overrides.pop(get_user_service, None)
def test_returns_404_if_not_linked(
self,
app,
client: TestClient,
test_user,
access_token,
mock_db_session,
mock_user_service_instance,
):
"""Test that endpoint returns 404 if provider isn't linked. """Test that endpoint returns 404 if provider isn't linked.
Should return 404 when trying to unlink a provider that isn't linked. Should return 404 when trying to unlink a provider that isn't linked.
""" """
with patch("app.api.deps.user_service") as mock_deps_service: # Set up db session to return test user for authentication
mock_deps_service.get_by_id = AsyncMock(return_value=test_user) mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with patch("app.api.users.user_service") as mock_user_service: # Set up user service mock
mock_user_service.unlink_oauth_account = AsyncMock(return_value=False) mock_user_service_instance.unlink_oauth_account.return_value = False
response = client.delete( # Override the dependency
"/api/users/me/link/discord", app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND try:
assert "not linked" in response.json()["detail"].lower() response = client.delete(
"/api/users/me/link/discord",
headers={"Authorization": f"Bearer {access_token}"},
)
def test_returns_400_for_primary_provider(self, client: TestClient, test_user, access_token): assert response.status_code == status.HTTP_404_NOT_FOUND
assert "not linked" in response.json()["detail"].lower()
finally:
app.dependency_overrides.pop(get_user_service, None)
def test_returns_400_for_primary_provider(
self,
app,
client: TestClient,
test_user,
access_token,
mock_db_session,
mock_user_service_instance,
):
"""Test that endpoint returns 400 when trying to unlink primary provider. """Test that endpoint returns 400 when trying to unlink primary provider.
Cannot unlink the provider used to create the account. Cannot unlink the provider used to create the account.
""" """
with patch("app.api.deps.user_service") as mock_deps_service: # Set up db session to return test user for authentication
mock_deps_service.get_by_id = AsyncMock(return_value=test_user) mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with patch("app.api.users.user_service") as mock_user_service: # Set up user service mock to raise AccountLinkingError
mock_user_service.unlink_oauth_account = AsyncMock( mock_user_service_instance.unlink_oauth_account.side_effect = AccountLinkingError(
side_effect=AccountLinkingError( "Cannot unlink Google - it is your primary login provider"
"Cannot unlink Google - it is your primary login provider" )
)
)
response = client.delete( # Override the dependency
"/api/users/me/link/google", app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST try:
assert "primary" in response.json()["detail"].lower() response = client.delete(
"/api/users/me/link/google",
headers={"Authorization": f"Bearer {access_token}"},
)
def test_returns_400_for_unknown_provider(self, client: TestClient, test_user, access_token): assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "primary" in response.json()["detail"].lower()
finally:
app.dependency_overrides.pop(get_user_service, None)
def test_returns_400_for_unknown_provider(
self, app, client: TestClient, test_user, access_token, mock_db_session
):
"""Test that endpoint returns 400 for unknown provider. """Test that endpoint returns 400 for unknown provider.
Only 'google' and 'discord' are valid providers. Only 'google' and 'discord' are valid providers.
""" """
with patch("app.api.deps.user_service") as mock_deps_service: # Set up db session to return test user for authentication
mock_deps_service.get_by_id = AsyncMock(return_value=test_user) mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
response = client.delete( response = client.delete(
"/api/users/me/link/twitter", "/api/users/me/link/twitter",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
) )
assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "unknown provider" in response.json()["detail"].lower() assert "unknown provider" in response.json()["detail"].lower()

View File

@ -2,25 +2,43 @@
Tests the user service CRUD operations and OAuth-based user creation. Tests the user service CRUD operations and OAuth-based user creation.
Uses real Postgres via the db_session fixture from conftest. Uses real Postgres via the db_session fixture from conftest.
The UserService now uses dependency injection with repositories injected
via constructor, so we create a fresh service instance per test.
""" """
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from uuid import UUID
import pytest import pytest
from app.db.models import User from app.db.models import User
from app.db.models.oauth_account import OAuthLinkedAccount from app.db.models.oauth_account import OAuthLinkedAccount
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate from app.repositories.postgres.linked_account import PostgresLinkedAccountRepository
from app.services.user_service import AccountLinkingError, user_service from app.repositories.postgres.user import PostgresUserRepository
from app.schemas.user import OAuthUserInfo, UserUpdate
from app.services.user_service import AccountLinkingError, UserService
# Import db_session fixture from db conftest # Import db_session fixture from db conftest
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
@pytest.fixture
def user_service(db_session):
"""Create UserService with real PostgreSQL repositories.
This fixture provides a properly constructed UserService for each test,
following the dependency injection pattern used in production.
"""
user_repo = PostgresUserRepository(db_session)
linked_repo = PostgresLinkedAccountRepository(db_session)
return UserService(user_repo, linked_repo)
class TestGetById: class TestGetById:
"""Tests for get_by_id method.""" """Tests for get_by_id method."""
async def test_returns_user_when_found(self, db_session): async def test_returns_user_when_found(self, db_session, user_service):
"""Test that get_by_id returns user when it exists. """Test that get_by_id returns user when it exists.
Creates a user and verifies it can be retrieved by ID. Creates a user and verifies it can be retrieved by ID.
@ -36,26 +54,24 @@ class TestGetById:
await db_session.commit() await db_session.commit()
# Retrieve by ID # Retrieve by ID
from uuid import UUID
user_id = UUID(user.id) if isinstance(user.id, str) else user.id user_id = UUID(user.id) if isinstance(user.id, str) else user.id
result = await user_service.get_by_id(db_session, user_id) result = await user_service.get_by_id(user_id)
assert result is not None assert result is not None
assert result.email == "test@example.com" assert result.email == "test@example.com"
async def test_returns_none_when_not_found(self, db_session): async def test_returns_none_when_not_found(self, user_service):
"""Test that get_by_id returns None for nonexistent users.""" """Test that get_by_id returns None for nonexistent users."""
from uuid import uuid4 from uuid import uuid4
result = await user_service.get_by_id(db_session, uuid4()) result = await user_service.get_by_id(uuid4())
assert result is None assert result is None
class TestGetByEmail: class TestGetByEmail:
"""Tests for get_by_email method.""" """Tests for get_by_email method."""
async def test_returns_user_when_found(self, db_session): async def test_returns_user_when_found(self, db_session, user_service):
"""Test that get_by_email returns user when it exists.""" """Test that get_by_email returns user when it exists."""
user = User( user = User(
email="findme@example.com", email="findme@example.com",
@ -66,21 +82,21 @@ class TestGetByEmail:
db_session.add(user) db_session.add(user)
await db_session.commit() await db_session.commit()
result = await user_service.get_by_email(db_session, "findme@example.com") result = await user_service.get_by_email("findme@example.com")
assert result is not None assert result is not None
assert result.display_name == "Find Me" assert result.display_name == "Find Me"
async def test_returns_none_when_not_found(self, db_session): async def test_returns_none_when_not_found(self, user_service):
"""Test that get_by_email returns None for nonexistent emails.""" """Test that get_by_email returns None for nonexistent emails."""
result = await user_service.get_by_email(db_session, "nobody@example.com") result = await user_service.get_by_email("nobody@example.com")
assert result is None assert result is None
class TestGetByOAuth: class TestGetByOAuth:
"""Tests for get_by_oauth method.""" """Tests for get_by_oauth method."""
async def test_returns_user_when_found(self, db_session): async def test_returns_user_when_found(self, db_session, user_service):
"""Test that get_by_oauth returns user for matching provider+id.""" """Test that get_by_oauth returns user for matching provider+id."""
user = User( user = User(
email="oauth@example.com", email="oauth@example.com",
@ -91,12 +107,12 @@ class TestGetByOAuth:
db_session.add(user) db_session.add(user)
await db_session.commit() await db_session.commit()
result = await user_service.get_by_oauth(db_session, "google", "google-unique-id") result = await user_service.get_by_oauth("google", "google-unique-id")
assert result is not None assert result is not None
assert result.email == "oauth@example.com" assert result.email == "oauth@example.com"
async def test_returns_none_for_wrong_provider(self, db_session): async def test_returns_none_for_wrong_provider(self, db_session, user_service):
"""Test that get_by_oauth returns None if provider doesn't match.""" """Test that get_by_oauth returns None if provider doesn't match."""
user = User( user = User(
email="oauth2@example.com", email="oauth2@example.com",
@ -108,30 +124,28 @@ class TestGetByOAuth:
await db_session.commit() await db_session.commit()
# Same ID, different provider # Same ID, different provider
result = await user_service.get_by_oauth(db_session, "discord", "google-id-2") result = await user_service.get_by_oauth("discord", "google-id-2")
assert result is None assert result is None
async def test_returns_none_when_not_found(self, db_session): async def test_returns_none_when_not_found(self, user_service):
"""Test that get_by_oauth returns None for nonexistent OAuth.""" """Test that get_by_oauth returns None for nonexistent OAuth."""
result = await user_service.get_by_oauth(db_session, "google", "nonexistent") result = await user_service.get_by_oauth("google", "nonexistent")
assert result is None assert result is None
class TestCreate: class TestCreate:
"""Tests for create method.""" """Tests for create method."""
async def test_creates_user_with_all_fields(self, db_session): async def test_creates_user_with_all_fields(self, user_service):
"""Test that create properly persists all user fields.""" """Test that create properly persists all user fields."""
user_data = UserCreate( result = await user_service.create(
email="new@example.com", email="new@example.com",
display_name="New User", display_name="New User",
avatar_url="https://example.com/avatar.jpg",
oauth_provider="discord", oauth_provider="discord",
oauth_id="discord-new-id", oauth_id="discord-new-id",
avatar_url="https://example.com/avatar.jpg",
) )
result = await user_service.create(db_session, user_data)
assert result.id is not None assert result.id is not None
assert result.email == "new@example.com" assert result.email == "new@example.com"
assert result.display_name == "New User" assert result.display_name == "New User"
@ -141,24 +155,22 @@ class TestCreate:
assert result.is_premium is False assert result.is_premium is False
assert result.premium_until is None assert result.premium_until is None
async def test_creates_user_without_avatar(self, db_session): async def test_creates_user_without_avatar(self, user_service):
"""Test that create works without optional avatar_url.""" """Test that create works without optional avatar_url."""
user_data = UserCreate( result = await user_service.create(
email="noavatar@example.com", email="noavatar@example.com",
display_name="No Avatar", display_name="No Avatar",
oauth_provider="google", oauth_provider="google",
oauth_id="google-no-avatar", oauth_id="google-no-avatar",
) )
result = await user_service.create(db_session, user_data)
assert result.avatar_url is None assert result.avatar_url is None
class TestCreateFromOAuth: class TestCreateFromOAuth:
"""Tests for create_from_oauth method.""" """Tests for create_from_oauth method."""
async def test_creates_user_from_oauth_info(self, db_session): async def test_creates_user_from_oauth_info(self, user_service):
"""Test that create_from_oauth converts OAuthUserInfo to User.""" """Test that create_from_oauth converts OAuthUserInfo to User."""
oauth_info = OAuthUserInfo( oauth_info = OAuthUserInfo(
provider="google", provider="google",
@ -168,7 +180,7 @@ class TestCreateFromOAuth:
avatar_url="https://google.com/avatar.jpg", avatar_url="https://google.com/avatar.jpg",
) )
result = await user_service.create_from_oauth(db_session, oauth_info) result = await user_service.create_from_oauth(oauth_info)
assert result.email == "oauthcreate@example.com" assert result.email == "oauthcreate@example.com"
assert result.display_name == "OAuth Created User" assert result.display_name == "OAuth Created User"
@ -179,7 +191,7 @@ class TestCreateFromOAuth:
class TestGetOrCreateFromOAuth: class TestGetOrCreateFromOAuth:
"""Tests for get_or_create_from_oauth method.""" """Tests for get_or_create_from_oauth method."""
async def test_returns_existing_user_by_oauth(self, db_session): async def test_returns_existing_user_by_oauth(self, db_session, user_service):
"""Test that existing user is returned when OAuth matches. """Test that existing user is returned when OAuth matches.
Verifies the method returns (user, False) for existing users. Verifies the method returns (user, False) for existing users.
@ -202,12 +214,12 @@ class TestGetOrCreateFromOAuth:
name="Existing", name="Existing",
) )
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info) result, created = await user_service.get_or_create_from_oauth(oauth_info)
assert created is False assert created is False
assert result.id == existing.id assert str(result.id) == str(existing.id)
async def test_links_existing_user_by_email(self, db_session): async def test_links_existing_user_by_email(self, db_session, user_service):
"""Test that OAuth is linked when email matches existing user. """Test that OAuth is linked when email matches existing user.
If a user exists with the same email but different OAuth, If a user exists with the same email but different OAuth,
@ -232,15 +244,15 @@ class TestGetOrCreateFromOAuth:
avatar_url="https://discord.com/avatar.jpg", avatar_url="https://discord.com/avatar.jpg",
) )
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info) result, created = await user_service.get_or_create_from_oauth(oauth_info)
assert created is False assert created is False
assert result.id == existing.id assert str(result.id) == str(existing.id)
# OAuth should be updated to Discord # OAuth should be updated to Discord
assert result.oauth_provider == "discord" assert result.oauth_provider == "discord"
assert result.oauth_id == "discord-link-id" assert result.oauth_id == "discord-link-id"
async def test_creates_new_user_when_not_found(self, db_session): async def test_creates_new_user_when_not_found(self, user_service):
"""Test that new user is created when no match exists. """Test that new user is created when no match exists.
Verifies the method returns (user, True) for new users. Verifies the method returns (user, True) for new users.
@ -252,7 +264,7 @@ class TestGetOrCreateFromOAuth:
name="Brand New", name="Brand New",
) )
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info) result, created = await user_service.get_or_create_from_oauth(oauth_info)
assert created is True assert created is True
assert result.email == "brandnew@example.com" assert result.email == "brandnew@example.com"
@ -261,7 +273,7 @@ class TestGetOrCreateFromOAuth:
class TestUpdate: class TestUpdate:
"""Tests for update method.""" """Tests for update method."""
async def test_updates_display_name(self, db_session): async def test_updates_display_name(self, db_session, user_service):
"""Test that update changes display_name when provided.""" """Test that update changes display_name when provided."""
user = User( user = User(
email="update@example.com", email="update@example.com",
@ -272,12 +284,13 @@ class TestUpdate:
db_session.add(user) db_session.add(user)
await db_session.commit() await db_session.commit()
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
update_data = UserUpdate(display_name="New Name") update_data = UserUpdate(display_name="New Name")
result = await user_service.update(db_session, user, update_data) result = await user_service.update(user_id, update_data)
assert result.display_name == "New Name" assert result.display_name == "New Name"
async def test_updates_avatar_url(self, db_session): async def test_updates_avatar_url(self, db_session, user_service):
"""Test that update changes avatar_url when provided.""" """Test that update changes avatar_url when provided."""
user = User( user = User(
email="avatar@example.com", email="avatar@example.com",
@ -288,12 +301,13 @@ class TestUpdate:
db_session.add(user) db_session.add(user)
await db_session.commit() await db_session.commit()
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
update_data = UserUpdate(avatar_url="https://new-avatar.com/img.jpg") update_data = UserUpdate(avatar_url="https://new-avatar.com/img.jpg")
result = await user_service.update(db_session, user, update_data) result = await user_service.update(user_id, update_data)
assert result.avatar_url == "https://new-avatar.com/img.jpg" assert result.avatar_url == "https://new-avatar.com/img.jpg"
async def test_ignores_none_values(self, db_session): async def test_ignores_none_values(self, db_session, user_service):
"""Test that update doesn't change fields set to None. """Test that update doesn't change fields set to None.
Only explicitly provided fields should be updated. Only explicitly provided fields should be updated.
@ -309,8 +323,9 @@ class TestUpdate:
await db_session.commit() await db_session.commit()
# Update only display_name, leave avatar alone # Update only display_name, leave avatar alone
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
update_data = UserUpdate(display_name="Changed") update_data = UserUpdate(display_name="Changed")
result = await user_service.update(db_session, user, update_data) result = await user_service.update(user_id, update_data)
assert result.display_name == "Changed" assert result.display_name == "Changed"
assert result.avatar_url == "https://keep.com/avatar.jpg" assert result.avatar_url == "https://keep.com/avatar.jpg"
@ -319,7 +334,7 @@ class TestUpdate:
class TestUpdateLastLogin: class TestUpdateLastLogin:
"""Tests for update_last_login method.""" """Tests for update_last_login method."""
async def test_updates_last_login_timestamp(self, db_session): async def test_updates_last_login_timestamp(self, db_session, user_service):
"""Test that update_last_login sets current timestamp.""" """Test that update_last_login sets current timestamp."""
user = User( user = User(
email="login@example.com", email="login@example.com",
@ -332,8 +347,9 @@ class TestUpdateLastLogin:
assert user.last_login is None assert user.last_login is None
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
before = datetime.now(UTC) before = datetime.now(UTC)
result = await user_service.update_last_login(db_session, user) result = await user_service.update_last_login(user_id)
after = datetime.now(UTC) after = datetime.now(UTC)
assert result.last_login is not None assert result.last_login is not None
@ -344,7 +360,7 @@ class TestUpdateLastLogin:
class TestUpdatePremium: class TestUpdatePremium:
"""Tests for update_premium method.""" """Tests for update_premium method."""
async def test_grants_premium(self, db_session): async def test_grants_premium(self, db_session, user_service):
"""Test that update_premium sets premium status and expiration.""" """Test that update_premium sets premium status and expiration."""
user = User( user = User(
email="premium@example.com", email="premium@example.com",
@ -357,13 +373,14 @@ class TestUpdatePremium:
assert user.is_premium is False assert user.is_premium is False
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
expires = datetime.now(UTC) + timedelta(days=30) expires = datetime.now(UTC) + timedelta(days=30)
result = await user_service.update_premium(db_session, user, expires) result = await user_service.update_premium(user_id, expires)
assert result.is_premium is True assert result.is_premium is True
assert result.premium_until == expires assert result.premium_until == expires
async def test_removes_premium(self, db_session): async def test_removes_premium(self, db_session, user_service):
"""Test that update_premium with None removes premium status.""" """Test that update_premium with None removes premium status."""
user = User( user = User(
email="unpremium@example.com", email="unpremium@example.com",
@ -376,7 +393,8 @@ class TestUpdatePremium:
db_session.add(user) db_session.add(user)
await db_session.commit() await db_session.commit()
result = await user_service.update_premium(db_session, user, None) user_id = UUID(user.id) if isinstance(user.id, str) else user.id
result = await user_service.update_premium(user_id, None)
assert result.is_premium is False assert result.is_premium is False
assert result.premium_until is None assert result.premium_until is None
@ -385,7 +403,7 @@ class TestUpdatePremium:
class TestDelete: class TestDelete:
"""Tests for delete method.""" """Tests for delete method."""
async def test_deletes_user(self, db_session): async def test_deletes_user(self, db_session, user_service):
"""Test that delete removes user from database.""" """Test that delete removes user from database."""
user = User( user = User(
email="delete@example.com", email="delete@example.com",
@ -396,22 +414,18 @@ class TestDelete:
db_session.add(user) db_session.add(user)
await db_session.commit() await db_session.commit()
user_id = user.id user_id = UUID(user.id) if isinstance(user.id, str) else user.id
await user_service.delete(db_session, user) await user_service.delete(user_id)
# Verify user is gone # Verify user is gone
from uuid import UUID result = await user_service.get_by_id(user_id)
result = await user_service.get_by_id(
db_session, UUID(user_id) if isinstance(user_id, str) else user_id
)
assert result is None assert result is None
class TestGetLinkedAccount: class TestGetLinkedAccount:
"""Tests for get_linked_account method.""" """Tests for get_linked_account method."""
async def test_returns_linked_account_when_found(self, db_session): async def test_returns_linked_account_when_found(self, db_session, user_service):
"""Test that get_linked_account returns account when it exists. """Test that get_linked_account returns account when it exists.
Creates a user with a linked account and verifies it can be retrieved. Creates a user with a linked account and verifies it can be retrieved.
@ -437,22 +451,22 @@ class TestGetLinkedAccount:
await db_session.commit() await db_session.commit()
# Retrieve linked account # Retrieve linked account
result = await user_service.get_linked_account(db_session, "discord", "discord-linked-123") result = await user_service.get_linked_account("discord", "discord-linked-123")
assert result is not None assert result is not None
assert result.provider == "discord" assert result.provider == "discord"
assert result.oauth_id == "discord-linked-123" assert result.oauth_id == "discord-linked-123"
async def test_returns_none_when_not_found(self, db_session): async def test_returns_none_when_not_found(self, user_service):
"""Test that get_linked_account returns None for nonexistent accounts.""" """Test that get_linked_account returns None for nonexistent accounts."""
result = await user_service.get_linked_account(db_session, "discord", "nonexistent-id") result = await user_service.get_linked_account("discord", "nonexistent-id")
assert result is None assert result is None
class TestLinkOAuthAccount: class TestLinkOAuthAccount:
"""Tests for link_oauth_account method.""" """Tests for link_oauth_account method."""
async def test_links_new_provider(self, db_session): async def test_links_new_provider(self, db_session, user_service):
"""Test that link_oauth_account successfully links a new provider. """Test that link_oauth_account successfully links a new provider.
Creates a Google user and links Discord to them. Creates a Google user and links Discord to them.
@ -468,6 +482,8 @@ class TestLinkOAuthAccount:
await db_session.commit() await db_session.commit()
await db_session.refresh(user) await db_session.refresh(user)
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
# Link Discord # Link Discord
discord_info = OAuthUserInfo( discord_info = OAuthUserInfo(
provider="discord", provider="discord",
@ -477,16 +493,16 @@ class TestLinkOAuthAccount:
avatar_url="https://discord.com/avatar.png", avatar_url="https://discord.com/avatar.png",
) )
result = await user_service.link_oauth_account(db_session, user, discord_info) result = await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
assert result is not None assert result is not None
assert result.provider == "discord" assert result.provider == "discord"
assert result.oauth_id == "discord-456" assert result.oauth_id == "discord-456"
assert result.email == "discord@example.com" assert result.email == "discord@example.com"
assert result.display_name == "Discord Name" assert result.display_name == "Discord Name"
assert str(result.user_id) == str(user.id) assert str(result.user_id) == str(user_id)
async def test_raises_error_if_already_linked_to_same_user(self, db_session): async def test_raises_error_if_already_linked_to_same_user(self, db_session, user_service):
"""Test that linking same provider twice raises error. """Test that linking same provider twice raises error.
A user cannot have the same provider linked multiple times. A user cannot have the same provider linked multiple times.
@ -501,6 +517,8 @@ class TestLinkOAuthAccount:
await db_session.commit() await db_session.commit()
await db_session.refresh(user) await db_session.refresh(user)
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
# Link Discord first time # Link Discord first time
discord_info = OAuthUserInfo( discord_info = OAuthUserInfo(
provider="discord", provider="discord",
@ -508,16 +526,15 @@ class TestLinkOAuthAccount:
email="first@discord.com", email="first@discord.com",
name="First", name="First",
) )
await user_service.link_oauth_account(db_session, user, discord_info) await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
await db_session.refresh(user)
# Try to link same Discord account again # Try to link same Discord account again
with pytest.raises(AccountLinkingError) as exc_info: with pytest.raises(AccountLinkingError) as exc_info:
await user_service.link_oauth_account(db_session, user, discord_info) await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
assert "already linked to your account" in str(exc_info.value) assert "already linked to your account" in str(exc_info.value)
async def test_raises_error_if_linked_to_another_user(self, db_session): async def test_raises_error_if_linked_to_another_user(self, db_session, user_service):
"""Test that linking account already linked to another user raises error. """Test that linking account already linked to another user raises error.
The same OAuth provider+ID cannot be linked to multiple users. The same OAuth provider+ID cannot be linked to multiple users.
@ -533,13 +550,15 @@ class TestLinkOAuthAccount:
await db_session.commit() await db_session.commit()
await db_session.refresh(user1) await db_session.refresh(user1)
user1_id = UUID(user1.id) if isinstance(user1.id, str) else user1.id
discord_info = OAuthUserInfo( discord_info = OAuthUserInfo(
provider="discord", provider="discord",
oauth_id="shared-discord", oauth_id="shared-discord",
email="shared@discord.com", email="shared@discord.com",
name="Shared", name="Shared",
) )
await user_service.link_oauth_account(db_session, user1, discord_info) await user_service.link_oauth_account(user1_id, user1.oauth_provider, discord_info)
# Create second user # Create second user
user2 = User( user2 = User(
@ -552,13 +571,15 @@ class TestLinkOAuthAccount:
await db_session.commit() await db_session.commit()
await db_session.refresh(user2) await db_session.refresh(user2)
user2_id = UUID(user2.id) if isinstance(user2.id, str) else user2.id
# Try to link same Discord account to second user # Try to link same Discord account to second user
with pytest.raises(AccountLinkingError) as exc_info: with pytest.raises(AccountLinkingError) as exc_info:
await user_service.link_oauth_account(db_session, user2, discord_info) await user_service.link_oauth_account(user2_id, user2.oauth_provider, discord_info)
assert "already linked to another user" in str(exc_info.value) assert "already linked to another user" in str(exc_info.value)
async def test_raises_error_if_linking_primary_provider(self, db_session): async def test_raises_error_if_linking_primary_provider(self, db_session, user_service):
"""Test that linking the same provider as primary raises error. """Test that linking the same provider as primary raises error.
User cannot link Google if they already signed up with Google. User cannot link Google if they already signed up with Google.
@ -573,6 +594,8 @@ class TestLinkOAuthAccount:
await db_session.commit() await db_session.commit()
await db_session.refresh(user) await db_session.refresh(user)
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
# Try to link another Google account # Try to link another Google account
google_info = OAuthUserInfo( google_info = OAuthUserInfo(
provider="google", provider="google",
@ -582,7 +605,7 @@ class TestLinkOAuthAccount:
) )
with pytest.raises(AccountLinkingError) as exc_info: with pytest.raises(AccountLinkingError) as exc_info:
await user_service.link_oauth_account(db_session, user, google_info) await user_service.link_oauth_account(user_id, user.oauth_provider, google_info)
assert "primary login provider" in str(exc_info.value) assert "primary login provider" in str(exc_info.value)
@ -590,7 +613,7 @@ class TestLinkOAuthAccount:
class TestUnlinkOAuthAccount: class TestUnlinkOAuthAccount:
"""Tests for unlink_oauth_account method.""" """Tests for unlink_oauth_account method."""
async def test_unlinks_linked_account(self, db_session): async def test_unlinks_linked_account(self, db_session, user_service):
"""Test that unlink_oauth_account removes a linked account. """Test that unlink_oauth_account removes a linked account.
Links Discord then unlinks it successfully. Links Discord then unlinks it successfully.
@ -605,6 +628,8 @@ class TestUnlinkOAuthAccount:
await db_session.commit() await db_session.commit()
await db_session.refresh(user) await db_session.refresh(user)
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
# Link Discord # Link Discord
discord_info = OAuthUserInfo( discord_info = OAuthUserInfo(
provider="discord", provider="discord",
@ -612,22 +637,18 @@ class TestUnlinkOAuthAccount:
email="discord@unlink.com", email="discord@unlink.com",
name="Discord Unlink", name="Discord Unlink",
) )
await user_service.link_oauth_account(db_session, user, discord_info) await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
await db_session.refresh(user)
# Verify linked
assert len(user.linked_accounts) == 1
# Unlink # Unlink
result = await user_service.unlink_oauth_account(db_session, user, "discord") result = await user_service.unlink_oauth_account(user_id, user.oauth_provider, "discord")
assert result is True assert result is True
# Verify unlinked # Verify unlinked
linked = await user_service.get_linked_account(db_session, "discord", "discord-unlink") linked = await user_service.get_linked_account("discord", "discord-unlink")
assert linked is None assert linked is None
async def test_returns_false_if_not_linked(self, db_session): async def test_returns_false_if_not_linked(self, db_session, user_service):
"""Test that unlink returns False if provider isn't linked.""" """Test that unlink returns False if provider isn't linked."""
user = User( user = User(
email="not-linked@example.com", email="not-linked@example.com",
@ -639,11 +660,12 @@ class TestUnlinkOAuthAccount:
await db_session.commit() await db_session.commit()
await db_session.refresh(user) await db_session.refresh(user)
result = await user_service.unlink_oauth_account(db_session, user, "discord") user_id = UUID(user.id) if isinstance(user.id, str) else user.id
result = await user_service.unlink_oauth_account(user_id, user.oauth_provider, "discord")
assert result is False assert result is False
async def test_raises_error_if_unlinking_primary(self, db_session): async def test_raises_error_if_unlinking_primary(self, db_session, user_service):
"""Test that unlinking primary provider raises error. """Test that unlinking primary provider raises error.
User cannot unlink their primary OAuth provider. User cannot unlink their primary OAuth provider.
@ -658,7 +680,9 @@ class TestUnlinkOAuthAccount:
await db_session.commit() await db_session.commit()
await db_session.refresh(user) await db_session.refresh(user)
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
with pytest.raises(AccountLinkingError) as exc_info: with pytest.raises(AccountLinkingError) as exc_info:
await user_service.unlink_oauth_account(db_session, user, "google") await user_service.unlink_oauth_account(user_id, user.oauth_provider, "google")
assert "primary login provider" in str(exc_info.value) assert "primary login provider" in str(exc_info.value)