Implement UserRepository pattern with dependency injection

- Add UserRepository and LinkedAccountRepository protocols to protocols.py
- Add UserEntry and LinkedAccountEntry DTOs for service layer decoupling
- Implement PostgresUserRepository and PostgresLinkedAccountRepository
- Refactor UserService to use constructor-injected repositories
- Add get_user_service factory and UserServiceDep to API deps
- Update auth.py and users.py endpoints to use UserServiceDep
- Rewrite tests to use FastAPI dependency overrides (no monkey patching)

This follows the established repository pattern used by DeckService and
CollectionService, enabling future offline fork support.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2026-01-30 07:30:16 -06:00
parent f6e8ab5f67
commit 7fcb86ff51
12 changed files with 1316 additions and 454 deletions

View File

@ -31,7 +31,7 @@ import secrets
from fastapi import APIRouter, HTTPException, Query, status
from fastapi.responses import RedirectResponse
from app.api.deps import CurrentUser, DbSession
from app.api.deps import CurrentUser, UserServiceDep
from app.config import settings
from app.db.redis import get_redis
from app.schemas.auth import RefreshTokenRequest, TokenResponse
@ -45,7 +45,7 @@ from app.services.jwt_service import (
from app.services.oauth.discord import DiscordOAuthError, discord_oauth
from app.services.oauth.google import GoogleOAuthError, google_oauth
from app.services.token_store import token_store
from app.services.user_service import AccountLinkingError, user_service
from app.services.user_service import AccountLinkingError
router = APIRouter(prefix="/auth", tags=["auth"])
@ -150,7 +150,7 @@ async def google_auth_redirect(
@router.get("/google/callback")
async def google_auth_callback(
db: DbSession,
user_service: UserServiceDep,
code: str = Query(..., description="Authorization code from Google"),
state: str = Query(..., description="State parameter for CSRF validation"),
) -> TokenResponse:
@ -182,10 +182,10 @@ async def google_auth_callback(
user_info = await google_oauth.get_user_info(code, oauth_callback)
# Get or create user
user, created = await user_service.get_or_create_from_oauth(db, user_info)
user, _created = await user_service.get_or_create_from_oauth(user_info)
# Update last login
await user_service.update_last_login(db, user)
await user_service.update_last_login(user.id)
# Create tokens
return await _create_tokens_for_user(user.id)
@ -243,7 +243,7 @@ async def discord_auth_redirect(
@router.get("/discord/callback")
async def discord_auth_callback(
db: DbSession,
user_service: UserServiceDep,
code: str = Query(..., description="Authorization code from Discord"),
state: str = Query(..., description="State parameter for CSRF validation"),
) -> RedirectResponse:
@ -276,10 +276,10 @@ async def discord_auth_callback(
user_info = await discord_oauth.get_user_info(code, oauth_callback)
# Get or create user
user, created = await user_service.get_or_create_from_oauth(db, user_info)
user, _created = await user_service.get_or_create_from_oauth(user_info)
# Update last login
await user_service.update_last_login(db, user)
await user_service.update_last_login(user.id)
# Create tokens
tokens = await _create_tokens_for_user(user.id)
@ -307,7 +307,7 @@ async def discord_auth_callback(
@router.post("/refresh", response_model=TokenResponse)
async def refresh_tokens(
db: DbSession,
user_service: UserServiceDep,
request: RefreshTokenRequest,
) -> TokenResponse:
"""Refresh access token using refresh token.
@ -344,7 +344,7 @@ async def refresh_tokens(
)
# Verify user still exists
user = await user_service.get_by_id(db, user_id)
user = await user_service.get_by_id(user_id)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -489,7 +489,7 @@ async def google_link_redirect(
@router.get("/link/google/callback")
async def google_link_callback(
db: DbSession,
user_service: UserServiceDep,
code: str = Query(..., description="Authorization code from Google"),
state: str = Query(..., description="State parameter for CSRF validation"),
) -> RedirectResponse:
@ -523,7 +523,7 @@ async def google_link_callback(
from uuid import UUID
user_id = UUID(user_id_str)
user = await user_service.get_by_id(db, user_id)
user = await user_service.get_by_id(user_id)
if user is None:
return RedirectResponse(
url=f"{redirect_uri}?error=user_not_found",
@ -531,7 +531,7 @@ async def google_link_callback(
)
# Link the account
await user_service.link_oauth_account(db, user, oauth_info)
await user_service.link_oauth_account(user.id, user.oauth_provider, oauth_info)
return RedirectResponse(
url=f"{redirect_uri}?linked=google",
@ -600,7 +600,7 @@ async def discord_link_redirect(
@router.get("/link/discord/callback")
async def discord_link_callback(
db: DbSession,
user_service: UserServiceDep,
code: str = Query(..., description="Authorization code from Discord"),
state: str = Query(..., description="State parameter for CSRF validation"),
) -> RedirectResponse:
@ -634,7 +634,7 @@ async def discord_link_callback(
from uuid import UUID
user_id = UUID(user_id_str)
user = await user_service.get_by_id(db, user_id)
user = await user_service.get_by_id(user_id)
if user is None:
return RedirectResponse(
url=f"{redirect_uri}?error=user_not_found",
@ -642,7 +642,7 @@ async def discord_link_callback(
)
# Link the account
await user_service.link_oauth_account(db, user, oauth_info)
await user_service.link_oauth_account(user.id, user.oauth_provider, oauth_info)
return RedirectResponse(
url=f"{redirect_uri}?linked=discord",

View File

@ -28,6 +28,7 @@ from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
@ -35,13 +36,15 @@ from app.db import get_session
from app.db.models import User
from app.repositories.postgres.collection import PostgresCollectionRepository
from app.repositories.postgres.deck import PostgresDeckRepository
from app.repositories.postgres.linked_account import PostgresLinkedAccountRepository
from app.repositories.postgres.user import PostgresUserRepository
from app.services.card_service import CardService, get_card_service
from app.services.collection_service import CollectionService
from app.services.deck_service import DeckService
from app.services.game_service import GameService, game_service
from app.services.game_state_manager import GameStateManager, game_state_manager
from app.services.jwt_service import verify_access_token
from app.services.user_service import user_service
from app.services.user_service import UserService
# OAuth2 scheme for extracting Bearer token from Authorization header
# tokenUrl is for OpenAPI docs - points to where tokens are obtained
@ -148,8 +151,9 @@ async def get_current_user(
if user_id is None:
raise credentials_exception
# Fetch user from database
user = await user_service.get_by_id(db, user_id)
# Fetch user from database (direct query for auth - not business logic)
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise credentials_exception
@ -187,7 +191,9 @@ async def get_optional_user(
if user_id is None:
return None
return await user_service.get_by_id(db, user_id)
# Direct query for auth - not business logic
result = await db.execute(select(User).where(User.id == user_id))
return result.scalar_one_or_none()
async def get_current_premium_user(
@ -333,6 +339,31 @@ def get_game_state_manager_dep() -> GameStateManager:
return game_state_manager
def get_user_service(
db: Annotated[AsyncSession, Depends(get_db)],
) -> UserService:
"""Get UserService with PostgreSQL repositories.
Creates a UserService instance with user and linked account repositories.
Args:
db: Database session from request.
Returns:
UserService configured for PostgreSQL.
Example:
@router.post("/auth/google/callback")
async def google_callback(
user_service: UserService = Depends(get_user_service),
):
user, created = await user_service.get_or_create_from_oauth(oauth_info)
"""
user_repo = PostgresUserRepository(db)
linked_repo = PostgresLinkedAccountRepository(db)
return UserService(user_repo, linked_repo)
# =============================================================================
# Type Aliases for Cleaner Endpoint Signatures
# =============================================================================
@ -351,6 +382,7 @@ CollectionServiceDep = Annotated[CollectionService, Depends(get_collection_servi
CardServiceDep = Annotated[CardService, Depends(get_card_service_dep)]
GameServiceDep = Annotated[GameService, Depends(get_game_service_dep)]
GameStateManagerDep = Annotated[GameStateManager, Depends(get_game_state_manager_dep)]
UserServiceDep = Annotated[UserService, Depends(get_user_service)]
# Admin authentication
AdminAuth = Annotated[None, Depends(verify_admin_token)]

View File

@ -26,12 +26,12 @@ Example:
from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel, Field
from app.api.deps import CurrentUser, DbSession, DeckServiceDep
from app.api.deps import CurrentUser, DeckServiceDep, UserServiceDep
from app.schemas.deck import DeckResponse, StarterDeckSelectRequest, StarterStatusResponse
from app.schemas.user import UserResponse, UserUpdate
from app.services.deck_service import DeckLimitExceededError, StarterAlreadySelectedError
from app.services.token_store import token_store
from app.services.user_service import AccountLinkingError, user_service
from app.services.user_service import AccountLinkingError
router = APIRouter(prefix="/users", tags=["users"])
@ -65,7 +65,7 @@ async def get_current_user_profile(
@router.patch("/me", response_model=UserResponse)
async def update_current_user_profile(
user: CurrentUser,
db: DbSession,
user_service: UserServiceDep,
update_data: UserUpdate,
) -> UserResponse:
"""Update the current user's profile.
@ -78,7 +78,7 @@ async def update_current_user_profile(
Returns:
Updated user profile.
"""
updated_user = await user_service.update(db, user, update_data)
updated_user = await user_service.update(user.id, update_data)
return UserResponse.model_validate(updated_user)
@ -139,7 +139,7 @@ async def get_active_sessions(
@router.delete("/me/link/{provider}", status_code=status.HTTP_204_NO_CONTENT)
async def unlink_oauth_account(
user: CurrentUser,
db: DbSession,
user_service: UserServiceDep,
provider: str,
) -> None:
"""Unlink an OAuth provider from the current user's account.
@ -162,7 +162,7 @@ async def unlink_oauth_account(
)
try:
unlinked = await user_service.unlink_oauth_account(db, user, provider)
unlinked = await user_service.unlink_oauth_account(user.id, user.oauth_provider, provider)
if not unlinked:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,

View File

@ -10,11 +10,12 @@ The protocol pattern enables:
- Offline fork support without rewriting service layer
Usage:
from app.repositories import CollectionRepository, DeckRepository
from app.repositories.postgres import PostgresCollectionRepository
from app.repositories import CollectionRepository, DeckRepository, UserRepository
from app.repositories.postgres import PostgresCollectionRepository, PostgresUserRepository
# In production (dependency injection)
repo = PostgresCollectionRepository(db_session)
user_repo = PostgresUserRepository(db_session)
# In tests
repo = MockCollectionRepository()
@ -23,9 +24,13 @@ Usage:
from app.repositories.protocols import (
CollectionRepository,
DeckRepository,
LinkedAccountRepository,
UserRepository,
)
__all__ = [
"CollectionRepository",
"DeckRepository",
"LinkedAccountRepository",
"UserRepository",
]

View File

@ -8,11 +8,15 @@ Usage:
from app.repositories.postgres import (
PostgresCollectionRepository,
PostgresDeckRepository,
PostgresLinkedAccountRepository,
PostgresUserRepository,
)
# Create repository with database session
collection_repo = PostgresCollectionRepository(db_session)
deck_repo = PostgresDeckRepository(db_session)
user_repo = PostgresUserRepository(db_session)
linked_repo = PostgresLinkedAccountRepository(db_session)
# Use via service layer
service = CollectionService(collection_repo)
@ -20,8 +24,12 @@ Usage:
from app.repositories.postgres.collection import PostgresCollectionRepository
from app.repositories.postgres.deck import PostgresDeckRepository
from app.repositories.postgres.linked_account import PostgresLinkedAccountRepository
from app.repositories.postgres.user import PostgresUserRepository
__all__ = [
"PostgresCollectionRepository",
"PostgresDeckRepository",
"PostgresLinkedAccountRepository",
"PostgresUserRepository",
]

View File

@ -0,0 +1,149 @@
"""PostgreSQL implementation of LinkedAccountRepository.
This module provides the PostgreSQL-specific implementation of the
LinkedAccountRepository protocol using SQLAlchemy async sessions.
Example:
async with get_db_session() as db:
repo = PostgresLinkedAccountRepository(db)
accounts = await repo.get_by_user(user_id)
"""
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.oauth_account import OAuthLinkedAccount
from app.repositories.protocols import LinkedAccountEntry
def _to_dto(model: OAuthLinkedAccount) -> LinkedAccountEntry:
"""Convert ORM model to DTO."""
return LinkedAccountEntry(
id=model.id,
user_id=UUID(model.user_id) if isinstance(model.user_id, str) else model.user_id,
provider=model.provider,
oauth_id=model.oauth_id,
email=model.email,
display_name=model.display_name,
avatar_url=model.avatar_url,
linked_at=model.linked_at,
)
class PostgresLinkedAccountRepository:
"""PostgreSQL implementation of LinkedAccountRepository.
Uses SQLAlchemy async sessions for database access.
Attributes:
_db: The async database session.
"""
def __init__(self, db: AsyncSession) -> None:
"""Initialize with database session.
Args:
db: SQLAlchemy async session.
"""
self._db = db
async def get_by_provider(
self,
provider: str,
oauth_id: str,
) -> LinkedAccountEntry | None:
"""Get a linked account by provider and OAuth ID.
Args:
provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider.
Returns:
LinkedAccountEntry if found, None otherwise.
"""
result = await self._db.execute(
select(OAuthLinkedAccount).where(
OAuthLinkedAccount.provider == provider,
OAuthLinkedAccount.oauth_id == oauth_id,
)
)
model = result.scalar_one_or_none()
return _to_dto(model) if model else None
async def get_by_user(self, user_id: UUID) -> list[LinkedAccountEntry]:
"""Get all linked accounts for a user.
Args:
user_id: The user's UUID.
Returns:
List of LinkedAccountEntry, ordered by provider.
"""
result = await self._db.execute(
select(OAuthLinkedAccount)
.where(OAuthLinkedAccount.user_id == str(user_id))
.order_by(OAuthLinkedAccount.provider)
)
return [_to_dto(model) for model in result.scalars().all()]
async def create(
self,
user_id: UUID,
provider: str,
oauth_id: str,
email: str | None = None,
display_name: str | None = None,
avatar_url: str | None = None,
) -> LinkedAccountEntry:
"""Link an OAuth provider to a user account.
Args:
user_id: The user's UUID.
provider: OAuth provider name.
oauth_id: Unique ID from the OAuth provider.
email: Email from the OAuth provider.
display_name: Display name from the OAuth provider.
avatar_url: Avatar URL from the OAuth provider.
Returns:
The created LinkedAccountEntry.
"""
linked_account = OAuthLinkedAccount(
user_id=str(user_id),
provider=provider,
oauth_id=oauth_id,
email=email,
display_name=display_name,
avatar_url=avatar_url,
)
self._db.add(linked_account)
await self._db.commit()
await self._db.refresh(linked_account)
return _to_dto(linked_account)
async def delete(self, user_id: UUID, provider: str) -> bool:
"""Unlink an OAuth provider from a user account.
Args:
user_id: The user's UUID.
provider: OAuth provider name to unlink.
Returns:
True if deleted, False if not found.
"""
result = await self._db.execute(
select(OAuthLinkedAccount).where(
OAuthLinkedAccount.user_id == str(user_id),
OAuthLinkedAccount.provider == provider,
)
)
linked_account = result.scalar_one_or_none()
if linked_account is None:
return False
await self._db.delete(linked_account)
await self._db.commit()
return True

View File

@ -0,0 +1,243 @@
"""PostgreSQL implementation of UserRepository.
This module provides the PostgreSQL-specific implementation of the
UserRepository protocol using SQLAlchemy async sessions.
Example:
async with get_db_session() as db:
repo = PostgresUserRepository(db)
user = await repo.get_by_id(user_id)
"""
from datetime import UTC, datetime
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.user import User
from app.repositories.protocols import UNSET, UserEntry
def _to_dto(model: User) -> UserEntry:
"""Convert ORM model to DTO."""
return UserEntry(
id=model.id,
email=model.email,
display_name=model.display_name,
avatar_url=model.avatar_url,
oauth_provider=model.oauth_provider,
oauth_id=model.oauth_id,
is_premium=model.is_premium,
premium_until=model.premium_until,
last_login=model.last_login,
created_at=model.created_at,
updated_at=model.updated_at,
)
class PostgresUserRepository:
"""PostgreSQL implementation of UserRepository.
Uses SQLAlchemy async sessions for database access.
Attributes:
_db: The async database session.
"""
def __init__(self, db: AsyncSession) -> None:
"""Initialize with database session.
Args:
db: SQLAlchemy async session.
"""
self._db = db
async def get_by_id(self, user_id: UUID) -> UserEntry | None:
"""Get a user by their ID.
Args:
user_id: The user's UUID.
Returns:
UserEntry if found, None otherwise.
"""
result = await self._db.execute(select(User).where(User.id == user_id))
model = result.scalar_one_or_none()
return _to_dto(model) if model else None
async def get_by_email(self, email: str) -> UserEntry | None:
"""Get a user by their email address.
Args:
email: The user's email address.
Returns:
UserEntry if found, None otherwise.
"""
result = await self._db.execute(select(User).where(User.email == email))
model = result.scalar_one_or_none()
return _to_dto(model) if model else None
async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None:
"""Get a user by their OAuth provider and ID.
Args:
provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider.
Returns:
UserEntry if found, None otherwise.
"""
result = await self._db.execute(
select(User).where(
User.oauth_provider == provider,
User.oauth_id == oauth_id,
)
)
model = result.scalar_one_or_none()
return _to_dto(model) if model else None
async def create(
self,
email: str,
display_name: str,
oauth_provider: str,
oauth_id: str,
avatar_url: str | None = None,
) -> UserEntry:
"""Create a new user.
Args:
email: User's email address.
display_name: Public display name.
oauth_provider: OAuth provider name.
oauth_id: Unique ID from the OAuth provider.
avatar_url: Optional avatar URL.
Returns:
The created UserEntry.
"""
user = User(
email=email,
display_name=display_name,
oauth_provider=oauth_provider,
oauth_id=oauth_id,
avatar_url=avatar_url,
)
self._db.add(user)
await self._db.commit()
await self._db.refresh(user)
return _to_dto(user)
async def update(
self,
user_id: UUID,
display_name: str | None = None,
avatar_url: str | None = UNSET, # type: ignore[assignment]
oauth_provider: str | None = None,
oauth_id: str | None = None,
) -> UserEntry | None:
"""Update user profile fields.
Only provided (non-None/non-UNSET) fields are updated.
Args:
user_id: The user's UUID.
display_name: New display name (None keeps existing).
avatar_url: New avatar URL (UNSET=keep, None=clear, str=set).
oauth_provider: New OAuth provider (None keeps existing).
oauth_id: New OAuth ID (None keeps existing).
Returns:
Updated UserEntry, or None if user not found.
"""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return None
if display_name is not None:
user.display_name = display_name
if avatar_url is not UNSET:
user.avatar_url = avatar_url
if oauth_provider is not None:
user.oauth_provider = oauth_provider
if oauth_id is not None:
user.oauth_id = oauth_id
await self._db.commit()
await self._db.refresh(user)
return _to_dto(user)
async def update_last_login(self, user_id: UUID) -> UserEntry | None:
"""Update the user's last login timestamp to now.
Args:
user_id: The user's UUID.
Returns:
Updated UserEntry, or None if user not found.
"""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return None
user.last_login = datetime.now(UTC)
await self._db.commit()
await self._db.refresh(user)
return _to_dto(user)
async def update_premium(
self,
user_id: UUID,
is_premium: bool,
premium_until: datetime | None,
) -> UserEntry | None:
"""Update user's premium subscription status.
Args:
user_id: The user's UUID.
is_premium: Whether user has premium.
premium_until: When premium expires, or None if not premium.
Returns:
Updated UserEntry, or None if user not found.
"""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return None
user.is_premium = is_premium
user.premium_until = premium_until
await self._db.commit()
await self._db.refresh(user)
return _to_dto(user)
async def delete(self, user_id: UUID) -> bool:
"""Delete a user account.
This will cascade delete all related data (decks, collection, etc.)
based on database constraints.
Args:
user_id: The user's UUID.
Returns:
True if deleted, False if not found.
"""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return False
await self._db.delete(user)
await self._db.commit()
return True

View File

@ -102,6 +102,46 @@ class DeckEntry:
updated_at: datetime
@dataclass
class UserEntry:
"""Storage-agnostic representation of a user account.
This DTO decouples the service layer from the ORM model,
enabling different storage backends (PostgreSQL, SQLite, JSON)
to be used interchangeably.
"""
id: UUID
email: str
display_name: str
avatar_url: str | None
oauth_provider: str
oauth_id: str
is_premium: bool
premium_until: datetime | None
last_login: datetime | None
created_at: datetime
updated_at: datetime
@dataclass
class LinkedAccountEntry:
"""Storage-agnostic representation of a linked OAuth account.
Users can link multiple OAuth providers (e.g., Google + Discord)
to a single account for flexible login options.
"""
id: UUID
user_id: UUID
provider: str
oauth_id: str
email: str | None
display_name: str | None
avatar_url: str | None
linked_at: datetime
# =============================================================================
# Repository Protocols
# =============================================================================
@ -342,3 +382,211 @@ class DeckRepository(Protocol):
Tuple of (has_starter, starter_type).
"""
...
class UserRepository(Protocol):
"""Protocol for user account data access.
Implementations handle storage-specific details (PostgreSQL, SQLite, JSON).
Services use this protocol for business logic without knowing storage details.
Note: Business logic like get_or_create_from_oauth belongs in the service layer,
not in the repository.
"""
async def get_by_id(self, user_id: UUID) -> UserEntry | None:
"""Get a user by their ID.
Args:
user_id: The user's UUID.
Returns:
UserEntry if found, None otherwise.
"""
...
async def get_by_email(self, email: str) -> UserEntry | None:
"""Get a user by their email address.
Args:
email: The user's email address.
Returns:
UserEntry if found, None otherwise.
"""
...
async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None:
"""Get a user by their OAuth provider and ID.
Args:
provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider.
Returns:
UserEntry if found, None otherwise.
"""
...
async def create(
self,
email: str,
display_name: str,
oauth_provider: str,
oauth_id: str,
avatar_url: str | None = None,
) -> UserEntry:
"""Create a new user.
Args:
email: User's email address.
display_name: Public display name.
oauth_provider: OAuth provider name.
oauth_id: Unique ID from the OAuth provider.
avatar_url: Optional avatar URL.
Returns:
The created UserEntry.
"""
...
async def update(
self,
user_id: UUID,
display_name: str | None = None,
avatar_url: str | None = UNSET, # type: ignore[assignment]
oauth_provider: str | None = None,
oauth_id: str | None = None,
) -> UserEntry | None:
"""Update user profile fields.
Only provided (non-None/non-UNSET) fields are updated.
Use UNSET (default) to keep existing value for nullable fields,
or None to explicitly clear them.
Args:
user_id: The user's UUID.
display_name: New display name (None keeps existing).
avatar_url: New avatar URL (UNSET=keep, None=clear, str=set).
oauth_provider: New OAuth provider (None keeps existing).
oauth_id: New OAuth ID (None keeps existing).
Returns:
Updated UserEntry, or None if user not found.
"""
...
async def update_last_login(self, user_id: UUID) -> UserEntry | None:
"""Update the user's last login timestamp to now.
Args:
user_id: The user's UUID.
Returns:
Updated UserEntry, or None if user not found.
"""
...
async def update_premium(
self,
user_id: UUID,
is_premium: bool,
premium_until: datetime | None,
) -> UserEntry | None:
"""Update user's premium subscription status.
Args:
user_id: The user's UUID.
is_premium: Whether user has premium.
premium_until: When premium expires, or None if not premium.
Returns:
Updated UserEntry, or None if user not found.
"""
...
async def delete(self, user_id: UUID) -> bool:
"""Delete a user account.
This will cascade delete all related data (decks, collection, etc.)
based on database constraints.
Args:
user_id: The user's UUID.
Returns:
True if deleted, False if not found.
"""
...
class LinkedAccountRepository(Protocol):
"""Protocol for OAuth linked accounts data access.
Users can link multiple OAuth providers to a single account.
The primary OAuth provider is stored on the User model itself;
additional linked providers are stored as LinkedAccount records.
"""
async def get_by_provider(
self,
provider: str,
oauth_id: str,
) -> LinkedAccountEntry | None:
"""Get a linked account by provider and OAuth ID.
Args:
provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider.
Returns:
LinkedAccountEntry if found, None otherwise.
"""
...
async def get_by_user(self, user_id: UUID) -> list[LinkedAccountEntry]:
"""Get all linked accounts for a user.
Args:
user_id: The user's UUID.
Returns:
List of LinkedAccountEntry, ordered by provider.
"""
...
async def create(
self,
user_id: UUID,
provider: str,
oauth_id: str,
email: str | None = None,
display_name: str | None = None,
avatar_url: str | None = None,
) -> LinkedAccountEntry:
"""Link an OAuth provider to a user account.
Args:
user_id: The user's UUID.
provider: OAuth provider name.
oauth_id: Unique ID from the OAuth provider.
email: Email from the OAuth provider.
display_name: Display name from the OAuth provider.
avatar_url: Avatar URL from the OAuth provider.
Returns:
The created LinkedAccountEntry.
"""
...
async def delete(self, user_id: UUID, provider: str) -> bool:
"""Unlink an OAuth provider from a user account.
Args:
user_id: The user's UUID.
provider: OAuth provider name to unlink.
Returns:
True if deleted, False if not found.
"""
...

View File

@ -1,32 +1,37 @@
"""User service for Mantimon TCG.
This module provides async CRUD operations for user accounts,
including OAuth-based user creation and premium status management.
This module provides business logic for user account operations,
including OAuth-based user creation, account linking, and premium
status management.
All database operations use async SQLAlchemy sessions.
The service layer contains business logic while repositories handle
pure data access. This separation enables testing and different
storage backends.
Example:
from app.services.user_service import user_service
from app.services.user_service import UserService
from app.repositories.postgres import PostgresUserRepository, PostgresLinkedAccountRepository
# Get user by ID
user = await user_service.get_by_id(db, user_id)
# Create service with injected repositories
user_repo = PostgresUserRepository(db)
linked_repo = PostgresLinkedAccountRepository(db)
service = UserService(user_repo, linked_repo)
# Create from OAuth
user = await user_service.create_from_oauth(db, oauth_info)
# Update premium status
user = await user_service.update_premium(db, user_id, premium_until)
# Use service
user = await service.get_by_id(user_id)
user, created = await service.get_or_create_from_oauth(oauth_info)
"""
from datetime import UTC, datetime
from datetime import datetime
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.oauth_account import OAuthLinkedAccount
from app.db.models.user import User
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
from app.repositories.protocols import (
LinkedAccountEntry,
LinkedAccountRepository,
UserEntry,
UserRepository,
)
from app.schemas.user import OAuthUserInfo, UserUpdate
class AccountLinkingError(Exception):
@ -38,117 +43,119 @@ class AccountLinkingError(Exception):
class UserService:
"""Service for user account operations.
Provides async methods for user CRUD, OAuth-based creation,
and premium subscription management.
Provides business logic for user CRUD, OAuth-based creation,
account linking, and premium subscription management.
Attributes:
_user_repo: Repository for user data access.
_linked_repo: Repository for linked account data access.
"""
async def get_by_id(self, db: AsyncSession, user_id: UUID) -> User | None:
def __init__(
self,
user_repository: UserRepository,
linked_account_repository: LinkedAccountRepository,
) -> None:
"""Initialize with repository dependencies.
Args:
user_repository: Repository for user data access.
linked_account_repository: Repository for linked account data access.
"""
self._user_repo = user_repository
self._linked_repo = linked_account_repository
async def get_by_id(self, user_id: UUID) -> UserEntry | None:
"""Get a user by their ID.
Args:
db: Async database session.
user_id: The user's UUID.
Returns:
User if found, None otherwise.
UserEntry if found, None otherwise.
Example:
user = await user_service.get_by_id(db, user_id)
user = await service.get_by_id(user_id)
if user:
print(f"Found user: {user.display_name}")
"""
result = await db.execute(select(User).where(User.id == user_id))
return result.scalar_one_or_none()
return await self._user_repo.get_by_id(user_id)
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
async def get_by_email(self, email: str) -> UserEntry | None:
"""Get a user by their email address.
Args:
db: Async database session.
email: The user's email address.
Returns:
User if found, None otherwise.
UserEntry if found, None otherwise.
Example:
user = await user_service.get_by_email(db, "player@example.com")
user = await service.get_by_email("player@example.com")
"""
result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none()
return await self._user_repo.get_by_email(email)
async def get_by_oauth(
self,
db: AsyncSession,
provider: str,
oauth_id: str,
) -> User | None:
async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None:
"""Get a user by their OAuth provider and ID.
Args:
db: Async database session.
provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider.
Returns:
User if found, None otherwise.
UserEntry if found, None otherwise.
Example:
user = await user_service.get_by_oauth(db, "google", "123456789")
user = await service.get_by_oauth("google", "123456789")
"""
result = await db.execute(
select(User).where(
User.oauth_provider == provider,
User.oauth_id == oauth_id,
)
)
return result.scalar_one_or_none()
return await self._user_repo.get_by_oauth(provider, oauth_id)
async def create(self, db: AsyncSession, user_data: UserCreate) -> User:
async def create(
self,
email: str,
display_name: str,
oauth_provider: str,
oauth_id: str,
avatar_url: str | None = None,
) -> UserEntry:
"""Create a new user.
Args:
db: Async database session.
user_data: User creation data.
email: User's email address.
display_name: Public display name.
oauth_provider: OAuth provider name.
oauth_id: Unique ID from the OAuth provider.
avatar_url: Optional avatar URL.
Returns:
The created User instance.
The created UserEntry.
Example:
user_data = UserCreate(
user = await service.create(
email="player@example.com",
display_name="Player1",
oauth_provider="google",
oauth_id="123456789"
)
user = await user_service.create(db, user_data)
"""
user = User(
email=user_data.email,
display_name=user_data.display_name,
avatar_url=user_data.avatar_url,
oauth_provider=user_data.oauth_provider,
oauth_id=user_data.oauth_id,
return await self._user_repo.create(
email=email,
display_name=display_name,
oauth_provider=oauth_provider,
oauth_id=oauth_id,
avatar_url=avatar_url,
)
db.add(user)
await db.commit()
await db.refresh(user)
return user
async def create_from_oauth(
self,
db: AsyncSession,
oauth_info: OAuthUserInfo,
) -> User:
async def create_from_oauth(self, oauth_info: OAuthUserInfo) -> UserEntry:
"""Create a new user from OAuth provider info.
Convenience method that converts OAuthUserInfo to UserCreate.
Convenience method that extracts fields from OAuthUserInfo.
Args:
db: Async database session.
oauth_info: Normalized OAuth user information.
Returns:
The created User instance.
The created UserEntry.
Example:
oauth_info = OAuthUserInfo(
@ -158,209 +165,212 @@ class UserService:
name="Player One",
avatar_url="https://..."
)
user = await user_service.create_from_oauth(db, oauth_info)
user = await service.create_from_oauth(oauth_info)
"""
user_data = oauth_info.to_user_create()
return await self.create(db, user_data)
return await self.create(
email=oauth_info.email,
display_name=oauth_info.name,
oauth_provider=oauth_info.provider,
oauth_id=oauth_info.oauth_id,
avatar_url=oauth_info.avatar_url,
)
async def get_or_create_from_oauth(
self,
db: AsyncSession,
oauth_info: OAuthUserInfo,
) -> tuple[User, bool]:
) -> tuple[UserEntry, bool]:
"""Get existing user or create new one from OAuth info.
First checks for existing user by OAuth provider+ID, then by email
(for account linking), and finally creates a new user if not found.
Args:
db: Async database session.
oauth_info: Normalized OAuth user information.
Returns:
Tuple of (User, created) where created is True if new user.
Tuple of (UserEntry, created) where created is True if new user.
Example:
user, created = await user_service.get_or_create_from_oauth(db, oauth_info)
user, created = await service.get_or_create_from_oauth(oauth_info)
if created:
print("Welcome, new user!")
else:
print("Welcome back!")
"""
# First, check by OAuth provider + ID (exact match)
user = await self.get_by_oauth(db, oauth_info.provider, oauth_info.oauth_id)
user = await self._user_repo.get_by_oauth(oauth_info.provider, oauth_info.oauth_id)
if user:
return user, False
# Check by email for potential account linking
# If user exists with same email but different OAuth, update their OAuth
user = await self.get_by_email(db, oauth_info.email)
user = await self._user_repo.get_by_email(oauth_info.email)
if user:
# Update OAuth credentials for existing user
# This links the new OAuth provider to the existing account
user.oauth_provider = oauth_info.provider
user.oauth_id = oauth_info.oauth_id
# Optionally update avatar if not set
if not user.avatar_url and oauth_info.avatar_url:
user.avatar_url = oauth_info.avatar_url
await db.commit()
await db.refresh(user)
return user, False
updated_user = await self._user_repo.update(
user_id=user.id,
oauth_provider=oauth_info.provider,
oauth_id=oauth_info.oauth_id,
avatar_url=oauth_info.avatar_url if not user.avatar_url else None,
)
return updated_user or user, False
# Create new user
user = await self.create_from_oauth(db, oauth_info)
user = await self.create_from_oauth(oauth_info)
return user, True
async def update(
self,
db: AsyncSession,
user: User,
user_id: UUID,
update_data: UserUpdate,
) -> User:
) -> UserEntry | None:
"""Update user profile fields.
Only updates fields that are provided (not None).
Only updates fields that are explicitly provided. Uses UNSET pattern
for avatar_url to distinguish "not provided" from "set to None".
Args:
db: Async database session.
user: The user to update.
user_id: The user's UUID.
update_data: Fields to update.
Returns:
The updated User instance.
The updated UserEntry, or None if not found.
Example:
update_data = UserUpdate(display_name="New Name")
user = await user_service.update(db, user, update_data)
user = await service.update(user_id, update_data)
"""
if update_data.display_name is not None:
user.display_name = update_data.display_name
if update_data.avatar_url is not None:
user.avatar_url = update_data.avatar_url
from app.repositories.protocols import UNSET
await db.commit()
await db.refresh(user)
return user
# Use UNSET for avatar_url unless explicitly provided
avatar_url = (
update_data.avatar_url if "avatar_url" in update_data.model_fields_set else UNSET
)
async def update_last_login(self, db: AsyncSession, user: User) -> User:
return await self._user_repo.update(
user_id=user_id,
display_name=update_data.display_name,
avatar_url=avatar_url,
)
async def update_last_login(self, user_id: UUID) -> UserEntry | None:
"""Update the user's last login timestamp.
Args:
db: Async database session.
user: The user to update.
user_id: The user's UUID.
Returns:
The updated User instance.
The updated UserEntry, or None if not found.
Example:
user = await user_service.update_last_login(db, user)
user = await service.update_last_login(user_id)
"""
user.last_login = datetime.now(UTC)
await db.commit()
await db.refresh(user)
return user
return await self._user_repo.update_last_login(user_id)
async def update_premium(
self,
db: AsyncSession,
user: User,
user_id: UUID,
premium_until: datetime | None,
) -> User:
) -> UserEntry | None:
"""Update user's premium subscription status.
Args:
db: Async database session.
user: The user to update.
user_id: The user's UUID.
premium_until: When premium expires, or None to remove premium.
Returns:
The updated User instance.
The updated UserEntry, or None if not found.
Example:
# Grant 30 days of premium
expires = datetime.now(UTC) + timedelta(days=30)
user = await user_service.update_premium(db, user, expires)
user = await service.update_premium(user_id, expires)
# Remove premium
user = await user_service.update_premium(db, user, None)
user = await service.update_premium(user_id, None)
"""
if premium_until is not None:
user.is_premium = True
user.premium_until = premium_until
else:
user.is_premium = False
user.premium_until = None
is_premium = premium_until is not None
return await self._user_repo.update_premium(
user_id=user_id,
is_premium=is_premium,
premium_until=premium_until,
)
await db.commit()
await db.refresh(user)
return user
async def delete(self, db: AsyncSession, user: User) -> None:
async def delete(self, user_id: UUID) -> bool:
"""Delete a user account.
This will cascade delete all related data (decks, collection, etc.)
based on the model relationships.
based on the database constraints.
Args:
db: Async database session.
user: The user to delete.
user_id: The user's UUID.
Returns:
True if deleted, False if not found.
Example:
await user_service.delete(db, user)
success = await service.delete(user_id)
"""
await db.delete(user)
await db.commit()
return await self._user_repo.delete(user_id)
# =========================================================================
# Linked Account Operations
# =========================================================================
async def get_linked_accounts(self, user_id: UUID) -> list[LinkedAccountEntry]:
"""Get all linked OAuth accounts for a user.
Args:
user_id: The user's UUID.
Returns:
List of LinkedAccountEntry.
"""
return await self._linked_repo.get_by_user(user_id)
async def get_linked_account(
self,
db: AsyncSession,
provider: str,
oauth_id: str,
) -> OAuthLinkedAccount | None:
) -> LinkedAccountEntry | None:
"""Get a linked account by provider and OAuth ID.
Args:
db: Async database session.
provider: OAuth provider name (google, discord).
oauth_id: Unique ID from the OAuth provider.
Returns:
OAuthLinkedAccount if found, None otherwise.
LinkedAccountEntry if found, None otherwise.
"""
result = await db.execute(
select(OAuthLinkedAccount).where(
OAuthLinkedAccount.provider == provider,
OAuthLinkedAccount.oauth_id == oauth_id,
)
)
return result.scalar_one_or_none()
return await self._linked_repo.get_by_provider(provider, oauth_id)
async def link_oauth_account(
self,
db: AsyncSession,
user: User,
user_id: UUID,
user_oauth_provider: str,
oauth_info: OAuthUserInfo,
) -> OAuthLinkedAccount:
) -> LinkedAccountEntry:
"""Link an additional OAuth provider to a user account.
Args:
db: Async database session.
user: The user to link the account to.
user_id: The user's UUID.
user_oauth_provider: The user's primary OAuth provider.
oauth_info: OAuth information from the provider.
Returns:
The created OAuthLinkedAccount.
The created LinkedAccountEntry.
Raises:
AccountLinkingError: If provider is already linked to this or another user.
Example:
linked = await user_service.link_oauth_account(db, user, discord_info)
linked = await service.link_oauth_account(user.id, user.oauth_provider, discord_info)
"""
# Check if this provider+oauth_id is already linked to any user
existing = await self.get_linked_account(db, oauth_info.provider, oauth_info.oauth_id)
existing = await self._linked_repo.get_by_provider(oauth_info.provider, oauth_info.oauth_id)
if existing:
if str(existing.user_id) == str(user.id):
if existing.user_id == user_id:
raise AccountLinkingError(
f"{oauth_info.provider.title()} account is already linked to your account"
)
@ -369,36 +379,33 @@ class UserService:
)
# Check if this is the user's primary OAuth provider
if user.oauth_provider == oauth_info.provider:
if user_oauth_provider == oauth_info.provider:
raise AccountLinkingError(
f"{oauth_info.provider.title()} is your primary login provider"
)
# Check if user already has this provider linked
for linked in user.linked_accounts:
linked_accounts = await self._linked_repo.get_by_user(user_id)
for linked in linked_accounts:
if linked.provider == oauth_info.provider:
raise AccountLinkingError(
f"You already have a {oauth_info.provider.title()} account linked"
)
# Create the linked account
linked_account = OAuthLinkedAccount(
user_id=str(user.id),
return await self._linked_repo.create(
user_id=user_id,
provider=oauth_info.provider,
oauth_id=oauth_info.oauth_id,
email=oauth_info.email,
display_name=oauth_info.name,
avatar_url=oauth_info.avatar_url,
)
db.add(linked_account)
await db.commit()
await db.refresh(linked_account)
return linked_account
async def unlink_oauth_account(
self,
db: AsyncSession,
user: User,
user_id: UUID,
user_oauth_provider: str,
provider: str,
) -> bool:
"""Unlink an OAuth provider from a user account.
@ -406,8 +413,8 @@ class UserService:
Cannot unlink the primary OAuth provider.
Args:
db: Async database session.
user: The user to unlink from.
user_id: The user's UUID.
user_oauth_provider: The user's primary OAuth provider.
provider: OAuth provider name to unlink.
Returns:
@ -417,23 +424,12 @@ class UserService:
AccountLinkingError: If trying to unlink the primary provider.
Example:
success = await user_service.unlink_oauth_account(db, user, "discord")
success = await service.unlink_oauth_account(user.id, user.oauth_provider, "discord")
"""
# Cannot unlink primary provider
if user.oauth_provider == provider:
if user_oauth_provider == provider:
raise AccountLinkingError(
f"Cannot unlink {provider.title()} - it is your primary login provider"
)
# Find and delete the linked account
for linked in user.linked_accounts:
if linked.provider == provider:
await db.delete(linked)
await db.commit()
return True
return False
# Global service instance
user_service = UserService()
return await self._linked_repo.delete(user_id, provider)

View File

@ -2,14 +2,19 @@
Tests the authentication endpoints including OAuth redirects,
token refresh, and logout.
Uses FastAPI's dependency override pattern for proper dependency injection testing.
"""
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID
import pytest
from fastapi import status
from fastapi.testclient import TestClient
from app.api.deps import get_user_service
class TestGoogleAuthRedirect:
"""Tests for GET /api/auth/google endpoint."""
@ -50,11 +55,32 @@ class TestDiscordAuthRedirect:
assert "not configured" in response.json()["detail"]
@pytest.fixture
def mock_user_service_instance():
"""Create a mock UserService for dependency injection.
Returns a MagicMock with async methods configured.
"""
mock = MagicMock()
mock.get_by_id = AsyncMock()
mock.get_by_email = AsyncMock()
mock.get_by_oauth = AsyncMock()
mock.get_or_create_from_oauth = AsyncMock()
mock.update_last_login = AsyncMock()
return mock
class TestRefreshTokens:
"""Tests for POST /api/auth/refresh endpoint."""
def test_returns_new_access_token(
self, client: TestClient, test_user, refresh_token_data, mock_get_redis
self,
app,
client: TestClient,
test_user,
refresh_token_data,
mock_get_redis,
mock_user_service_instance,
):
"""Test that refresh endpoint returns new access token for valid refresh token.
@ -70,25 +96,31 @@ class TestRefreshTokens:
asyncio.get_event_loop().run_until_complete(setup_token())
# Mock user service to return our test user
with patch("app.api.auth.user_service") as mock_user_service:
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
# Configure mock to return test user
# Convert to UserEntry-like object
mock_user_entry = MagicMock()
mock_user_entry.id = test_user.id
mock_user_entry.email = test_user.email
mock_user_service_instance.get_by_id.return_value = mock_user_entry
with (
patch("app.api.auth.get_redis", mock_get_redis),
patch("app.services.token_store.get_redis", mock_get_redis),
):
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token_data["token"]},
)
# Override the dependency on the test app
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "access_token" in data
assert data["refresh_token"] == refresh_token_data["token"]
assert data["token_type"] == "bearer"
assert "expires_in" in data
try:
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token_data["token"]},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "access_token" in data
assert data["refresh_token"] == refresh_token_data["token"]
assert data["token_type"] == "bearer"
assert "expires_in" in data
finally:
# Clean up override
app.dependency_overrides.pop(get_user_service, None)
def test_returns_401_for_invalid_token(self, client: TestClient):
"""Test that refresh endpoint returns 401 for invalid refresh token."""
@ -107,21 +139,23 @@ class TestRefreshTokens:
A refresh token not in Redis (revoked/expired) should be rejected.
"""
# Don't store the token in Redis - simulating revocation
# The mock_get_redis is already patched via conftest's app fixture
with (
patch("app.api.auth.get_redis", mock_get_redis),
patch("app.services.token_store.get_redis", mock_get_redis),
):
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token_data["token"]},
)
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token_data["token"]},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert "revoked" in response.json()["detail"]
def test_returns_401_for_deleted_user(
self, client: TestClient, refresh_token_data, mock_get_redis
self,
app,
client: TestClient,
refresh_token_data,
mock_get_redis,
mock_user_service_instance,
):
"""Test that refresh endpoint returns 401 if user no longer exists."""
# Store the token
@ -134,21 +168,23 @@ class TestRefreshTokens:
asyncio.get_event_loop().run_until_complete(setup_token())
# Mock user service to return None (user deleted)
with patch("app.api.auth.user_service") as mock_user_service:
mock_user_service.get_by_id = AsyncMock(return_value=None)
# Configure mock to return None (user deleted)
mock_user_service_instance.get_by_id.return_value = None
with (
patch("app.api.auth.get_redis", mock_get_redis),
patch("app.services.token_store.get_redis", mock_get_redis),
):
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token_data["token"]},
)
# Override the dependency on the test app
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert "User not found" in response.json()["detail"]
try:
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token_data["token"]},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert "User not found" in response.json()["detail"]
finally:
# Clean up override
app.dependency_overrides.pop(get_user_service, None)
class TestLogout:
@ -171,14 +207,10 @@ class TestLogout:
key = asyncio.get_event_loop().run_until_complete(setup_and_check())
# Logout
with (
patch("app.api.auth.get_redis", mock_get_redis),
patch("app.services.token_store.get_redis", mock_get_redis),
):
response = client.post(
"/api/auth/logout",
json={"refresh_token": refresh_token_data["token"]},
)
response = client.post(
"/api/auth/logout",
json={"refresh_token": refresh_token_data["token"]},
)
assert response.status_code == status.HTTP_204_NO_CONTENT
@ -214,7 +246,9 @@ class TestLogoutAll:
response = client.post("/api/auth/logout-all")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_revokes_all_tokens(self, client: TestClient, test_user, access_token, mock_get_redis):
def test_revokes_all_tokens(
self, app, client: TestClient, test_user, access_token, mock_get_redis, mock_db_session
):
"""Test that logout-all revokes all refresh tokens for user.
Should delete all tokens matching the user's ID pattern.
@ -233,18 +267,16 @@ class TestLogoutAll:
asyncio.get_event_loop().run_until_complete(setup_tokens())
# Mock dependencies
with patch("app.api.deps.user_service") as mock_user_service:
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
# Set up mock db session to return test user when queried
# The get_current_user dependency now does a direct DB query
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with (
patch("app.api.auth.get_redis", mock_get_redis),
patch("app.services.token_store.get_redis", mock_get_redis),
):
response = client.post(
"/api/auth/logout-all",
headers={"Authorization": f"Bearer {access_token}"},
)
response = client.post(
"/api/auth/logout-all",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_204_NO_CONTENT

View File

@ -1,32 +1,55 @@
"""Tests for users API endpoints.
Tests the user profile management endpoints.
Uses FastAPI's dependency override pattern for proper dependency injection testing.
"""
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID
import pytest
from fastapi import status
from fastapi.testclient import TestClient
from app.api.deps import get_user_service
from app.services.user_service import AccountLinkingError
@pytest.fixture
def mock_user_service_instance():
"""Create a mock UserService for dependency injection.
Returns a MagicMock with async methods configured.
"""
mock = MagicMock()
mock.get_by_id = AsyncMock()
mock.get_by_email = AsyncMock()
mock.get_by_oauth = AsyncMock()
mock.update = AsyncMock()
mock.unlink_oauth_account = AsyncMock()
return mock
class TestGetCurrentUser:
"""Tests for GET /api/users/me endpoint."""
def test_returns_user_profile(self, client: TestClient, test_user, access_token):
def test_returns_user_profile(
self, app, client: TestClient, test_user, access_token, mock_db_session
):
"""Test that endpoint returns user profile for authenticated user.
Should return the user's profile information.
"""
with patch("app.api.deps.user_service") as mock_user_service:
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
# Set up mock db session to return test user when queried
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
response = client.get(
"/api/users/me",
headers={"Authorization": f"Bearer {access_token}"},
)
response = client.get(
"/api/users/me",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
@ -52,47 +75,93 @@ class TestGetCurrentUser:
class TestUpdateCurrentUser:
"""Tests for PATCH /api/users/me endpoint."""
def test_updates_display_name(self, client: TestClient, test_user, access_token):
def test_updates_display_name(
self,
app,
client: TestClient,
test_user,
access_token,
mock_db_session,
mock_user_service_instance,
):
"""Test that endpoint updates display_name when provided."""
updated_user = test_user
# Create an updated user mock
updated_user = MagicMock()
updated_user.id = test_user.id
updated_user.email = test_user.email
updated_user.display_name = "New Name"
updated_user.avatar_url = test_user.avatar_url
updated_user.is_premium = test_user.is_premium
updated_user.premium_until = test_user.premium_until
updated_user.created_at = test_user.created_at
with patch("app.api.deps.user_service") as mock_deps_service:
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
# Set up db session to return test user for authentication
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with patch("app.api.users.user_service") as mock_user_service:
mock_user_service.update = AsyncMock(return_value=updated_user)
# Set up user service mock
mock_user_service_instance.update.return_value = updated_user
response = client.patch(
"/api/users/me",
headers={"Authorization": f"Bearer {access_token}"},
json={"display_name": "New Name"},
)
# Override the dependency
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["display_name"] == "New Name"
try:
response = client.patch(
"/api/users/me",
headers={"Authorization": f"Bearer {access_token}"},
json={"display_name": "New Name"},
)
def test_updates_avatar_url(self, client: TestClient, test_user, access_token):
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["display_name"] == "New Name"
finally:
app.dependency_overrides.pop(get_user_service, None)
def test_updates_avatar_url(
self,
app,
client: TestClient,
test_user,
access_token,
mock_db_session,
mock_user_service_instance,
):
"""Test that endpoint updates avatar_url when provided."""
updated_user = test_user
# Create an updated user mock
updated_user = MagicMock()
updated_user.id = test_user.id
updated_user.email = test_user.email
updated_user.display_name = test_user.display_name
updated_user.avatar_url = "https://new-avatar.com/img.jpg"
updated_user.is_premium = test_user.is_premium
updated_user.premium_until = test_user.premium_until
updated_user.created_at = test_user.created_at
with patch("app.api.deps.user_service") as mock_deps_service:
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
# Set up db session to return test user for authentication
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with patch("app.api.users.user_service") as mock_user_service:
mock_user_service.update = AsyncMock(return_value=updated_user)
# Set up user service mock
mock_user_service_instance.update.return_value = updated_user
response = client.patch(
"/api/users/me",
headers={"Authorization": f"Bearer {access_token}"},
json={"avatar_url": "https://new-avatar.com/img.jpg"},
)
# Override the dependency
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
try:
response = client.patch(
"/api/users/me",
headers={"Authorization": f"Bearer {access_token}"},
json={"avatar_url": "https://new-avatar.com/img.jpg"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
finally:
app.dependency_overrides.pop(get_user_service, None)
def test_requires_authentication(self, client: TestClient):
"""Test that endpoint returns 401 without authentication."""
@ -106,18 +175,22 @@ class TestUpdateCurrentUser:
class TestGetLinkedAccounts:
"""Tests for GET /api/users/me/linked-accounts endpoint."""
def test_returns_linked_accounts(self, client: TestClient, test_user, access_token):
def test_returns_linked_accounts(
self, app, client: TestClient, test_user, access_token, mock_db_session
):
"""Test that endpoint returns list of linked OAuth accounts.
Should include the primary provider and any linked accounts.
"""
with patch("app.api.deps.user_service") as mock_user_service:
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
# Set up db session to return test user for authentication
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
response = client.get(
"/api/users/me/linked-accounts",
headers={"Authorization": f"Bearer {access_token}"},
)
response = client.get(
"/api/users/me/linked-accounts",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
@ -135,7 +208,7 @@ class TestGetActiveSessions:
"""Tests for GET /api/users/me/sessions endpoint."""
def test_returns_session_count(
self, client: TestClient, test_user, access_token, mock_get_redis
self, app, client: TestClient, test_user, access_token, mock_get_redis, mock_db_session
):
"""Test that endpoint returns count of active sessions.
@ -154,14 +227,16 @@ class TestGetActiveSessions:
asyncio.get_event_loop().run_until_complete(setup_tokens())
with patch("app.api.deps.user_service") as mock_user_service:
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
# Set up db session to return test user for authentication
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with patch("app.services.token_store.get_redis", mock_get_redis):
response = client.get(
"/api/users/me/sessions",
headers={"Authorization": f"Bearer {access_token}"},
)
with patch("app.services.token_store.get_redis", mock_get_redis):
response = client.get(
"/api/users/me/sessions",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
@ -177,78 +252,128 @@ class TestGetActiveSessions:
class TestUnlinkOAuthAccount:
"""Tests for DELETE /api/users/me/link/{provider} endpoint."""
def test_unlinks_provider_successfully(self, client: TestClient, test_user, access_token):
def test_unlinks_provider_successfully(
self,
app,
client: TestClient,
test_user,
access_token,
mock_db_session,
mock_user_service_instance,
):
"""Test that endpoint successfully unlinks a provider.
Should return 204 when provider is unlinked.
"""
with patch("app.api.deps.user_service") as mock_deps_service:
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
# Set up db session to return test user for authentication
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with patch("app.api.users.user_service") as mock_user_service:
mock_user_service.unlink_oauth_account = AsyncMock(return_value=True)
# Set up user service mock
mock_user_service_instance.unlink_oauth_account.return_value = True
response = client.delete(
"/api/users/me/link/discord",
headers={"Authorization": f"Bearer {access_token}"},
)
# Override the dependency
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
assert response.status_code == status.HTTP_204_NO_CONTENT
try:
response = client.delete(
"/api/users/me/link/discord",
headers={"Authorization": f"Bearer {access_token}"},
)
def test_returns_404_if_not_linked(self, client: TestClient, test_user, access_token):
assert response.status_code == status.HTTP_204_NO_CONTENT
finally:
app.dependency_overrides.pop(get_user_service, None)
def test_returns_404_if_not_linked(
self,
app,
client: TestClient,
test_user,
access_token,
mock_db_session,
mock_user_service_instance,
):
"""Test that endpoint returns 404 if provider isn't linked.
Should return 404 when trying to unlink a provider that isn't linked.
"""
with patch("app.api.deps.user_service") as mock_deps_service:
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
# Set up db session to return test user for authentication
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with patch("app.api.users.user_service") as mock_user_service:
mock_user_service.unlink_oauth_account = AsyncMock(return_value=False)
# Set up user service mock
mock_user_service_instance.unlink_oauth_account.return_value = False
response = client.delete(
"/api/users/me/link/discord",
headers={"Authorization": f"Bearer {access_token}"},
)
# Override the dependency
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "not linked" in response.json()["detail"].lower()
try:
response = client.delete(
"/api/users/me/link/discord",
headers={"Authorization": f"Bearer {access_token}"},
)
def test_returns_400_for_primary_provider(self, client: TestClient, test_user, access_token):
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "not linked" in response.json()["detail"].lower()
finally:
app.dependency_overrides.pop(get_user_service, None)
def test_returns_400_for_primary_provider(
self,
app,
client: TestClient,
test_user,
access_token,
mock_db_session,
mock_user_service_instance,
):
"""Test that endpoint returns 400 when trying to unlink primary provider.
Cannot unlink the provider used to create the account.
"""
with patch("app.api.deps.user_service") as mock_deps_service:
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
# Set up db session to return test user for authentication
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
with patch("app.api.users.user_service") as mock_user_service:
mock_user_service.unlink_oauth_account = AsyncMock(
side_effect=AccountLinkingError(
"Cannot unlink Google - it is your primary login provider"
)
)
# Set up user service mock to raise AccountLinkingError
mock_user_service_instance.unlink_oauth_account.side_effect = AccountLinkingError(
"Cannot unlink Google - it is your primary login provider"
)
response = client.delete(
"/api/users/me/link/google",
headers={"Authorization": f"Bearer {access_token}"},
)
# Override the dependency
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "primary" in response.json()["detail"].lower()
try:
response = client.delete(
"/api/users/me/link/google",
headers={"Authorization": f"Bearer {access_token}"},
)
def test_returns_400_for_unknown_provider(self, client: TestClient, test_user, access_token):
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "primary" in response.json()["detail"].lower()
finally:
app.dependency_overrides.pop(get_user_service, None)
def test_returns_400_for_unknown_provider(
self, app, client: TestClient, test_user, access_token, mock_db_session
):
"""Test that endpoint returns 400 for unknown provider.
Only 'google' and 'discord' are valid providers.
"""
with patch("app.api.deps.user_service") as mock_deps_service:
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
# Set up db session to return test user for authentication
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = test_user
mock_db_session.execute = AsyncMock(return_value=mock_result)
response = client.delete(
"/api/users/me/link/twitter",
headers={"Authorization": f"Bearer {access_token}"},
)
response = client.delete(
"/api/users/me/link/twitter",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "unknown provider" in response.json()["detail"].lower()

View File

@ -2,25 +2,43 @@
Tests the user service CRUD operations and OAuth-based user creation.
Uses real Postgres via the db_session fixture from conftest.
The UserService now uses dependency injection with repositories injected
via constructor, so we create a fresh service instance per test.
"""
from datetime import UTC, datetime, timedelta
from uuid import UUID
import pytest
from app.db.models import User
from app.db.models.oauth_account import OAuthLinkedAccount
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
from app.services.user_service import AccountLinkingError, user_service
from app.repositories.postgres.linked_account import PostgresLinkedAccountRepository
from app.repositories.postgres.user import PostgresUserRepository
from app.schemas.user import OAuthUserInfo, UserUpdate
from app.services.user_service import AccountLinkingError, UserService
# Import db_session fixture from db conftest
pytestmark = pytest.mark.asyncio
@pytest.fixture
def user_service(db_session):
"""Create UserService with real PostgreSQL repositories.
This fixture provides a properly constructed UserService for each test,
following the dependency injection pattern used in production.
"""
user_repo = PostgresUserRepository(db_session)
linked_repo = PostgresLinkedAccountRepository(db_session)
return UserService(user_repo, linked_repo)
class TestGetById:
"""Tests for get_by_id method."""
async def test_returns_user_when_found(self, db_session):
async def test_returns_user_when_found(self, db_session, user_service):
"""Test that get_by_id returns user when it exists.
Creates a user and verifies it can be retrieved by ID.
@ -36,26 +54,24 @@ class TestGetById:
await db_session.commit()
# Retrieve by ID
from uuid import UUID
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
result = await user_service.get_by_id(db_session, user_id)
result = await user_service.get_by_id(user_id)
assert result is not None
assert result.email == "test@example.com"
async def test_returns_none_when_not_found(self, db_session):
async def test_returns_none_when_not_found(self, user_service):
"""Test that get_by_id returns None for nonexistent users."""
from uuid import uuid4
result = await user_service.get_by_id(db_session, uuid4())
result = await user_service.get_by_id(uuid4())
assert result is None
class TestGetByEmail:
"""Tests for get_by_email method."""
async def test_returns_user_when_found(self, db_session):
async def test_returns_user_when_found(self, db_session, user_service):
"""Test that get_by_email returns user when it exists."""
user = User(
email="findme@example.com",
@ -66,21 +82,21 @@ class TestGetByEmail:
db_session.add(user)
await db_session.commit()
result = await user_service.get_by_email(db_session, "findme@example.com")
result = await user_service.get_by_email("findme@example.com")
assert result is not None
assert result.display_name == "Find Me"
async def test_returns_none_when_not_found(self, db_session):
async def test_returns_none_when_not_found(self, user_service):
"""Test that get_by_email returns None for nonexistent emails."""
result = await user_service.get_by_email(db_session, "nobody@example.com")
result = await user_service.get_by_email("nobody@example.com")
assert result is None
class TestGetByOAuth:
"""Tests for get_by_oauth method."""
async def test_returns_user_when_found(self, db_session):
async def test_returns_user_when_found(self, db_session, user_service):
"""Test that get_by_oauth returns user for matching provider+id."""
user = User(
email="oauth@example.com",
@ -91,12 +107,12 @@ class TestGetByOAuth:
db_session.add(user)
await db_session.commit()
result = await user_service.get_by_oauth(db_session, "google", "google-unique-id")
result = await user_service.get_by_oauth("google", "google-unique-id")
assert result is not None
assert result.email == "oauth@example.com"
async def test_returns_none_for_wrong_provider(self, db_session):
async def test_returns_none_for_wrong_provider(self, db_session, user_service):
"""Test that get_by_oauth returns None if provider doesn't match."""
user = User(
email="oauth2@example.com",
@ -108,30 +124,28 @@ class TestGetByOAuth:
await db_session.commit()
# Same ID, different provider
result = await user_service.get_by_oauth(db_session, "discord", "google-id-2")
result = await user_service.get_by_oauth("discord", "google-id-2")
assert result is None
async def test_returns_none_when_not_found(self, db_session):
async def test_returns_none_when_not_found(self, user_service):
"""Test that get_by_oauth returns None for nonexistent OAuth."""
result = await user_service.get_by_oauth(db_session, "google", "nonexistent")
result = await user_service.get_by_oauth("google", "nonexistent")
assert result is None
class TestCreate:
"""Tests for create method."""
async def test_creates_user_with_all_fields(self, db_session):
async def test_creates_user_with_all_fields(self, user_service):
"""Test that create properly persists all user fields."""
user_data = UserCreate(
result = await user_service.create(
email="new@example.com",
display_name="New User",
avatar_url="https://example.com/avatar.jpg",
oauth_provider="discord",
oauth_id="discord-new-id",
avatar_url="https://example.com/avatar.jpg",
)
result = await user_service.create(db_session, user_data)
assert result.id is not None
assert result.email == "new@example.com"
assert result.display_name == "New User"
@ -141,24 +155,22 @@ class TestCreate:
assert result.is_premium is False
assert result.premium_until is None
async def test_creates_user_without_avatar(self, db_session):
async def test_creates_user_without_avatar(self, user_service):
"""Test that create works without optional avatar_url."""
user_data = UserCreate(
result = await user_service.create(
email="noavatar@example.com",
display_name="No Avatar",
oauth_provider="google",
oauth_id="google-no-avatar",
)
result = await user_service.create(db_session, user_data)
assert result.avatar_url is None
class TestCreateFromOAuth:
"""Tests for create_from_oauth method."""
async def test_creates_user_from_oauth_info(self, db_session):
async def test_creates_user_from_oauth_info(self, user_service):
"""Test that create_from_oauth converts OAuthUserInfo to User."""
oauth_info = OAuthUserInfo(
provider="google",
@ -168,7 +180,7 @@ class TestCreateFromOAuth:
avatar_url="https://google.com/avatar.jpg",
)
result = await user_service.create_from_oauth(db_session, oauth_info)
result = await user_service.create_from_oauth(oauth_info)
assert result.email == "oauthcreate@example.com"
assert result.display_name == "OAuth Created User"
@ -179,7 +191,7 @@ class TestCreateFromOAuth:
class TestGetOrCreateFromOAuth:
"""Tests for get_or_create_from_oauth method."""
async def test_returns_existing_user_by_oauth(self, db_session):
async def test_returns_existing_user_by_oauth(self, db_session, user_service):
"""Test that existing user is returned when OAuth matches.
Verifies the method returns (user, False) for existing users.
@ -202,12 +214,12 @@ class TestGetOrCreateFromOAuth:
name="Existing",
)
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
result, created = await user_service.get_or_create_from_oauth(oauth_info)
assert created is False
assert result.id == existing.id
assert str(result.id) == str(existing.id)
async def test_links_existing_user_by_email(self, db_session):
async def test_links_existing_user_by_email(self, db_session, user_service):
"""Test that OAuth is linked when email matches existing user.
If a user exists with the same email but different OAuth,
@ -232,15 +244,15 @@ class TestGetOrCreateFromOAuth:
avatar_url="https://discord.com/avatar.jpg",
)
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
result, created = await user_service.get_or_create_from_oauth(oauth_info)
assert created is False
assert result.id == existing.id
assert str(result.id) == str(existing.id)
# OAuth should be updated to Discord
assert result.oauth_provider == "discord"
assert result.oauth_id == "discord-link-id"
async def test_creates_new_user_when_not_found(self, db_session):
async def test_creates_new_user_when_not_found(self, user_service):
"""Test that new user is created when no match exists.
Verifies the method returns (user, True) for new users.
@ -252,7 +264,7 @@ class TestGetOrCreateFromOAuth:
name="Brand New",
)
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
result, created = await user_service.get_or_create_from_oauth(oauth_info)
assert created is True
assert result.email == "brandnew@example.com"
@ -261,7 +273,7 @@ class TestGetOrCreateFromOAuth:
class TestUpdate:
"""Tests for update method."""
async def test_updates_display_name(self, db_session):
async def test_updates_display_name(self, db_session, user_service):
"""Test that update changes display_name when provided."""
user = User(
email="update@example.com",
@ -272,12 +284,13 @@ class TestUpdate:
db_session.add(user)
await db_session.commit()
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
update_data = UserUpdate(display_name="New Name")
result = await user_service.update(db_session, user, update_data)
result = await user_service.update(user_id, update_data)
assert result.display_name == "New Name"
async def test_updates_avatar_url(self, db_session):
async def test_updates_avatar_url(self, db_session, user_service):
"""Test that update changes avatar_url when provided."""
user = User(
email="avatar@example.com",
@ -288,12 +301,13 @@ class TestUpdate:
db_session.add(user)
await db_session.commit()
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
update_data = UserUpdate(avatar_url="https://new-avatar.com/img.jpg")
result = await user_service.update(db_session, user, update_data)
result = await user_service.update(user_id, update_data)
assert result.avatar_url == "https://new-avatar.com/img.jpg"
async def test_ignores_none_values(self, db_session):
async def test_ignores_none_values(self, db_session, user_service):
"""Test that update doesn't change fields set to None.
Only explicitly provided fields should be updated.
@ -309,8 +323,9 @@ class TestUpdate:
await db_session.commit()
# Update only display_name, leave avatar alone
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
update_data = UserUpdate(display_name="Changed")
result = await user_service.update(db_session, user, update_data)
result = await user_service.update(user_id, update_data)
assert result.display_name == "Changed"
assert result.avatar_url == "https://keep.com/avatar.jpg"
@ -319,7 +334,7 @@ class TestUpdate:
class TestUpdateLastLogin:
"""Tests for update_last_login method."""
async def test_updates_last_login_timestamp(self, db_session):
async def test_updates_last_login_timestamp(self, db_session, user_service):
"""Test that update_last_login sets current timestamp."""
user = User(
email="login@example.com",
@ -332,8 +347,9 @@ class TestUpdateLastLogin:
assert user.last_login is None
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
before = datetime.now(UTC)
result = await user_service.update_last_login(db_session, user)
result = await user_service.update_last_login(user_id)
after = datetime.now(UTC)
assert result.last_login is not None
@ -344,7 +360,7 @@ class TestUpdateLastLogin:
class TestUpdatePremium:
"""Tests for update_premium method."""
async def test_grants_premium(self, db_session):
async def test_grants_premium(self, db_session, user_service):
"""Test that update_premium sets premium status and expiration."""
user = User(
email="premium@example.com",
@ -357,13 +373,14 @@ class TestUpdatePremium:
assert user.is_premium is False
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
expires = datetime.now(UTC) + timedelta(days=30)
result = await user_service.update_premium(db_session, user, expires)
result = await user_service.update_premium(user_id, expires)
assert result.is_premium is True
assert result.premium_until == expires
async def test_removes_premium(self, db_session):
async def test_removes_premium(self, db_session, user_service):
"""Test that update_premium with None removes premium status."""
user = User(
email="unpremium@example.com",
@ -376,7 +393,8 @@ class TestUpdatePremium:
db_session.add(user)
await db_session.commit()
result = await user_service.update_premium(db_session, user, None)
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
result = await user_service.update_premium(user_id, None)
assert result.is_premium is False
assert result.premium_until is None
@ -385,7 +403,7 @@ class TestUpdatePremium:
class TestDelete:
"""Tests for delete method."""
async def test_deletes_user(self, db_session):
async def test_deletes_user(self, db_session, user_service):
"""Test that delete removes user from database."""
user = User(
email="delete@example.com",
@ -396,22 +414,18 @@ class TestDelete:
db_session.add(user)
await db_session.commit()
user_id = user.id
await user_service.delete(db_session, user)
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
await user_service.delete(user_id)
# Verify user is gone
from uuid import UUID
result = await user_service.get_by_id(
db_session, UUID(user_id) if isinstance(user_id, str) else user_id
)
result = await user_service.get_by_id(user_id)
assert result is None
class TestGetLinkedAccount:
"""Tests for get_linked_account method."""
async def test_returns_linked_account_when_found(self, db_session):
async def test_returns_linked_account_when_found(self, db_session, user_service):
"""Test that get_linked_account returns account when it exists.
Creates a user with a linked account and verifies it can be retrieved.
@ -437,22 +451,22 @@ class TestGetLinkedAccount:
await db_session.commit()
# Retrieve linked account
result = await user_service.get_linked_account(db_session, "discord", "discord-linked-123")
result = await user_service.get_linked_account("discord", "discord-linked-123")
assert result is not None
assert result.provider == "discord"
assert result.oauth_id == "discord-linked-123"
async def test_returns_none_when_not_found(self, db_session):
async def test_returns_none_when_not_found(self, user_service):
"""Test that get_linked_account returns None for nonexistent accounts."""
result = await user_service.get_linked_account(db_session, "discord", "nonexistent-id")
result = await user_service.get_linked_account("discord", "nonexistent-id")
assert result is None
class TestLinkOAuthAccount:
"""Tests for link_oauth_account method."""
async def test_links_new_provider(self, db_session):
async def test_links_new_provider(self, db_session, user_service):
"""Test that link_oauth_account successfully links a new provider.
Creates a Google user and links Discord to them.
@ -468,6 +482,8 @@ class TestLinkOAuthAccount:
await db_session.commit()
await db_session.refresh(user)
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
# Link Discord
discord_info = OAuthUserInfo(
provider="discord",
@ -477,16 +493,16 @@ class TestLinkOAuthAccount:
avatar_url="https://discord.com/avatar.png",
)
result = await user_service.link_oauth_account(db_session, user, discord_info)
result = await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
assert result is not None
assert result.provider == "discord"
assert result.oauth_id == "discord-456"
assert result.email == "discord@example.com"
assert result.display_name == "Discord Name"
assert str(result.user_id) == str(user.id)
assert str(result.user_id) == str(user_id)
async def test_raises_error_if_already_linked_to_same_user(self, db_session):
async def test_raises_error_if_already_linked_to_same_user(self, db_session, user_service):
"""Test that linking same provider twice raises error.
A user cannot have the same provider linked multiple times.
@ -501,6 +517,8 @@ class TestLinkOAuthAccount:
await db_session.commit()
await db_session.refresh(user)
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
# Link Discord first time
discord_info = OAuthUserInfo(
provider="discord",
@ -508,16 +526,15 @@ class TestLinkOAuthAccount:
email="first@discord.com",
name="First",
)
await user_service.link_oauth_account(db_session, user, discord_info)
await db_session.refresh(user)
await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
# Try to link same Discord account again
with pytest.raises(AccountLinkingError) as exc_info:
await user_service.link_oauth_account(db_session, user, discord_info)
await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
assert "already linked to your account" in str(exc_info.value)
async def test_raises_error_if_linked_to_another_user(self, db_session):
async def test_raises_error_if_linked_to_another_user(self, db_session, user_service):
"""Test that linking account already linked to another user raises error.
The same OAuth provider+ID cannot be linked to multiple users.
@ -533,13 +550,15 @@ class TestLinkOAuthAccount:
await db_session.commit()
await db_session.refresh(user1)
user1_id = UUID(user1.id) if isinstance(user1.id, str) else user1.id
discord_info = OAuthUserInfo(
provider="discord",
oauth_id="shared-discord",
email="shared@discord.com",
name="Shared",
)
await user_service.link_oauth_account(db_session, user1, discord_info)
await user_service.link_oauth_account(user1_id, user1.oauth_provider, discord_info)
# Create second user
user2 = User(
@ -552,13 +571,15 @@ class TestLinkOAuthAccount:
await db_session.commit()
await db_session.refresh(user2)
user2_id = UUID(user2.id) if isinstance(user2.id, str) else user2.id
# Try to link same Discord account to second user
with pytest.raises(AccountLinkingError) as exc_info:
await user_service.link_oauth_account(db_session, user2, discord_info)
await user_service.link_oauth_account(user2_id, user2.oauth_provider, discord_info)
assert "already linked to another user" in str(exc_info.value)
async def test_raises_error_if_linking_primary_provider(self, db_session):
async def test_raises_error_if_linking_primary_provider(self, db_session, user_service):
"""Test that linking the same provider as primary raises error.
User cannot link Google if they already signed up with Google.
@ -573,6 +594,8 @@ class TestLinkOAuthAccount:
await db_session.commit()
await db_session.refresh(user)
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
# Try to link another Google account
google_info = OAuthUserInfo(
provider="google",
@ -582,7 +605,7 @@ class TestLinkOAuthAccount:
)
with pytest.raises(AccountLinkingError) as exc_info:
await user_service.link_oauth_account(db_session, user, google_info)
await user_service.link_oauth_account(user_id, user.oauth_provider, google_info)
assert "primary login provider" in str(exc_info.value)
@ -590,7 +613,7 @@ class TestLinkOAuthAccount:
class TestUnlinkOAuthAccount:
"""Tests for unlink_oauth_account method."""
async def test_unlinks_linked_account(self, db_session):
async def test_unlinks_linked_account(self, db_session, user_service):
"""Test that unlink_oauth_account removes a linked account.
Links Discord then unlinks it successfully.
@ -605,6 +628,8 @@ class TestUnlinkOAuthAccount:
await db_session.commit()
await db_session.refresh(user)
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
# Link Discord
discord_info = OAuthUserInfo(
provider="discord",
@ -612,22 +637,18 @@ class TestUnlinkOAuthAccount:
email="discord@unlink.com",
name="Discord Unlink",
)
await user_service.link_oauth_account(db_session, user, discord_info)
await db_session.refresh(user)
# Verify linked
assert len(user.linked_accounts) == 1
await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
# Unlink
result = await user_service.unlink_oauth_account(db_session, user, "discord")
result = await user_service.unlink_oauth_account(user_id, user.oauth_provider, "discord")
assert result is True
# Verify unlinked
linked = await user_service.get_linked_account(db_session, "discord", "discord-unlink")
linked = await user_service.get_linked_account("discord", "discord-unlink")
assert linked is None
async def test_returns_false_if_not_linked(self, db_session):
async def test_returns_false_if_not_linked(self, db_session, user_service):
"""Test that unlink returns False if provider isn't linked."""
user = User(
email="not-linked@example.com",
@ -639,11 +660,12 @@ class TestUnlinkOAuthAccount:
await db_session.commit()
await db_session.refresh(user)
result = await user_service.unlink_oauth_account(db_session, user, "discord")
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
result = await user_service.unlink_oauth_account(user_id, user.oauth_provider, "discord")
assert result is False
async def test_raises_error_if_unlinking_primary(self, db_session):
async def test_raises_error_if_unlinking_primary(self, db_session, user_service):
"""Test that unlinking primary provider raises error.
User cannot unlink their primary OAuth provider.
@ -658,7 +680,9 @@ class TestUnlinkOAuthAccount:
await db_session.commit()
await db_session.refresh(user)
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
with pytest.raises(AccountLinkingError) as exc_info:
await user_service.unlink_oauth_account(db_session, user, "google")
await user_service.unlink_oauth_account(user_id, user.oauth_provider, "google")
assert "primary login provider" in str(exc_info.value)