Implement UserRepository pattern with dependency injection
- Add UserRepository and LinkedAccountRepository protocols to protocols.py - Add UserEntry and LinkedAccountEntry DTOs for service layer decoupling - Implement PostgresUserRepository and PostgresLinkedAccountRepository - Refactor UserService to use constructor-injected repositories - Add get_user_service factory and UserServiceDep to API deps - Update auth.py and users.py endpoints to use UserServiceDep - Rewrite tests to use FastAPI dependency overrides (no monkey patching) This follows the established repository pattern used by DeckService and CollectionService, enabling future offline fork support. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
f6e8ab5f67
commit
7fcb86ff51
@ -31,7 +31,7 @@ import secrets
|
||||
from fastapi import APIRouter, HTTPException, Query, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from app.api.deps import CurrentUser, DbSession
|
||||
from app.api.deps import CurrentUser, UserServiceDep
|
||||
from app.config import settings
|
||||
from app.db.redis import get_redis
|
||||
from app.schemas.auth import RefreshTokenRequest, TokenResponse
|
||||
@ -45,7 +45,7 @@ from app.services.jwt_service import (
|
||||
from app.services.oauth.discord import DiscordOAuthError, discord_oauth
|
||||
from app.services.oauth.google import GoogleOAuthError, google_oauth
|
||||
from app.services.token_store import token_store
|
||||
from app.services.user_service import AccountLinkingError, user_service
|
||||
from app.services.user_service import AccountLinkingError
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
@ -150,7 +150,7 @@ async def google_auth_redirect(
|
||||
|
||||
@router.get("/google/callback")
|
||||
async def google_auth_callback(
|
||||
db: DbSession,
|
||||
user_service: UserServiceDep,
|
||||
code: str = Query(..., description="Authorization code from Google"),
|
||||
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||
) -> TokenResponse:
|
||||
@ -182,10 +182,10 @@ async def google_auth_callback(
|
||||
user_info = await google_oauth.get_user_info(code, oauth_callback)
|
||||
|
||||
# Get or create user
|
||||
user, created = await user_service.get_or_create_from_oauth(db, user_info)
|
||||
user, _created = await user_service.get_or_create_from_oauth(user_info)
|
||||
|
||||
# Update last login
|
||||
await user_service.update_last_login(db, user)
|
||||
await user_service.update_last_login(user.id)
|
||||
|
||||
# Create tokens
|
||||
return await _create_tokens_for_user(user.id)
|
||||
@ -243,7 +243,7 @@ async def discord_auth_redirect(
|
||||
|
||||
@router.get("/discord/callback")
|
||||
async def discord_auth_callback(
|
||||
db: DbSession,
|
||||
user_service: UserServiceDep,
|
||||
code: str = Query(..., description="Authorization code from Discord"),
|
||||
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||
) -> RedirectResponse:
|
||||
@ -276,10 +276,10 @@ async def discord_auth_callback(
|
||||
user_info = await discord_oauth.get_user_info(code, oauth_callback)
|
||||
|
||||
# Get or create user
|
||||
user, created = await user_service.get_or_create_from_oauth(db, user_info)
|
||||
user, _created = await user_service.get_or_create_from_oauth(user_info)
|
||||
|
||||
# Update last login
|
||||
await user_service.update_last_login(db, user)
|
||||
await user_service.update_last_login(user.id)
|
||||
|
||||
# Create tokens
|
||||
tokens = await _create_tokens_for_user(user.id)
|
||||
@ -307,7 +307,7 @@ async def discord_auth_callback(
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh_tokens(
|
||||
db: DbSession,
|
||||
user_service: UserServiceDep,
|
||||
request: RefreshTokenRequest,
|
||||
) -> TokenResponse:
|
||||
"""Refresh access token using refresh token.
|
||||
@ -344,7 +344,7 @@ async def refresh_tokens(
|
||||
)
|
||||
|
||||
# Verify user still exists
|
||||
user = await user_service.get_by_id(db, user_id)
|
||||
user = await user_service.get_by_id(user_id)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@ -489,7 +489,7 @@ async def google_link_redirect(
|
||||
|
||||
@router.get("/link/google/callback")
|
||||
async def google_link_callback(
|
||||
db: DbSession,
|
||||
user_service: UserServiceDep,
|
||||
code: str = Query(..., description="Authorization code from Google"),
|
||||
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||
) -> RedirectResponse:
|
||||
@ -523,7 +523,7 @@ async def google_link_callback(
|
||||
from uuid import UUID
|
||||
|
||||
user_id = UUID(user_id_str)
|
||||
user = await user_service.get_by_id(db, user_id)
|
||||
user = await user_service.get_by_id(user_id)
|
||||
if user is None:
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?error=user_not_found",
|
||||
@ -531,7 +531,7 @@ async def google_link_callback(
|
||||
)
|
||||
|
||||
# Link the account
|
||||
await user_service.link_oauth_account(db, user, oauth_info)
|
||||
await user_service.link_oauth_account(user.id, user.oauth_provider, oauth_info)
|
||||
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?linked=google",
|
||||
@ -600,7 +600,7 @@ async def discord_link_redirect(
|
||||
|
||||
@router.get("/link/discord/callback")
|
||||
async def discord_link_callback(
|
||||
db: DbSession,
|
||||
user_service: UserServiceDep,
|
||||
code: str = Query(..., description="Authorization code from Discord"),
|
||||
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||
) -> RedirectResponse:
|
||||
@ -634,7 +634,7 @@ async def discord_link_callback(
|
||||
from uuid import UUID
|
||||
|
||||
user_id = UUID(user_id_str)
|
||||
user = await user_service.get_by_id(db, user_id)
|
||||
user = await user_service.get_by_id(user_id)
|
||||
if user is None:
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?error=user_not_found",
|
||||
@ -642,7 +642,7 @@ async def discord_link_callback(
|
||||
)
|
||||
|
||||
# Link the account
|
||||
await user_service.link_oauth_account(db, user, oauth_info)
|
||||
await user_service.link_oauth_account(user.id, user.oauth_provider, oauth_info)
|
||||
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?linked=discord",
|
||||
|
||||
@ -28,6 +28,7 @@ from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
@ -35,13 +36,15 @@ from app.db import get_session
|
||||
from app.db.models import User
|
||||
from app.repositories.postgres.collection import PostgresCollectionRepository
|
||||
from app.repositories.postgres.deck import PostgresDeckRepository
|
||||
from app.repositories.postgres.linked_account import PostgresLinkedAccountRepository
|
||||
from app.repositories.postgres.user import PostgresUserRepository
|
||||
from app.services.card_service import CardService, get_card_service
|
||||
from app.services.collection_service import CollectionService
|
||||
from app.services.deck_service import DeckService
|
||||
from app.services.game_service import GameService, game_service
|
||||
from app.services.game_state_manager import GameStateManager, game_state_manager
|
||||
from app.services.jwt_service import verify_access_token
|
||||
from app.services.user_service import user_service
|
||||
from app.services.user_service import UserService
|
||||
|
||||
# OAuth2 scheme for extracting Bearer token from Authorization header
|
||||
# tokenUrl is for OpenAPI docs - points to where tokens are obtained
|
||||
@ -148,8 +151,9 @@ async def get_current_user(
|
||||
if user_id is None:
|
||||
raise credentials_exception
|
||||
|
||||
# Fetch user from database
|
||||
user = await user_service.get_by_id(db, user_id)
|
||||
# Fetch user from database (direct query for auth - not business logic)
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
@ -187,7 +191,9 @@ async def get_optional_user(
|
||||
if user_id is None:
|
||||
return None
|
||||
|
||||
return await user_service.get_by_id(db, user_id)
|
||||
# Direct query for auth - not business logic
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_current_premium_user(
|
||||
@ -333,6 +339,31 @@ def get_game_state_manager_dep() -> GameStateManager:
|
||||
return game_state_manager
|
||||
|
||||
|
||||
def get_user_service(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> UserService:
|
||||
"""Get UserService with PostgreSQL repositories.
|
||||
|
||||
Creates a UserService instance with user and linked account repositories.
|
||||
|
||||
Args:
|
||||
db: Database session from request.
|
||||
|
||||
Returns:
|
||||
UserService configured for PostgreSQL.
|
||||
|
||||
Example:
|
||||
@router.post("/auth/google/callback")
|
||||
async def google_callback(
|
||||
user_service: UserService = Depends(get_user_service),
|
||||
):
|
||||
user, created = await user_service.get_or_create_from_oauth(oauth_info)
|
||||
"""
|
||||
user_repo = PostgresUserRepository(db)
|
||||
linked_repo = PostgresLinkedAccountRepository(db)
|
||||
return UserService(user_repo, linked_repo)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Type Aliases for Cleaner Endpoint Signatures
|
||||
# =============================================================================
|
||||
@ -351,6 +382,7 @@ CollectionServiceDep = Annotated[CollectionService, Depends(get_collection_servi
|
||||
CardServiceDep = Annotated[CardService, Depends(get_card_service_dep)]
|
||||
GameServiceDep = Annotated[GameService, Depends(get_game_service_dep)]
|
||||
GameStateManagerDep = Annotated[GameStateManager, Depends(get_game_state_manager_dep)]
|
||||
UserServiceDep = Annotated[UserService, Depends(get_user_service)]
|
||||
|
||||
# Admin authentication
|
||||
AdminAuth = Annotated[None, Depends(verify_admin_token)]
|
||||
|
||||
@ -26,12 +26,12 @@ Example:
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.deps import CurrentUser, DbSession, DeckServiceDep
|
||||
from app.api.deps import CurrentUser, DeckServiceDep, UserServiceDep
|
||||
from app.schemas.deck import DeckResponse, StarterDeckSelectRequest, StarterStatusResponse
|
||||
from app.schemas.user import UserResponse, UserUpdate
|
||||
from app.services.deck_service import DeckLimitExceededError, StarterAlreadySelectedError
|
||||
from app.services.token_store import token_store
|
||||
from app.services.user_service import AccountLinkingError, user_service
|
||||
from app.services.user_service import AccountLinkingError
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
@ -65,7 +65,7 @@ async def get_current_user_profile(
|
||||
@router.patch("/me", response_model=UserResponse)
|
||||
async def update_current_user_profile(
|
||||
user: CurrentUser,
|
||||
db: DbSession,
|
||||
user_service: UserServiceDep,
|
||||
update_data: UserUpdate,
|
||||
) -> UserResponse:
|
||||
"""Update the current user's profile.
|
||||
@ -78,7 +78,7 @@ async def update_current_user_profile(
|
||||
Returns:
|
||||
Updated user profile.
|
||||
"""
|
||||
updated_user = await user_service.update(db, user, update_data)
|
||||
updated_user = await user_service.update(user.id, update_data)
|
||||
return UserResponse.model_validate(updated_user)
|
||||
|
||||
|
||||
@ -139,7 +139,7 @@ async def get_active_sessions(
|
||||
@router.delete("/me/link/{provider}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def unlink_oauth_account(
|
||||
user: CurrentUser,
|
||||
db: DbSession,
|
||||
user_service: UserServiceDep,
|
||||
provider: str,
|
||||
) -> None:
|
||||
"""Unlink an OAuth provider from the current user's account.
|
||||
@ -162,7 +162,7 @@ async def unlink_oauth_account(
|
||||
)
|
||||
|
||||
try:
|
||||
unlinked = await user_service.unlink_oauth_account(db, user, provider)
|
||||
unlinked = await user_service.unlink_oauth_account(user.id, user.oauth_provider, provider)
|
||||
if not unlinked:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
||||
@ -10,11 +10,12 @@ The protocol pattern enables:
|
||||
- Offline fork support without rewriting service layer
|
||||
|
||||
Usage:
|
||||
from app.repositories import CollectionRepository, DeckRepository
|
||||
from app.repositories.postgres import PostgresCollectionRepository
|
||||
from app.repositories import CollectionRepository, DeckRepository, UserRepository
|
||||
from app.repositories.postgres import PostgresCollectionRepository, PostgresUserRepository
|
||||
|
||||
# In production (dependency injection)
|
||||
repo = PostgresCollectionRepository(db_session)
|
||||
user_repo = PostgresUserRepository(db_session)
|
||||
|
||||
# In tests
|
||||
repo = MockCollectionRepository()
|
||||
@ -23,9 +24,13 @@ Usage:
|
||||
from app.repositories.protocols import (
|
||||
CollectionRepository,
|
||||
DeckRepository,
|
||||
LinkedAccountRepository,
|
||||
UserRepository,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CollectionRepository",
|
||||
"DeckRepository",
|
||||
"LinkedAccountRepository",
|
||||
"UserRepository",
|
||||
]
|
||||
|
||||
@ -8,11 +8,15 @@ Usage:
|
||||
from app.repositories.postgres import (
|
||||
PostgresCollectionRepository,
|
||||
PostgresDeckRepository,
|
||||
PostgresLinkedAccountRepository,
|
||||
PostgresUserRepository,
|
||||
)
|
||||
|
||||
# Create repository with database session
|
||||
collection_repo = PostgresCollectionRepository(db_session)
|
||||
deck_repo = PostgresDeckRepository(db_session)
|
||||
user_repo = PostgresUserRepository(db_session)
|
||||
linked_repo = PostgresLinkedAccountRepository(db_session)
|
||||
|
||||
# Use via service layer
|
||||
service = CollectionService(collection_repo)
|
||||
@ -20,8 +24,12 @@ Usage:
|
||||
|
||||
from app.repositories.postgres.collection import PostgresCollectionRepository
|
||||
from app.repositories.postgres.deck import PostgresDeckRepository
|
||||
from app.repositories.postgres.linked_account import PostgresLinkedAccountRepository
|
||||
from app.repositories.postgres.user import PostgresUserRepository
|
||||
|
||||
__all__ = [
|
||||
"PostgresCollectionRepository",
|
||||
"PostgresDeckRepository",
|
||||
"PostgresLinkedAccountRepository",
|
||||
"PostgresUserRepository",
|
||||
]
|
||||
|
||||
149
backend/app/repositories/postgres/linked_account.py
Normal file
149
backend/app/repositories/postgres/linked_account.py
Normal file
@ -0,0 +1,149 @@
|
||||
"""PostgreSQL implementation of LinkedAccountRepository.
|
||||
|
||||
This module provides the PostgreSQL-specific implementation of the
|
||||
LinkedAccountRepository protocol using SQLAlchemy async sessions.
|
||||
|
||||
Example:
|
||||
async with get_db_session() as db:
|
||||
repo = PostgresLinkedAccountRepository(db)
|
||||
accounts = await repo.get_by_user(user_id)
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.oauth_account import OAuthLinkedAccount
|
||||
from app.repositories.protocols import LinkedAccountEntry
|
||||
|
||||
|
||||
def _to_dto(model: OAuthLinkedAccount) -> LinkedAccountEntry:
|
||||
"""Convert ORM model to DTO."""
|
||||
return LinkedAccountEntry(
|
||||
id=model.id,
|
||||
user_id=UUID(model.user_id) if isinstance(model.user_id, str) else model.user_id,
|
||||
provider=model.provider,
|
||||
oauth_id=model.oauth_id,
|
||||
email=model.email,
|
||||
display_name=model.display_name,
|
||||
avatar_url=model.avatar_url,
|
||||
linked_at=model.linked_at,
|
||||
)
|
||||
|
||||
|
||||
class PostgresLinkedAccountRepository:
|
||||
"""PostgreSQL implementation of LinkedAccountRepository.
|
||||
|
||||
Uses SQLAlchemy async sessions for database access.
|
||||
|
||||
Attributes:
|
||||
_db: The async database session.
|
||||
"""
|
||||
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
"""Initialize with database session.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy async session.
|
||||
"""
|
||||
self._db = db
|
||||
|
||||
async def get_by_provider(
|
||||
self,
|
||||
provider: str,
|
||||
oauth_id: str,
|
||||
) -> LinkedAccountEntry | None:
|
||||
"""Get a linked account by provider and OAuth ID.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider name (google, discord).
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
LinkedAccountEntry if found, None otherwise.
|
||||
"""
|
||||
result = await self._db.execute(
|
||||
select(OAuthLinkedAccount).where(
|
||||
OAuthLinkedAccount.provider == provider,
|
||||
OAuthLinkedAccount.oauth_id == oauth_id,
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_dto(model) if model else None
|
||||
|
||||
async def get_by_user(self, user_id: UUID) -> list[LinkedAccountEntry]:
|
||||
"""Get all linked accounts for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
List of LinkedAccountEntry, ordered by provider.
|
||||
"""
|
||||
result = await self._db.execute(
|
||||
select(OAuthLinkedAccount)
|
||||
.where(OAuthLinkedAccount.user_id == str(user_id))
|
||||
.order_by(OAuthLinkedAccount.provider)
|
||||
)
|
||||
return [_to_dto(model) for model in result.scalars().all()]
|
||||
|
||||
async def create(
|
||||
self,
|
||||
user_id: UUID,
|
||||
provider: str,
|
||||
oauth_id: str,
|
||||
email: str | None = None,
|
||||
display_name: str | None = None,
|
||||
avatar_url: str | None = None,
|
||||
) -> LinkedAccountEntry:
|
||||
"""Link an OAuth provider to a user account.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
provider: OAuth provider name.
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
email: Email from the OAuth provider.
|
||||
display_name: Display name from the OAuth provider.
|
||||
avatar_url: Avatar URL from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
The created LinkedAccountEntry.
|
||||
"""
|
||||
linked_account = OAuthLinkedAccount(
|
||||
user_id=str(user_id),
|
||||
provider=provider,
|
||||
oauth_id=oauth_id,
|
||||
email=email,
|
||||
display_name=display_name,
|
||||
avatar_url=avatar_url,
|
||||
)
|
||||
self._db.add(linked_account)
|
||||
await self._db.commit()
|
||||
await self._db.refresh(linked_account)
|
||||
return _to_dto(linked_account)
|
||||
|
||||
async def delete(self, user_id: UUID, provider: str) -> bool:
|
||||
"""Unlink an OAuth provider from a user account.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
provider: OAuth provider name to unlink.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
result = await self._db.execute(
|
||||
select(OAuthLinkedAccount).where(
|
||||
OAuthLinkedAccount.user_id == str(user_id),
|
||||
OAuthLinkedAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
linked_account = result.scalar_one_or_none()
|
||||
|
||||
if linked_account is None:
|
||||
return False
|
||||
|
||||
await self._db.delete(linked_account)
|
||||
await self._db.commit()
|
||||
return True
|
||||
243
backend/app/repositories/postgres/user.py
Normal file
243
backend/app/repositories/postgres/user.py
Normal file
@ -0,0 +1,243 @@
|
||||
"""PostgreSQL implementation of UserRepository.
|
||||
|
||||
This module provides the PostgreSQL-specific implementation of the
|
||||
UserRepository protocol using SQLAlchemy async sessions.
|
||||
|
||||
Example:
|
||||
async with get_db_session() as db:
|
||||
repo = PostgresUserRepository(db)
|
||||
user = await repo.get_by_id(user_id)
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.user import User
|
||||
from app.repositories.protocols import UNSET, UserEntry
|
||||
|
||||
|
||||
def _to_dto(model: User) -> UserEntry:
|
||||
"""Convert ORM model to DTO."""
|
||||
return UserEntry(
|
||||
id=model.id,
|
||||
email=model.email,
|
||||
display_name=model.display_name,
|
||||
avatar_url=model.avatar_url,
|
||||
oauth_provider=model.oauth_provider,
|
||||
oauth_id=model.oauth_id,
|
||||
is_premium=model.is_premium,
|
||||
premium_until=model.premium_until,
|
||||
last_login=model.last_login,
|
||||
created_at=model.created_at,
|
||||
updated_at=model.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class PostgresUserRepository:
|
||||
"""PostgreSQL implementation of UserRepository.
|
||||
|
||||
Uses SQLAlchemy async sessions for database access.
|
||||
|
||||
Attributes:
|
||||
_db: The async database session.
|
||||
"""
|
||||
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
"""Initialize with database session.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy async session.
|
||||
"""
|
||||
self._db = db
|
||||
|
||||
async def get_by_id(self, user_id: UUID) -> UserEntry | None:
|
||||
"""Get a user by their ID.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
UserEntry if found, None otherwise.
|
||||
"""
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_dto(model) if model else None
|
||||
|
||||
async def get_by_email(self, email: str) -> UserEntry | None:
|
||||
"""Get a user by their email address.
|
||||
|
||||
Args:
|
||||
email: The user's email address.
|
||||
|
||||
Returns:
|
||||
UserEntry if found, None otherwise.
|
||||
"""
|
||||
result = await self._db.execute(select(User).where(User.email == email))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_dto(model) if model else None
|
||||
|
||||
async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None:
|
||||
"""Get a user by their OAuth provider and ID.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider name (google, discord).
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
UserEntry if found, None otherwise.
|
||||
"""
|
||||
result = await self._db.execute(
|
||||
select(User).where(
|
||||
User.oauth_provider == provider,
|
||||
User.oauth_id == oauth_id,
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_dto(model) if model else None
|
||||
|
||||
async def create(
|
||||
self,
|
||||
email: str,
|
||||
display_name: str,
|
||||
oauth_provider: str,
|
||||
oauth_id: str,
|
||||
avatar_url: str | None = None,
|
||||
) -> UserEntry:
|
||||
"""Create a new user.
|
||||
|
||||
Args:
|
||||
email: User's email address.
|
||||
display_name: Public display name.
|
||||
oauth_provider: OAuth provider name.
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
avatar_url: Optional avatar URL.
|
||||
|
||||
Returns:
|
||||
The created UserEntry.
|
||||
"""
|
||||
user = User(
|
||||
email=email,
|
||||
display_name=display_name,
|
||||
oauth_provider=oauth_provider,
|
||||
oauth_id=oauth_id,
|
||||
avatar_url=avatar_url,
|
||||
)
|
||||
self._db.add(user)
|
||||
await self._db.commit()
|
||||
await self._db.refresh(user)
|
||||
return _to_dto(user)
|
||||
|
||||
async def update(
|
||||
self,
|
||||
user_id: UUID,
|
||||
display_name: str | None = None,
|
||||
avatar_url: str | None = UNSET, # type: ignore[assignment]
|
||||
oauth_provider: str | None = None,
|
||||
oauth_id: str | None = None,
|
||||
) -> UserEntry | None:
|
||||
"""Update user profile fields.
|
||||
|
||||
Only provided (non-None/non-UNSET) fields are updated.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
display_name: New display name (None keeps existing).
|
||||
avatar_url: New avatar URL (UNSET=keep, None=clear, str=set).
|
||||
oauth_provider: New OAuth provider (None keeps existing).
|
||||
oauth_id: New OAuth ID (None keeps existing).
|
||||
|
||||
Returns:
|
||||
Updated UserEntry, or None if user not found.
|
||||
"""
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
return None
|
||||
|
||||
if display_name is not None:
|
||||
user.display_name = display_name
|
||||
if avatar_url is not UNSET:
|
||||
user.avatar_url = avatar_url
|
||||
if oauth_provider is not None:
|
||||
user.oauth_provider = oauth_provider
|
||||
if oauth_id is not None:
|
||||
user.oauth_id = oauth_id
|
||||
|
||||
await self._db.commit()
|
||||
await self._db.refresh(user)
|
||||
return _to_dto(user)
|
||||
|
||||
async def update_last_login(self, user_id: UUID) -> UserEntry | None:
|
||||
"""Update the user's last login timestamp to now.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
Updated UserEntry, or None if user not found.
|
||||
"""
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
return None
|
||||
|
||||
user.last_login = datetime.now(UTC)
|
||||
await self._db.commit()
|
||||
await self._db.refresh(user)
|
||||
return _to_dto(user)
|
||||
|
||||
async def update_premium(
|
||||
self,
|
||||
user_id: UUID,
|
||||
is_premium: bool,
|
||||
premium_until: datetime | None,
|
||||
) -> UserEntry | None:
|
||||
"""Update user's premium subscription status.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
is_premium: Whether user has premium.
|
||||
premium_until: When premium expires, or None if not premium.
|
||||
|
||||
Returns:
|
||||
Updated UserEntry, or None if user not found.
|
||||
"""
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
return None
|
||||
|
||||
user.is_premium = is_premium
|
||||
user.premium_until = premium_until
|
||||
|
||||
await self._db.commit()
|
||||
await self._db.refresh(user)
|
||||
return _to_dto(user)
|
||||
|
||||
async def delete(self, user_id: UUID) -> bool:
|
||||
"""Delete a user account.
|
||||
|
||||
This will cascade delete all related data (decks, collection, etc.)
|
||||
based on database constraints.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
return False
|
||||
|
||||
await self._db.delete(user)
|
||||
await self._db.commit()
|
||||
return True
|
||||
@ -102,6 +102,46 @@ class DeckEntry:
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserEntry:
|
||||
"""Storage-agnostic representation of a user account.
|
||||
|
||||
This DTO decouples the service layer from the ORM model,
|
||||
enabling different storage backends (PostgreSQL, SQLite, JSON)
|
||||
to be used interchangeably.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
email: str
|
||||
display_name: str
|
||||
avatar_url: str | None
|
||||
oauth_provider: str
|
||||
oauth_id: str
|
||||
is_premium: bool
|
||||
premium_until: datetime | None
|
||||
last_login: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinkedAccountEntry:
|
||||
"""Storage-agnostic representation of a linked OAuth account.
|
||||
|
||||
Users can link multiple OAuth providers (e.g., Google + Discord)
|
||||
to a single account for flexible login options.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
user_id: UUID
|
||||
provider: str
|
||||
oauth_id: str
|
||||
email: str | None
|
||||
display_name: str | None
|
||||
avatar_url: str | None
|
||||
linked_at: datetime
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Repository Protocols
|
||||
# =============================================================================
|
||||
@ -342,3 +382,211 @@ class DeckRepository(Protocol):
|
||||
Tuple of (has_starter, starter_type).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class UserRepository(Protocol):
|
||||
"""Protocol for user account data access.
|
||||
|
||||
Implementations handle storage-specific details (PostgreSQL, SQLite, JSON).
|
||||
Services use this protocol for business logic without knowing storage details.
|
||||
|
||||
Note: Business logic like get_or_create_from_oauth belongs in the service layer,
|
||||
not in the repository.
|
||||
"""
|
||||
|
||||
async def get_by_id(self, user_id: UUID) -> UserEntry | None:
|
||||
"""Get a user by their ID.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
UserEntry if found, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_by_email(self, email: str) -> UserEntry | None:
|
||||
"""Get a user by their email address.
|
||||
|
||||
Args:
|
||||
email: The user's email address.
|
||||
|
||||
Returns:
|
||||
UserEntry if found, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None:
|
||||
"""Get a user by their OAuth provider and ID.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider name (google, discord).
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
UserEntry if found, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def create(
|
||||
self,
|
||||
email: str,
|
||||
display_name: str,
|
||||
oauth_provider: str,
|
||||
oauth_id: str,
|
||||
avatar_url: str | None = None,
|
||||
) -> UserEntry:
|
||||
"""Create a new user.
|
||||
|
||||
Args:
|
||||
email: User's email address.
|
||||
display_name: Public display name.
|
||||
oauth_provider: OAuth provider name.
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
avatar_url: Optional avatar URL.
|
||||
|
||||
Returns:
|
||||
The created UserEntry.
|
||||
"""
|
||||
...
|
||||
|
||||
async def update(
|
||||
self,
|
||||
user_id: UUID,
|
||||
display_name: str | None = None,
|
||||
avatar_url: str | None = UNSET, # type: ignore[assignment]
|
||||
oauth_provider: str | None = None,
|
||||
oauth_id: str | None = None,
|
||||
) -> UserEntry | None:
|
||||
"""Update user profile fields.
|
||||
|
||||
Only provided (non-None/non-UNSET) fields are updated.
|
||||
Use UNSET (default) to keep existing value for nullable fields,
|
||||
or None to explicitly clear them.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
display_name: New display name (None keeps existing).
|
||||
avatar_url: New avatar URL (UNSET=keep, None=clear, str=set).
|
||||
oauth_provider: New OAuth provider (None keeps existing).
|
||||
oauth_id: New OAuth ID (None keeps existing).
|
||||
|
||||
Returns:
|
||||
Updated UserEntry, or None if user not found.
|
||||
"""
|
||||
...
|
||||
|
||||
async def update_last_login(self, user_id: UUID) -> UserEntry | None:
|
||||
"""Update the user's last login timestamp to now.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
Updated UserEntry, or None if user not found.
|
||||
"""
|
||||
...
|
||||
|
||||
async def update_premium(
|
||||
self,
|
||||
user_id: UUID,
|
||||
is_premium: bool,
|
||||
premium_until: datetime | None,
|
||||
) -> UserEntry | None:
|
||||
"""Update user's premium subscription status.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
is_premium: Whether user has premium.
|
||||
premium_until: When premium expires, or None if not premium.
|
||||
|
||||
Returns:
|
||||
Updated UserEntry, or None if user not found.
|
||||
"""
|
||||
...
|
||||
|
||||
async def delete(self, user_id: UUID) -> bool:
|
||||
"""Delete a user account.
|
||||
|
||||
This will cascade delete all related data (decks, collection, etc.)
|
||||
based on database constraints.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class LinkedAccountRepository(Protocol):
|
||||
"""Protocol for OAuth linked accounts data access.
|
||||
|
||||
Users can link multiple OAuth providers to a single account.
|
||||
The primary OAuth provider is stored on the User model itself;
|
||||
additional linked providers are stored as LinkedAccount records.
|
||||
"""
|
||||
|
||||
async def get_by_provider(
|
||||
self,
|
||||
provider: str,
|
||||
oauth_id: str,
|
||||
) -> LinkedAccountEntry | None:
|
||||
"""Get a linked account by provider and OAuth ID.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider name (google, discord).
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
LinkedAccountEntry if found, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_by_user(self, user_id: UUID) -> list[LinkedAccountEntry]:
|
||||
"""Get all linked accounts for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
List of LinkedAccountEntry, ordered by provider.
|
||||
"""
|
||||
...
|
||||
|
||||
async def create(
|
||||
self,
|
||||
user_id: UUID,
|
||||
provider: str,
|
||||
oauth_id: str,
|
||||
email: str | None = None,
|
||||
display_name: str | None = None,
|
||||
avatar_url: str | None = None,
|
||||
) -> LinkedAccountEntry:
|
||||
"""Link an OAuth provider to a user account.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
provider: OAuth provider name.
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
email: Email from the OAuth provider.
|
||||
display_name: Display name from the OAuth provider.
|
||||
avatar_url: Avatar URL from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
The created LinkedAccountEntry.
|
||||
"""
|
||||
...
|
||||
|
||||
async def delete(self, user_id: UUID, provider: str) -> bool:
|
||||
"""Unlink an OAuth provider from a user account.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
provider: OAuth provider name to unlink.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
...
|
||||
|
||||
@ -1,32 +1,37 @@
|
||||
"""User service for Mantimon TCG.
|
||||
|
||||
This module provides async CRUD operations for user accounts,
|
||||
including OAuth-based user creation and premium status management.
|
||||
This module provides business logic for user account operations,
|
||||
including OAuth-based user creation, account linking, and premium
|
||||
status management.
|
||||
|
||||
All database operations use async SQLAlchemy sessions.
|
||||
The service layer contains business logic while repositories handle
|
||||
pure data access. This separation enables testing and different
|
||||
storage backends.
|
||||
|
||||
Example:
|
||||
from app.services.user_service import user_service
|
||||
from app.services.user_service import UserService
|
||||
from app.repositories.postgres import PostgresUserRepository, PostgresLinkedAccountRepository
|
||||
|
||||
# Get user by ID
|
||||
user = await user_service.get_by_id(db, user_id)
|
||||
# Create service with injected repositories
|
||||
user_repo = PostgresUserRepository(db)
|
||||
linked_repo = PostgresLinkedAccountRepository(db)
|
||||
service = UserService(user_repo, linked_repo)
|
||||
|
||||
# Create from OAuth
|
||||
user = await user_service.create_from_oauth(db, oauth_info)
|
||||
|
||||
# Update premium status
|
||||
user = await user_service.update_premium(db, user_id, premium_until)
|
||||
# Use service
|
||||
user = await service.get_by_id(user_id)
|
||||
user, created = await service.get_or_create_from_oauth(oauth_info)
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.oauth_account import OAuthLinkedAccount
|
||||
from app.db.models.user import User
|
||||
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
|
||||
from app.repositories.protocols import (
|
||||
LinkedAccountEntry,
|
||||
LinkedAccountRepository,
|
||||
UserEntry,
|
||||
UserRepository,
|
||||
)
|
||||
from app.schemas.user import OAuthUserInfo, UserUpdate
|
||||
|
||||
|
||||
class AccountLinkingError(Exception):
|
||||
@ -38,117 +43,119 @@ class AccountLinkingError(Exception):
|
||||
class UserService:
|
||||
"""Service for user account operations.
|
||||
|
||||
Provides async methods for user CRUD, OAuth-based creation,
|
||||
and premium subscription management.
|
||||
Provides business logic for user CRUD, OAuth-based creation,
|
||||
account linking, and premium subscription management.
|
||||
|
||||
Attributes:
|
||||
_user_repo: Repository for user data access.
|
||||
_linked_repo: Repository for linked account data access.
|
||||
"""
|
||||
|
||||
async def get_by_id(self, db: AsyncSession, user_id: UUID) -> User | None:
|
||||
def __init__(
|
||||
self,
|
||||
user_repository: UserRepository,
|
||||
linked_account_repository: LinkedAccountRepository,
|
||||
) -> None:
|
||||
"""Initialize with repository dependencies.
|
||||
|
||||
Args:
|
||||
user_repository: Repository for user data access.
|
||||
linked_account_repository: Repository for linked account data access.
|
||||
"""
|
||||
self._user_repo = user_repository
|
||||
self._linked_repo = linked_account_repository
|
||||
|
||||
async def get_by_id(self, user_id: UUID) -> UserEntry | None:
|
||||
"""Get a user by their ID.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise.
|
||||
UserEntry if found, None otherwise.
|
||||
|
||||
Example:
|
||||
user = await user_service.get_by_id(db, user_id)
|
||||
user = await service.get_by_id(user_id)
|
||||
if user:
|
||||
print(f"Found user: {user.display_name}")
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
return result.scalar_one_or_none()
|
||||
return await self._user_repo.get_by_id(user_id)
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
|
||||
async def get_by_email(self, email: str) -> UserEntry | None:
|
||||
"""Get a user by their email address.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
email: The user's email address.
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise.
|
||||
UserEntry if found, None otherwise.
|
||||
|
||||
Example:
|
||||
user = await user_service.get_by_email(db, "player@example.com")
|
||||
user = await service.get_by_email("player@example.com")
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
return result.scalar_one_or_none()
|
||||
return await self._user_repo.get_by_email(email)
|
||||
|
||||
async def get_by_oauth(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
provider: str,
|
||||
oauth_id: str,
|
||||
) -> User | None:
|
||||
async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None:
|
||||
"""Get a user by their OAuth provider and ID.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
provider: OAuth provider name (google, discord).
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise.
|
||||
UserEntry if found, None otherwise.
|
||||
|
||||
Example:
|
||||
user = await user_service.get_by_oauth(db, "google", "123456789")
|
||||
user = await service.get_by_oauth("google", "123456789")
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.oauth_provider == provider,
|
||||
User.oauth_id == oauth_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
return await self._user_repo.get_by_oauth(provider, oauth_id)
|
||||
|
||||
async def create(self, db: AsyncSession, user_data: UserCreate) -> User:
|
||||
async def create(
|
||||
self,
|
||||
email: str,
|
||||
display_name: str,
|
||||
oauth_provider: str,
|
||||
oauth_id: str,
|
||||
avatar_url: str | None = None,
|
||||
) -> UserEntry:
|
||||
"""Create a new user.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user_data: User creation data.
|
||||
email: User's email address.
|
||||
display_name: Public display name.
|
||||
oauth_provider: OAuth provider name.
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
avatar_url: Optional avatar URL.
|
||||
|
||||
Returns:
|
||||
The created User instance.
|
||||
The created UserEntry.
|
||||
|
||||
Example:
|
||||
user_data = UserCreate(
|
||||
user = await service.create(
|
||||
email="player@example.com",
|
||||
display_name="Player1",
|
||||
oauth_provider="google",
|
||||
oauth_id="123456789"
|
||||
)
|
||||
user = await user_service.create(db, user_data)
|
||||
"""
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
display_name=user_data.display_name,
|
||||
avatar_url=user_data.avatar_url,
|
||||
oauth_provider=user_data.oauth_provider,
|
||||
oauth_id=user_data.oauth_id,
|
||||
return await self._user_repo.create(
|
||||
email=email,
|
||||
display_name=display_name,
|
||||
oauth_provider=oauth_provider,
|
||||
oauth_id=oauth_id,
|
||||
avatar_url=avatar_url,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
async def create_from_oauth(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
oauth_info: OAuthUserInfo,
|
||||
) -> User:
|
||||
async def create_from_oauth(self, oauth_info: OAuthUserInfo) -> UserEntry:
|
||||
"""Create a new user from OAuth provider info.
|
||||
|
||||
Convenience method that converts OAuthUserInfo to UserCreate.
|
||||
Convenience method that extracts fields from OAuthUserInfo.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
oauth_info: Normalized OAuth user information.
|
||||
|
||||
Returns:
|
||||
The created User instance.
|
||||
The created UserEntry.
|
||||
|
||||
Example:
|
||||
oauth_info = OAuthUserInfo(
|
||||
@ -158,209 +165,212 @@ class UserService:
|
||||
name="Player One",
|
||||
avatar_url="https://..."
|
||||
)
|
||||
user = await user_service.create_from_oauth(db, oauth_info)
|
||||
user = await service.create_from_oauth(oauth_info)
|
||||
"""
|
||||
user_data = oauth_info.to_user_create()
|
||||
return await self.create(db, user_data)
|
||||
return await self.create(
|
||||
email=oauth_info.email,
|
||||
display_name=oauth_info.name,
|
||||
oauth_provider=oauth_info.provider,
|
||||
oauth_id=oauth_info.oauth_id,
|
||||
avatar_url=oauth_info.avatar_url,
|
||||
)
|
||||
|
||||
async def get_or_create_from_oauth(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
oauth_info: OAuthUserInfo,
|
||||
) -> tuple[User, bool]:
|
||||
) -> tuple[UserEntry, bool]:
|
||||
"""Get existing user or create new one from OAuth info.
|
||||
|
||||
First checks for existing user by OAuth provider+ID, then by email
|
||||
(for account linking), and finally creates a new user if not found.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
oauth_info: Normalized OAuth user information.
|
||||
|
||||
Returns:
|
||||
Tuple of (User, created) where created is True if new user.
|
||||
Tuple of (UserEntry, created) where created is True if new user.
|
||||
|
||||
Example:
|
||||
user, created = await user_service.get_or_create_from_oauth(db, oauth_info)
|
||||
user, created = await service.get_or_create_from_oauth(oauth_info)
|
||||
if created:
|
||||
print("Welcome, new user!")
|
||||
else:
|
||||
print("Welcome back!")
|
||||
"""
|
||||
# First, check by OAuth provider + ID (exact match)
|
||||
user = await self.get_by_oauth(db, oauth_info.provider, oauth_info.oauth_id)
|
||||
user = await self._user_repo.get_by_oauth(oauth_info.provider, oauth_info.oauth_id)
|
||||
if user:
|
||||
return user, False
|
||||
|
||||
# Check by email for potential account linking
|
||||
# If user exists with same email but different OAuth, update their OAuth
|
||||
user = await self.get_by_email(db, oauth_info.email)
|
||||
user = await self._user_repo.get_by_email(oauth_info.email)
|
||||
if user:
|
||||
# Update OAuth credentials for existing user
|
||||
# This links the new OAuth provider to the existing account
|
||||
user.oauth_provider = oauth_info.provider
|
||||
user.oauth_id = oauth_info.oauth_id
|
||||
# Optionally update avatar if not set
|
||||
if not user.avatar_url and oauth_info.avatar_url:
|
||||
user.avatar_url = oauth_info.avatar_url
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user, False
|
||||
updated_user = await self._user_repo.update(
|
||||
user_id=user.id,
|
||||
oauth_provider=oauth_info.provider,
|
||||
oauth_id=oauth_info.oauth_id,
|
||||
avatar_url=oauth_info.avatar_url if not user.avatar_url else None,
|
||||
)
|
||||
return updated_user or user, False
|
||||
|
||||
# Create new user
|
||||
user = await self.create_from_oauth(db, oauth_info)
|
||||
user = await self.create_from_oauth(oauth_info)
|
||||
return user, True
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user: User,
|
||||
user_id: UUID,
|
||||
update_data: UserUpdate,
|
||||
) -> User:
|
||||
) -> UserEntry | None:
|
||||
"""Update user profile fields.
|
||||
|
||||
Only updates fields that are provided (not None).
|
||||
Only updates fields that are explicitly provided. Uses UNSET pattern
|
||||
for avatar_url to distinguish "not provided" from "set to None".
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user: The user to update.
|
||||
user_id: The user's UUID.
|
||||
update_data: Fields to update.
|
||||
|
||||
Returns:
|
||||
The updated User instance.
|
||||
The updated UserEntry, or None if not found.
|
||||
|
||||
Example:
|
||||
update_data = UserUpdate(display_name="New Name")
|
||||
user = await user_service.update(db, user, update_data)
|
||||
user = await service.update(user_id, update_data)
|
||||
"""
|
||||
if update_data.display_name is not None:
|
||||
user.display_name = update_data.display_name
|
||||
if update_data.avatar_url is not None:
|
||||
user.avatar_url = update_data.avatar_url
|
||||
from app.repositories.protocols import UNSET
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
# Use UNSET for avatar_url unless explicitly provided
|
||||
avatar_url = (
|
||||
update_data.avatar_url if "avatar_url" in update_data.model_fields_set else UNSET
|
||||
)
|
||||
|
||||
async def update_last_login(self, db: AsyncSession, user: User) -> User:
|
||||
return await self._user_repo.update(
|
||||
user_id=user_id,
|
||||
display_name=update_data.display_name,
|
||||
avatar_url=avatar_url,
|
||||
)
|
||||
|
||||
async def update_last_login(self, user_id: UUID) -> UserEntry | None:
|
||||
"""Update the user's last login timestamp.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user: The user to update.
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
The updated User instance.
|
||||
The updated UserEntry, or None if not found.
|
||||
|
||||
Example:
|
||||
user = await user_service.update_last_login(db, user)
|
||||
user = await service.update_last_login(user_id)
|
||||
"""
|
||||
user.last_login = datetime.now(UTC)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
return await self._user_repo.update_last_login(user_id)
|
||||
|
||||
async def update_premium(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user: User,
|
||||
user_id: UUID,
|
||||
premium_until: datetime | None,
|
||||
) -> User:
|
||||
) -> UserEntry | None:
|
||||
"""Update user's premium subscription status.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user: The user to update.
|
||||
user_id: The user's UUID.
|
||||
premium_until: When premium expires, or None to remove premium.
|
||||
|
||||
Returns:
|
||||
The updated User instance.
|
||||
The updated UserEntry, or None if not found.
|
||||
|
||||
Example:
|
||||
# Grant 30 days of premium
|
||||
expires = datetime.now(UTC) + timedelta(days=30)
|
||||
user = await user_service.update_premium(db, user, expires)
|
||||
user = await service.update_premium(user_id, expires)
|
||||
|
||||
# Remove premium
|
||||
user = await user_service.update_premium(db, user, None)
|
||||
user = await service.update_premium(user_id, None)
|
||||
"""
|
||||
if premium_until is not None:
|
||||
user.is_premium = True
|
||||
user.premium_until = premium_until
|
||||
else:
|
||||
user.is_premium = False
|
||||
user.premium_until = None
|
||||
is_premium = premium_until is not None
|
||||
return await self._user_repo.update_premium(
|
||||
user_id=user_id,
|
||||
is_premium=is_premium,
|
||||
premium_until=premium_until,
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
async def delete(self, db: AsyncSession, user: User) -> None:
|
||||
async def delete(self, user_id: UUID) -> bool:
|
||||
"""Delete a user account.
|
||||
|
||||
This will cascade delete all related data (decks, collection, etc.)
|
||||
based on the model relationships.
|
||||
based on the database constraints.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user: The user to delete.
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
|
||||
Example:
|
||||
await user_service.delete(db, user)
|
||||
success = await service.delete(user_id)
|
||||
"""
|
||||
await db.delete(user)
|
||||
await db.commit()
|
||||
return await self._user_repo.delete(user_id)
|
||||
|
||||
# =========================================================================
|
||||
# Linked Account Operations
|
||||
# =========================================================================
|
||||
|
||||
async def get_linked_accounts(self, user_id: UUID) -> list[LinkedAccountEntry]:
|
||||
"""Get all linked OAuth accounts for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
List of LinkedAccountEntry.
|
||||
"""
|
||||
return await self._linked_repo.get_by_user(user_id)
|
||||
|
||||
async def get_linked_account(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
provider: str,
|
||||
oauth_id: str,
|
||||
) -> OAuthLinkedAccount | None:
|
||||
) -> LinkedAccountEntry | None:
|
||||
"""Get a linked account by provider and OAuth ID.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
provider: OAuth provider name (google, discord).
|
||||
oauth_id: Unique ID from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
OAuthLinkedAccount if found, None otherwise.
|
||||
LinkedAccountEntry if found, None otherwise.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(OAuthLinkedAccount).where(
|
||||
OAuthLinkedAccount.provider == provider,
|
||||
OAuthLinkedAccount.oauth_id == oauth_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
return await self._linked_repo.get_by_provider(provider, oauth_id)
|
||||
|
||||
async def link_oauth_account(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user: User,
|
||||
user_id: UUID,
|
||||
user_oauth_provider: str,
|
||||
oauth_info: OAuthUserInfo,
|
||||
) -> OAuthLinkedAccount:
|
||||
) -> LinkedAccountEntry:
|
||||
"""Link an additional OAuth provider to a user account.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user: The user to link the account to.
|
||||
user_id: The user's UUID.
|
||||
user_oauth_provider: The user's primary OAuth provider.
|
||||
oauth_info: OAuth information from the provider.
|
||||
|
||||
Returns:
|
||||
The created OAuthLinkedAccount.
|
||||
The created LinkedAccountEntry.
|
||||
|
||||
Raises:
|
||||
AccountLinkingError: If provider is already linked to this or another user.
|
||||
|
||||
Example:
|
||||
linked = await user_service.link_oauth_account(db, user, discord_info)
|
||||
linked = await service.link_oauth_account(user.id, user.oauth_provider, discord_info)
|
||||
"""
|
||||
# Check if this provider+oauth_id is already linked to any user
|
||||
existing = await self.get_linked_account(db, oauth_info.provider, oauth_info.oauth_id)
|
||||
existing = await self._linked_repo.get_by_provider(oauth_info.provider, oauth_info.oauth_id)
|
||||
if existing:
|
||||
if str(existing.user_id) == str(user.id):
|
||||
if existing.user_id == user_id:
|
||||
raise AccountLinkingError(
|
||||
f"{oauth_info.provider.title()} account is already linked to your account"
|
||||
)
|
||||
@ -369,36 +379,33 @@ class UserService:
|
||||
)
|
||||
|
||||
# Check if this is the user's primary OAuth provider
|
||||
if user.oauth_provider == oauth_info.provider:
|
||||
if user_oauth_provider == oauth_info.provider:
|
||||
raise AccountLinkingError(
|
||||
f"{oauth_info.provider.title()} is your primary login provider"
|
||||
)
|
||||
|
||||
# Check if user already has this provider linked
|
||||
for linked in user.linked_accounts:
|
||||
linked_accounts = await self._linked_repo.get_by_user(user_id)
|
||||
for linked in linked_accounts:
|
||||
if linked.provider == oauth_info.provider:
|
||||
raise AccountLinkingError(
|
||||
f"You already have a {oauth_info.provider.title()} account linked"
|
||||
)
|
||||
|
||||
# Create the linked account
|
||||
linked_account = OAuthLinkedAccount(
|
||||
user_id=str(user.id),
|
||||
return await self._linked_repo.create(
|
||||
user_id=user_id,
|
||||
provider=oauth_info.provider,
|
||||
oauth_id=oauth_info.oauth_id,
|
||||
email=oauth_info.email,
|
||||
display_name=oauth_info.name,
|
||||
avatar_url=oauth_info.avatar_url,
|
||||
)
|
||||
db.add(linked_account)
|
||||
await db.commit()
|
||||
await db.refresh(linked_account)
|
||||
return linked_account
|
||||
|
||||
async def unlink_oauth_account(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user: User,
|
||||
user_id: UUID,
|
||||
user_oauth_provider: str,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""Unlink an OAuth provider from a user account.
|
||||
@ -406,8 +413,8 @@ class UserService:
|
||||
Cannot unlink the primary OAuth provider.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user: The user to unlink from.
|
||||
user_id: The user's UUID.
|
||||
user_oauth_provider: The user's primary OAuth provider.
|
||||
provider: OAuth provider name to unlink.
|
||||
|
||||
Returns:
|
||||
@ -417,23 +424,12 @@ class UserService:
|
||||
AccountLinkingError: If trying to unlink the primary provider.
|
||||
|
||||
Example:
|
||||
success = await user_service.unlink_oauth_account(db, user, "discord")
|
||||
success = await service.unlink_oauth_account(user.id, user.oauth_provider, "discord")
|
||||
"""
|
||||
# Cannot unlink primary provider
|
||||
if user.oauth_provider == provider:
|
||||
if user_oauth_provider == provider:
|
||||
raise AccountLinkingError(
|
||||
f"Cannot unlink {provider.title()} - it is your primary login provider"
|
||||
)
|
||||
|
||||
# Find and delete the linked account
|
||||
for linked in user.linked_accounts:
|
||||
if linked.provider == provider:
|
||||
await db.delete(linked)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Global service instance
|
||||
user_service = UserService()
|
||||
return await self._linked_repo.delete(user_id, provider)
|
||||
|
||||
@ -2,14 +2,19 @@
|
||||
|
||||
Tests the authentication endpoints including OAuth redirects,
|
||||
token refresh, and logout.
|
||||
|
||||
Uses FastAPI's dependency override pattern for proper dependency injection testing.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.deps import get_user_service
|
||||
|
||||
|
||||
class TestGoogleAuthRedirect:
|
||||
"""Tests for GET /api/auth/google endpoint."""
|
||||
@ -50,11 +55,32 @@ class TestDiscordAuthRedirect:
|
||||
assert "not configured" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_service_instance():
|
||||
"""Create a mock UserService for dependency injection.
|
||||
|
||||
Returns a MagicMock with async methods configured.
|
||||
"""
|
||||
mock = MagicMock()
|
||||
mock.get_by_id = AsyncMock()
|
||||
mock.get_by_email = AsyncMock()
|
||||
mock.get_by_oauth = AsyncMock()
|
||||
mock.get_or_create_from_oauth = AsyncMock()
|
||||
mock.update_last_login = AsyncMock()
|
||||
return mock
|
||||
|
||||
|
||||
class TestRefreshTokens:
|
||||
"""Tests for POST /api/auth/refresh endpoint."""
|
||||
|
||||
def test_returns_new_access_token(
|
||||
self, client: TestClient, test_user, refresh_token_data, mock_get_redis
|
||||
self,
|
||||
app,
|
||||
client: TestClient,
|
||||
test_user,
|
||||
refresh_token_data,
|
||||
mock_get_redis,
|
||||
mock_user_service_instance,
|
||||
):
|
||||
"""Test that refresh endpoint returns new access token for valid refresh token.
|
||||
|
||||
@ -70,25 +96,31 @@ class TestRefreshTokens:
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(setup_token())
|
||||
|
||||
# Mock user service to return our test user
|
||||
with patch("app.api.auth.user_service") as mock_user_service:
|
||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
# Configure mock to return test user
|
||||
# Convert to UserEntry-like object
|
||||
mock_user_entry = MagicMock()
|
||||
mock_user_entry.id = test_user.id
|
||||
mock_user_entry.email = test_user.email
|
||||
mock_user_service_instance.get_by_id.return_value = mock_user_entry
|
||||
|
||||
with (
|
||||
patch("app.api.auth.get_redis", mock_get_redis),
|
||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||
):
|
||||
response = client.post(
|
||||
"/api/auth/refresh",
|
||||
json={"refresh_token": refresh_token_data["token"]},
|
||||
)
|
||||
# Override the dependency on the test app
|
||||
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert data["refresh_token"] == refresh_token_data["token"]
|
||||
assert data["token_type"] == "bearer"
|
||||
assert "expires_in" in data
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/auth/refresh",
|
||||
json={"refresh_token": refresh_token_data["token"]},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert data["refresh_token"] == refresh_token_data["token"]
|
||||
assert data["token_type"] == "bearer"
|
||||
assert "expires_in" in data
|
||||
finally:
|
||||
# Clean up override
|
||||
app.dependency_overrides.pop(get_user_service, None)
|
||||
|
||||
def test_returns_401_for_invalid_token(self, client: TestClient):
|
||||
"""Test that refresh endpoint returns 401 for invalid refresh token."""
|
||||
@ -107,21 +139,23 @@ class TestRefreshTokens:
|
||||
A refresh token not in Redis (revoked/expired) should be rejected.
|
||||
"""
|
||||
# Don't store the token in Redis - simulating revocation
|
||||
# The mock_get_redis is already patched via conftest's app fixture
|
||||
|
||||
with (
|
||||
patch("app.api.auth.get_redis", mock_get_redis),
|
||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||
):
|
||||
response = client.post(
|
||||
"/api/auth/refresh",
|
||||
json={"refresh_token": refresh_token_data["token"]},
|
||||
)
|
||||
response = client.post(
|
||||
"/api/auth/refresh",
|
||||
json={"refresh_token": refresh_token_data["token"]},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "revoked" in response.json()["detail"]
|
||||
|
||||
def test_returns_401_for_deleted_user(
|
||||
self, client: TestClient, refresh_token_data, mock_get_redis
|
||||
self,
|
||||
app,
|
||||
client: TestClient,
|
||||
refresh_token_data,
|
||||
mock_get_redis,
|
||||
mock_user_service_instance,
|
||||
):
|
||||
"""Test that refresh endpoint returns 401 if user no longer exists."""
|
||||
# Store the token
|
||||
@ -134,21 +168,23 @@ class TestRefreshTokens:
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(setup_token())
|
||||
|
||||
# Mock user service to return None (user deleted)
|
||||
with patch("app.api.auth.user_service") as mock_user_service:
|
||||
mock_user_service.get_by_id = AsyncMock(return_value=None)
|
||||
# Configure mock to return None (user deleted)
|
||||
mock_user_service_instance.get_by_id.return_value = None
|
||||
|
||||
with (
|
||||
patch("app.api.auth.get_redis", mock_get_redis),
|
||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||
):
|
||||
response = client.post(
|
||||
"/api/auth/refresh",
|
||||
json={"refresh_token": refresh_token_data["token"]},
|
||||
)
|
||||
# Override the dependency on the test app
|
||||
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "User not found" in response.json()["detail"]
|
||||
try:
|
||||
response = client.post(
|
||||
"/api/auth/refresh",
|
||||
json={"refresh_token": refresh_token_data["token"]},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "User not found" in response.json()["detail"]
|
||||
finally:
|
||||
# Clean up override
|
||||
app.dependency_overrides.pop(get_user_service, None)
|
||||
|
||||
|
||||
class TestLogout:
|
||||
@ -171,14 +207,10 @@ class TestLogout:
|
||||
key = asyncio.get_event_loop().run_until_complete(setup_and_check())
|
||||
|
||||
# Logout
|
||||
with (
|
||||
patch("app.api.auth.get_redis", mock_get_redis),
|
||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||
):
|
||||
response = client.post(
|
||||
"/api/auth/logout",
|
||||
json={"refresh_token": refresh_token_data["token"]},
|
||||
)
|
||||
response = client.post(
|
||||
"/api/auth/logout",
|
||||
json={"refresh_token": refresh_token_data["token"]},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
@ -214,7 +246,9 @@ class TestLogoutAll:
|
||||
response = client.post("/api/auth/logout-all")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_revokes_all_tokens(self, client: TestClient, test_user, access_token, mock_get_redis):
|
||||
def test_revokes_all_tokens(
|
||||
self, app, client: TestClient, test_user, access_token, mock_get_redis, mock_db_session
|
||||
):
|
||||
"""Test that logout-all revokes all refresh tokens for user.
|
||||
|
||||
Should delete all tokens matching the user's ID pattern.
|
||||
@ -233,18 +267,16 @@ class TestLogoutAll:
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
||||
|
||||
# Mock dependencies
|
||||
with patch("app.api.deps.user_service") as mock_user_service:
|
||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
# Set up mock db session to return test user when queried
|
||||
# The get_current_user dependency now does a direct DB query
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = test_user
|
||||
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with (
|
||||
patch("app.api.auth.get_redis", mock_get_redis),
|
||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
||||
):
|
||||
response = client.post(
|
||||
"/api/auth/logout-all",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
response = client.post(
|
||||
"/api/auth/logout-all",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
|
||||
@ -1,32 +1,55 @@
|
||||
"""Tests for users API endpoints.
|
||||
|
||||
Tests the user profile management endpoints.
|
||||
|
||||
Uses FastAPI's dependency override pattern for proper dependency injection testing.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.deps import get_user_service
|
||||
from app.services.user_service import AccountLinkingError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_service_instance():
|
||||
"""Create a mock UserService for dependency injection.
|
||||
|
||||
Returns a MagicMock with async methods configured.
|
||||
"""
|
||||
mock = MagicMock()
|
||||
mock.get_by_id = AsyncMock()
|
||||
mock.get_by_email = AsyncMock()
|
||||
mock.get_by_oauth = AsyncMock()
|
||||
mock.update = AsyncMock()
|
||||
mock.unlink_oauth_account = AsyncMock()
|
||||
return mock
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
"""Tests for GET /api/users/me endpoint."""
|
||||
|
||||
def test_returns_user_profile(self, client: TestClient, test_user, access_token):
|
||||
def test_returns_user_profile(
|
||||
self, app, client: TestClient, test_user, access_token, mock_db_session
|
||||
):
|
||||
"""Test that endpoint returns user profile for authenticated user.
|
||||
|
||||
Should return the user's profile information.
|
||||
"""
|
||||
with patch("app.api.deps.user_service") as mock_user_service:
|
||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
# Set up mock db session to return test user when queried
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = test_user
|
||||
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
response = client.get(
|
||||
"/api/users/me",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
response = client.get(
|
||||
"/api/users/me",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
@ -52,47 +75,93 @@ class TestGetCurrentUser:
|
||||
class TestUpdateCurrentUser:
|
||||
"""Tests for PATCH /api/users/me endpoint."""
|
||||
|
||||
def test_updates_display_name(self, client: TestClient, test_user, access_token):
|
||||
def test_updates_display_name(
|
||||
self,
|
||||
app,
|
||||
client: TestClient,
|
||||
test_user,
|
||||
access_token,
|
||||
mock_db_session,
|
||||
mock_user_service_instance,
|
||||
):
|
||||
"""Test that endpoint updates display_name when provided."""
|
||||
updated_user = test_user
|
||||
# Create an updated user mock
|
||||
updated_user = MagicMock()
|
||||
updated_user.id = test_user.id
|
||||
updated_user.email = test_user.email
|
||||
updated_user.display_name = "New Name"
|
||||
updated_user.avatar_url = test_user.avatar_url
|
||||
updated_user.is_premium = test_user.is_premium
|
||||
updated_user.premium_until = test_user.premium_until
|
||||
updated_user.created_at = test_user.created_at
|
||||
|
||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
# Set up db session to return test user for authentication
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = test_user
|
||||
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch("app.api.users.user_service") as mock_user_service:
|
||||
mock_user_service.update = AsyncMock(return_value=updated_user)
|
||||
# Set up user service mock
|
||||
mock_user_service_instance.update.return_value = updated_user
|
||||
|
||||
response = client.patch(
|
||||
"/api/users/me",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json={"display_name": "New Name"},
|
||||
)
|
||||
# Override the dependency
|
||||
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["display_name"] == "New Name"
|
||||
try:
|
||||
response = client.patch(
|
||||
"/api/users/me",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json={"display_name": "New Name"},
|
||||
)
|
||||
|
||||
def test_updates_avatar_url(self, client: TestClient, test_user, access_token):
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["display_name"] == "New Name"
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_user_service, None)
|
||||
|
||||
def test_updates_avatar_url(
|
||||
self,
|
||||
app,
|
||||
client: TestClient,
|
||||
test_user,
|
||||
access_token,
|
||||
mock_db_session,
|
||||
mock_user_service_instance,
|
||||
):
|
||||
"""Test that endpoint updates avatar_url when provided."""
|
||||
updated_user = test_user
|
||||
# Create an updated user mock
|
||||
updated_user = MagicMock()
|
||||
updated_user.id = test_user.id
|
||||
updated_user.email = test_user.email
|
||||
updated_user.display_name = test_user.display_name
|
||||
updated_user.avatar_url = "https://new-avatar.com/img.jpg"
|
||||
updated_user.is_premium = test_user.is_premium
|
||||
updated_user.premium_until = test_user.premium_until
|
||||
updated_user.created_at = test_user.created_at
|
||||
|
||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
# Set up db session to return test user for authentication
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = test_user
|
||||
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch("app.api.users.user_service") as mock_user_service:
|
||||
mock_user_service.update = AsyncMock(return_value=updated_user)
|
||||
# Set up user service mock
|
||||
mock_user_service_instance.update.return_value = updated_user
|
||||
|
||||
response = client.patch(
|
||||
"/api/users/me",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json={"avatar_url": "https://new-avatar.com/img.jpg"},
|
||||
)
|
||||
# Override the dependency
|
||||
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
|
||||
try:
|
||||
response = client.patch(
|
||||
"/api/users/me",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json={"avatar_url": "https://new-avatar.com/img.jpg"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_user_service, None)
|
||||
|
||||
def test_requires_authentication(self, client: TestClient):
|
||||
"""Test that endpoint returns 401 without authentication."""
|
||||
@ -106,18 +175,22 @@ class TestUpdateCurrentUser:
|
||||
class TestGetLinkedAccounts:
|
||||
"""Tests for GET /api/users/me/linked-accounts endpoint."""
|
||||
|
||||
def test_returns_linked_accounts(self, client: TestClient, test_user, access_token):
|
||||
def test_returns_linked_accounts(
|
||||
self, app, client: TestClient, test_user, access_token, mock_db_session
|
||||
):
|
||||
"""Test that endpoint returns list of linked OAuth accounts.
|
||||
|
||||
Should include the primary provider and any linked accounts.
|
||||
"""
|
||||
with patch("app.api.deps.user_service") as mock_user_service:
|
||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
# Set up db session to return test user for authentication
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = test_user
|
||||
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
response = client.get(
|
||||
"/api/users/me/linked-accounts",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
response = client.get(
|
||||
"/api/users/me/linked-accounts",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
@ -135,7 +208,7 @@ class TestGetActiveSessions:
|
||||
"""Tests for GET /api/users/me/sessions endpoint."""
|
||||
|
||||
def test_returns_session_count(
|
||||
self, client: TestClient, test_user, access_token, mock_get_redis
|
||||
self, app, client: TestClient, test_user, access_token, mock_get_redis, mock_db_session
|
||||
):
|
||||
"""Test that endpoint returns count of active sessions.
|
||||
|
||||
@ -154,14 +227,16 @@ class TestGetActiveSessions:
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
||||
|
||||
with patch("app.api.deps.user_service") as mock_user_service:
|
||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
# Set up db session to return test user for authentication
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = test_user
|
||||
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch("app.services.token_store.get_redis", mock_get_redis):
|
||||
response = client.get(
|
||||
"/api/users/me/sessions",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
with patch("app.services.token_store.get_redis", mock_get_redis):
|
||||
response = client.get(
|
||||
"/api/users/me/sessions",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
@ -177,78 +252,128 @@ class TestGetActiveSessions:
|
||||
class TestUnlinkOAuthAccount:
|
||||
"""Tests for DELETE /api/users/me/link/{provider} endpoint."""
|
||||
|
||||
def test_unlinks_provider_successfully(self, client: TestClient, test_user, access_token):
|
||||
def test_unlinks_provider_successfully(
|
||||
self,
|
||||
app,
|
||||
client: TestClient,
|
||||
test_user,
|
||||
access_token,
|
||||
mock_db_session,
|
||||
mock_user_service_instance,
|
||||
):
|
||||
"""Test that endpoint successfully unlinks a provider.
|
||||
|
||||
Should return 204 when provider is unlinked.
|
||||
"""
|
||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
# Set up db session to return test user for authentication
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = test_user
|
||||
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch("app.api.users.user_service") as mock_user_service:
|
||||
mock_user_service.unlink_oauth_account = AsyncMock(return_value=True)
|
||||
# Set up user service mock
|
||||
mock_user_service_instance.unlink_oauth_account.return_value = True
|
||||
|
||||
response = client.delete(
|
||||
"/api/users/me/link/discord",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
# Override the dependency
|
||||
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
try:
|
||||
response = client.delete(
|
||||
"/api/users/me/link/discord",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
def test_returns_404_if_not_linked(self, client: TestClient, test_user, access_token):
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_user_service, None)
|
||||
|
||||
def test_returns_404_if_not_linked(
|
||||
self,
|
||||
app,
|
||||
client: TestClient,
|
||||
test_user,
|
||||
access_token,
|
||||
mock_db_session,
|
||||
mock_user_service_instance,
|
||||
):
|
||||
"""Test that endpoint returns 404 if provider isn't linked.
|
||||
|
||||
Should return 404 when trying to unlink a provider that isn't linked.
|
||||
"""
|
||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
# Set up db session to return test user for authentication
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = test_user
|
||||
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch("app.api.users.user_service") as mock_user_service:
|
||||
mock_user_service.unlink_oauth_account = AsyncMock(return_value=False)
|
||||
# Set up user service mock
|
||||
mock_user_service_instance.unlink_oauth_account.return_value = False
|
||||
|
||||
response = client.delete(
|
||||
"/api/users/me/link/discord",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
# Override the dependency
|
||||
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "not linked" in response.json()["detail"].lower()
|
||||
try:
|
||||
response = client.delete(
|
||||
"/api/users/me/link/discord",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
def test_returns_400_for_primary_provider(self, client: TestClient, test_user, access_token):
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "not linked" in response.json()["detail"].lower()
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_user_service, None)
|
||||
|
||||
def test_returns_400_for_primary_provider(
|
||||
self,
|
||||
app,
|
||||
client: TestClient,
|
||||
test_user,
|
||||
access_token,
|
||||
mock_db_session,
|
||||
mock_user_service_instance,
|
||||
):
|
||||
"""Test that endpoint returns 400 when trying to unlink primary provider.
|
||||
|
||||
Cannot unlink the provider used to create the account.
|
||||
"""
|
||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
# Set up db session to return test user for authentication
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = test_user
|
||||
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch("app.api.users.user_service") as mock_user_service:
|
||||
mock_user_service.unlink_oauth_account = AsyncMock(
|
||||
side_effect=AccountLinkingError(
|
||||
"Cannot unlink Google - it is your primary login provider"
|
||||
)
|
||||
)
|
||||
# Set up user service mock to raise AccountLinkingError
|
||||
mock_user_service_instance.unlink_oauth_account.side_effect = AccountLinkingError(
|
||||
"Cannot unlink Google - it is your primary login provider"
|
||||
)
|
||||
|
||||
response = client.delete(
|
||||
"/api/users/me/link/google",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
# Override the dependency
|
||||
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "primary" in response.json()["detail"].lower()
|
||||
try:
|
||||
response = client.delete(
|
||||
"/api/users/me/link/google",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
def test_returns_400_for_unknown_provider(self, client: TestClient, test_user, access_token):
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "primary" in response.json()["detail"].lower()
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_user_service, None)
|
||||
|
||||
def test_returns_400_for_unknown_provider(
|
||||
self, app, client: TestClient, test_user, access_token, mock_db_session
|
||||
):
|
||||
"""Test that endpoint returns 400 for unknown provider.
|
||||
|
||||
Only 'google' and 'discord' are valid providers.
|
||||
"""
|
||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
||||
# Set up db session to return test user for authentication
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = test_user
|
||||
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
response = client.delete(
|
||||
"/api/users/me/link/twitter",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
response = client.delete(
|
||||
"/api/users/me/link/twitter",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "unknown provider" in response.json()["detail"].lower()
|
||||
|
||||
@ -2,25 +2,43 @@
|
||||
|
||||
Tests the user service CRUD operations and OAuth-based user creation.
|
||||
Uses real Postgres via the db_session fixture from conftest.
|
||||
|
||||
The UserService now uses dependency injection with repositories injected
|
||||
via constructor, so we create a fresh service instance per test.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db.models import User
|
||||
from app.db.models.oauth_account import OAuthLinkedAccount
|
||||
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
|
||||
from app.services.user_service import AccountLinkingError, user_service
|
||||
from app.repositories.postgres.linked_account import PostgresLinkedAccountRepository
|
||||
from app.repositories.postgres.user import PostgresUserRepository
|
||||
from app.schemas.user import OAuthUserInfo, UserUpdate
|
||||
from app.services.user_service import AccountLinkingError, UserService
|
||||
|
||||
# Import db_session fixture from db conftest
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_service(db_session):
|
||||
"""Create UserService with real PostgreSQL repositories.
|
||||
|
||||
This fixture provides a properly constructed UserService for each test,
|
||||
following the dependency injection pattern used in production.
|
||||
"""
|
||||
user_repo = PostgresUserRepository(db_session)
|
||||
linked_repo = PostgresLinkedAccountRepository(db_session)
|
||||
return UserService(user_repo, linked_repo)
|
||||
|
||||
|
||||
class TestGetById:
|
||||
"""Tests for get_by_id method."""
|
||||
|
||||
async def test_returns_user_when_found(self, db_session):
|
||||
async def test_returns_user_when_found(self, db_session, user_service):
|
||||
"""Test that get_by_id returns user when it exists.
|
||||
|
||||
Creates a user and verifies it can be retrieved by ID.
|
||||
@ -36,26 +54,24 @@ class TestGetById:
|
||||
await db_session.commit()
|
||||
|
||||
# Retrieve by ID
|
||||
from uuid import UUID
|
||||
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
result = await user_service.get_by_id(db_session, user_id)
|
||||
result = await user_service.get_by_id(user_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.email == "test@example.com"
|
||||
|
||||
async def test_returns_none_when_not_found(self, db_session):
|
||||
async def test_returns_none_when_not_found(self, user_service):
|
||||
"""Test that get_by_id returns None for nonexistent users."""
|
||||
from uuid import uuid4
|
||||
|
||||
result = await user_service.get_by_id(db_session, uuid4())
|
||||
result = await user_service.get_by_id(uuid4())
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetByEmail:
|
||||
"""Tests for get_by_email method."""
|
||||
|
||||
async def test_returns_user_when_found(self, db_session):
|
||||
async def test_returns_user_when_found(self, db_session, user_service):
|
||||
"""Test that get_by_email returns user when it exists."""
|
||||
user = User(
|
||||
email="findme@example.com",
|
||||
@ -66,21 +82,21 @@ class TestGetByEmail:
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
result = await user_service.get_by_email(db_session, "findme@example.com")
|
||||
result = await user_service.get_by_email("findme@example.com")
|
||||
|
||||
assert result is not None
|
||||
assert result.display_name == "Find Me"
|
||||
|
||||
async def test_returns_none_when_not_found(self, db_session):
|
||||
async def test_returns_none_when_not_found(self, user_service):
|
||||
"""Test that get_by_email returns None for nonexistent emails."""
|
||||
result = await user_service.get_by_email(db_session, "nobody@example.com")
|
||||
result = await user_service.get_by_email("nobody@example.com")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetByOAuth:
|
||||
"""Tests for get_by_oauth method."""
|
||||
|
||||
async def test_returns_user_when_found(self, db_session):
|
||||
async def test_returns_user_when_found(self, db_session, user_service):
|
||||
"""Test that get_by_oauth returns user for matching provider+id."""
|
||||
user = User(
|
||||
email="oauth@example.com",
|
||||
@ -91,12 +107,12 @@ class TestGetByOAuth:
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
result = await user_service.get_by_oauth(db_session, "google", "google-unique-id")
|
||||
result = await user_service.get_by_oauth("google", "google-unique-id")
|
||||
|
||||
assert result is not None
|
||||
assert result.email == "oauth@example.com"
|
||||
|
||||
async def test_returns_none_for_wrong_provider(self, db_session):
|
||||
async def test_returns_none_for_wrong_provider(self, db_session, user_service):
|
||||
"""Test that get_by_oauth returns None if provider doesn't match."""
|
||||
user = User(
|
||||
email="oauth2@example.com",
|
||||
@ -108,30 +124,28 @@ class TestGetByOAuth:
|
||||
await db_session.commit()
|
||||
|
||||
# Same ID, different provider
|
||||
result = await user_service.get_by_oauth(db_session, "discord", "google-id-2")
|
||||
result = await user_service.get_by_oauth("discord", "google-id-2")
|
||||
assert result is None
|
||||
|
||||
async def test_returns_none_when_not_found(self, db_session):
|
||||
async def test_returns_none_when_not_found(self, user_service):
|
||||
"""Test that get_by_oauth returns None for nonexistent OAuth."""
|
||||
result = await user_service.get_by_oauth(db_session, "google", "nonexistent")
|
||||
result = await user_service.get_by_oauth("google", "nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCreate:
|
||||
"""Tests for create method."""
|
||||
|
||||
async def test_creates_user_with_all_fields(self, db_session):
|
||||
async def test_creates_user_with_all_fields(self, user_service):
|
||||
"""Test that create properly persists all user fields."""
|
||||
user_data = UserCreate(
|
||||
result = await user_service.create(
|
||||
email="new@example.com",
|
||||
display_name="New User",
|
||||
avatar_url="https://example.com/avatar.jpg",
|
||||
oauth_provider="discord",
|
||||
oauth_id="discord-new-id",
|
||||
avatar_url="https://example.com/avatar.jpg",
|
||||
)
|
||||
|
||||
result = await user_service.create(db_session, user_data)
|
||||
|
||||
assert result.id is not None
|
||||
assert result.email == "new@example.com"
|
||||
assert result.display_name == "New User"
|
||||
@ -141,24 +155,22 @@ class TestCreate:
|
||||
assert result.is_premium is False
|
||||
assert result.premium_until is None
|
||||
|
||||
async def test_creates_user_without_avatar(self, db_session):
|
||||
async def test_creates_user_without_avatar(self, user_service):
|
||||
"""Test that create works without optional avatar_url."""
|
||||
user_data = UserCreate(
|
||||
result = await user_service.create(
|
||||
email="noavatar@example.com",
|
||||
display_name="No Avatar",
|
||||
oauth_provider="google",
|
||||
oauth_id="google-no-avatar",
|
||||
)
|
||||
|
||||
result = await user_service.create(db_session, user_data)
|
||||
|
||||
assert result.avatar_url is None
|
||||
|
||||
|
||||
class TestCreateFromOAuth:
|
||||
"""Tests for create_from_oauth method."""
|
||||
|
||||
async def test_creates_user_from_oauth_info(self, db_session):
|
||||
async def test_creates_user_from_oauth_info(self, user_service):
|
||||
"""Test that create_from_oauth converts OAuthUserInfo to User."""
|
||||
oauth_info = OAuthUserInfo(
|
||||
provider="google",
|
||||
@ -168,7 +180,7 @@ class TestCreateFromOAuth:
|
||||
avatar_url="https://google.com/avatar.jpg",
|
||||
)
|
||||
|
||||
result = await user_service.create_from_oauth(db_session, oauth_info)
|
||||
result = await user_service.create_from_oauth(oauth_info)
|
||||
|
||||
assert result.email == "oauthcreate@example.com"
|
||||
assert result.display_name == "OAuth Created User"
|
||||
@ -179,7 +191,7 @@ class TestCreateFromOAuth:
|
||||
class TestGetOrCreateFromOAuth:
|
||||
"""Tests for get_or_create_from_oauth method."""
|
||||
|
||||
async def test_returns_existing_user_by_oauth(self, db_session):
|
||||
async def test_returns_existing_user_by_oauth(self, db_session, user_service):
|
||||
"""Test that existing user is returned when OAuth matches.
|
||||
|
||||
Verifies the method returns (user, False) for existing users.
|
||||
@ -202,12 +214,12 @@ class TestGetOrCreateFromOAuth:
|
||||
name="Existing",
|
||||
)
|
||||
|
||||
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
||||
result, created = await user_service.get_or_create_from_oauth(oauth_info)
|
||||
|
||||
assert created is False
|
||||
assert result.id == existing.id
|
||||
assert str(result.id) == str(existing.id)
|
||||
|
||||
async def test_links_existing_user_by_email(self, db_session):
|
||||
async def test_links_existing_user_by_email(self, db_session, user_service):
|
||||
"""Test that OAuth is linked when email matches existing user.
|
||||
|
||||
If a user exists with the same email but different OAuth,
|
||||
@ -232,15 +244,15 @@ class TestGetOrCreateFromOAuth:
|
||||
avatar_url="https://discord.com/avatar.jpg",
|
||||
)
|
||||
|
||||
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
||||
result, created = await user_service.get_or_create_from_oauth(oauth_info)
|
||||
|
||||
assert created is False
|
||||
assert result.id == existing.id
|
||||
assert str(result.id) == str(existing.id)
|
||||
# OAuth should be updated to Discord
|
||||
assert result.oauth_provider == "discord"
|
||||
assert result.oauth_id == "discord-link-id"
|
||||
|
||||
async def test_creates_new_user_when_not_found(self, db_session):
|
||||
async def test_creates_new_user_when_not_found(self, user_service):
|
||||
"""Test that new user is created when no match exists.
|
||||
|
||||
Verifies the method returns (user, True) for new users.
|
||||
@ -252,7 +264,7 @@ class TestGetOrCreateFromOAuth:
|
||||
name="Brand New",
|
||||
)
|
||||
|
||||
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
||||
result, created = await user_service.get_or_create_from_oauth(oauth_info)
|
||||
|
||||
assert created is True
|
||||
assert result.email == "brandnew@example.com"
|
||||
@ -261,7 +273,7 @@ class TestGetOrCreateFromOAuth:
|
||||
class TestUpdate:
|
||||
"""Tests for update method."""
|
||||
|
||||
async def test_updates_display_name(self, db_session):
|
||||
async def test_updates_display_name(self, db_session, user_service):
|
||||
"""Test that update changes display_name when provided."""
|
||||
user = User(
|
||||
email="update@example.com",
|
||||
@ -272,12 +284,13 @@ class TestUpdate:
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
update_data = UserUpdate(display_name="New Name")
|
||||
result = await user_service.update(db_session, user, update_data)
|
||||
result = await user_service.update(user_id, update_data)
|
||||
|
||||
assert result.display_name == "New Name"
|
||||
|
||||
async def test_updates_avatar_url(self, db_session):
|
||||
async def test_updates_avatar_url(self, db_session, user_service):
|
||||
"""Test that update changes avatar_url when provided."""
|
||||
user = User(
|
||||
email="avatar@example.com",
|
||||
@ -288,12 +301,13 @@ class TestUpdate:
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
update_data = UserUpdate(avatar_url="https://new-avatar.com/img.jpg")
|
||||
result = await user_service.update(db_session, user, update_data)
|
||||
result = await user_service.update(user_id, update_data)
|
||||
|
||||
assert result.avatar_url == "https://new-avatar.com/img.jpg"
|
||||
|
||||
async def test_ignores_none_values(self, db_session):
|
||||
async def test_ignores_none_values(self, db_session, user_service):
|
||||
"""Test that update doesn't change fields set to None.
|
||||
|
||||
Only explicitly provided fields should be updated.
|
||||
@ -309,8 +323,9 @@ class TestUpdate:
|
||||
await db_session.commit()
|
||||
|
||||
# Update only display_name, leave avatar alone
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
update_data = UserUpdate(display_name="Changed")
|
||||
result = await user_service.update(db_session, user, update_data)
|
||||
result = await user_service.update(user_id, update_data)
|
||||
|
||||
assert result.display_name == "Changed"
|
||||
assert result.avatar_url == "https://keep.com/avatar.jpg"
|
||||
@ -319,7 +334,7 @@ class TestUpdate:
|
||||
class TestUpdateLastLogin:
|
||||
"""Tests for update_last_login method."""
|
||||
|
||||
async def test_updates_last_login_timestamp(self, db_session):
|
||||
async def test_updates_last_login_timestamp(self, db_session, user_service):
|
||||
"""Test that update_last_login sets current timestamp."""
|
||||
user = User(
|
||||
email="login@example.com",
|
||||
@ -332,8 +347,9 @@ class TestUpdateLastLogin:
|
||||
|
||||
assert user.last_login is None
|
||||
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
before = datetime.now(UTC)
|
||||
result = await user_service.update_last_login(db_session, user)
|
||||
result = await user_service.update_last_login(user_id)
|
||||
after = datetime.now(UTC)
|
||||
|
||||
assert result.last_login is not None
|
||||
@ -344,7 +360,7 @@ class TestUpdateLastLogin:
|
||||
class TestUpdatePremium:
|
||||
"""Tests for update_premium method."""
|
||||
|
||||
async def test_grants_premium(self, db_session):
|
||||
async def test_grants_premium(self, db_session, user_service):
|
||||
"""Test that update_premium sets premium status and expiration."""
|
||||
user = User(
|
||||
email="premium@example.com",
|
||||
@ -357,13 +373,14 @@ class TestUpdatePremium:
|
||||
|
||||
assert user.is_premium is False
|
||||
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
expires = datetime.now(UTC) + timedelta(days=30)
|
||||
result = await user_service.update_premium(db_session, user, expires)
|
||||
result = await user_service.update_premium(user_id, expires)
|
||||
|
||||
assert result.is_premium is True
|
||||
assert result.premium_until == expires
|
||||
|
||||
async def test_removes_premium(self, db_session):
|
||||
async def test_removes_premium(self, db_session, user_service):
|
||||
"""Test that update_premium with None removes premium status."""
|
||||
user = User(
|
||||
email="unpremium@example.com",
|
||||
@ -376,7 +393,8 @@ class TestUpdatePremium:
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
result = await user_service.update_premium(db_session, user, None)
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
result = await user_service.update_premium(user_id, None)
|
||||
|
||||
assert result.is_premium is False
|
||||
assert result.premium_until is None
|
||||
@ -385,7 +403,7 @@ class TestUpdatePremium:
|
||||
class TestDelete:
|
||||
"""Tests for delete method."""
|
||||
|
||||
async def test_deletes_user(self, db_session):
|
||||
async def test_deletes_user(self, db_session, user_service):
|
||||
"""Test that delete removes user from database."""
|
||||
user = User(
|
||||
email="delete@example.com",
|
||||
@ -396,22 +414,18 @@ class TestDelete:
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
user_id = user.id
|
||||
await user_service.delete(db_session, user)
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
await user_service.delete(user_id)
|
||||
|
||||
# Verify user is gone
|
||||
from uuid import UUID
|
||||
|
||||
result = await user_service.get_by_id(
|
||||
db_session, UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
)
|
||||
result = await user_service.get_by_id(user_id)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetLinkedAccount:
|
||||
"""Tests for get_linked_account method."""
|
||||
|
||||
async def test_returns_linked_account_when_found(self, db_session):
|
||||
async def test_returns_linked_account_when_found(self, db_session, user_service):
|
||||
"""Test that get_linked_account returns account when it exists.
|
||||
|
||||
Creates a user with a linked account and verifies it can be retrieved.
|
||||
@ -437,22 +451,22 @@ class TestGetLinkedAccount:
|
||||
await db_session.commit()
|
||||
|
||||
# Retrieve linked account
|
||||
result = await user_service.get_linked_account(db_session, "discord", "discord-linked-123")
|
||||
result = await user_service.get_linked_account("discord", "discord-linked-123")
|
||||
|
||||
assert result is not None
|
||||
assert result.provider == "discord"
|
||||
assert result.oauth_id == "discord-linked-123"
|
||||
|
||||
async def test_returns_none_when_not_found(self, db_session):
|
||||
async def test_returns_none_when_not_found(self, user_service):
|
||||
"""Test that get_linked_account returns None for nonexistent accounts."""
|
||||
result = await user_service.get_linked_account(db_session, "discord", "nonexistent-id")
|
||||
result = await user_service.get_linked_account("discord", "nonexistent-id")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLinkOAuthAccount:
|
||||
"""Tests for link_oauth_account method."""
|
||||
|
||||
async def test_links_new_provider(self, db_session):
|
||||
async def test_links_new_provider(self, db_session, user_service):
|
||||
"""Test that link_oauth_account successfully links a new provider.
|
||||
|
||||
Creates a Google user and links Discord to them.
|
||||
@ -468,6 +482,8 @@ class TestLinkOAuthAccount:
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
|
||||
# Link Discord
|
||||
discord_info = OAuthUserInfo(
|
||||
provider="discord",
|
||||
@ -477,16 +493,16 @@ class TestLinkOAuthAccount:
|
||||
avatar_url="https://discord.com/avatar.png",
|
||||
)
|
||||
|
||||
result = await user_service.link_oauth_account(db_session, user, discord_info)
|
||||
result = await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
|
||||
|
||||
assert result is not None
|
||||
assert result.provider == "discord"
|
||||
assert result.oauth_id == "discord-456"
|
||||
assert result.email == "discord@example.com"
|
||||
assert result.display_name == "Discord Name"
|
||||
assert str(result.user_id) == str(user.id)
|
||||
assert str(result.user_id) == str(user_id)
|
||||
|
||||
async def test_raises_error_if_already_linked_to_same_user(self, db_session):
|
||||
async def test_raises_error_if_already_linked_to_same_user(self, db_session, user_service):
|
||||
"""Test that linking same provider twice raises error.
|
||||
|
||||
A user cannot have the same provider linked multiple times.
|
||||
@ -501,6 +517,8 @@ class TestLinkOAuthAccount:
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
|
||||
# Link Discord first time
|
||||
discord_info = OAuthUserInfo(
|
||||
provider="discord",
|
||||
@ -508,16 +526,15 @@ class TestLinkOAuthAccount:
|
||||
email="first@discord.com",
|
||||
name="First",
|
||||
)
|
||||
await user_service.link_oauth_account(db_session, user, discord_info)
|
||||
await db_session.refresh(user)
|
||||
await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
|
||||
|
||||
# Try to link same Discord account again
|
||||
with pytest.raises(AccountLinkingError) as exc_info:
|
||||
await user_service.link_oauth_account(db_session, user, discord_info)
|
||||
await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
|
||||
|
||||
assert "already linked to your account" in str(exc_info.value)
|
||||
|
||||
async def test_raises_error_if_linked_to_another_user(self, db_session):
|
||||
async def test_raises_error_if_linked_to_another_user(self, db_session, user_service):
|
||||
"""Test that linking account already linked to another user raises error.
|
||||
|
||||
The same OAuth provider+ID cannot be linked to multiple users.
|
||||
@ -533,13 +550,15 @@ class TestLinkOAuthAccount:
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user1)
|
||||
|
||||
user1_id = UUID(user1.id) if isinstance(user1.id, str) else user1.id
|
||||
|
||||
discord_info = OAuthUserInfo(
|
||||
provider="discord",
|
||||
oauth_id="shared-discord",
|
||||
email="shared@discord.com",
|
||||
name="Shared",
|
||||
)
|
||||
await user_service.link_oauth_account(db_session, user1, discord_info)
|
||||
await user_service.link_oauth_account(user1_id, user1.oauth_provider, discord_info)
|
||||
|
||||
# Create second user
|
||||
user2 = User(
|
||||
@ -552,13 +571,15 @@ class TestLinkOAuthAccount:
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user2)
|
||||
|
||||
user2_id = UUID(user2.id) if isinstance(user2.id, str) else user2.id
|
||||
|
||||
# Try to link same Discord account to second user
|
||||
with pytest.raises(AccountLinkingError) as exc_info:
|
||||
await user_service.link_oauth_account(db_session, user2, discord_info)
|
||||
await user_service.link_oauth_account(user2_id, user2.oauth_provider, discord_info)
|
||||
|
||||
assert "already linked to another user" in str(exc_info.value)
|
||||
|
||||
async def test_raises_error_if_linking_primary_provider(self, db_session):
|
||||
async def test_raises_error_if_linking_primary_provider(self, db_session, user_service):
|
||||
"""Test that linking the same provider as primary raises error.
|
||||
|
||||
User cannot link Google if they already signed up with Google.
|
||||
@ -573,6 +594,8 @@ class TestLinkOAuthAccount:
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
|
||||
# Try to link another Google account
|
||||
google_info = OAuthUserInfo(
|
||||
provider="google",
|
||||
@ -582,7 +605,7 @@ class TestLinkOAuthAccount:
|
||||
)
|
||||
|
||||
with pytest.raises(AccountLinkingError) as exc_info:
|
||||
await user_service.link_oauth_account(db_session, user, google_info)
|
||||
await user_service.link_oauth_account(user_id, user.oauth_provider, google_info)
|
||||
|
||||
assert "primary login provider" in str(exc_info.value)
|
||||
|
||||
@ -590,7 +613,7 @@ class TestLinkOAuthAccount:
|
||||
class TestUnlinkOAuthAccount:
|
||||
"""Tests for unlink_oauth_account method."""
|
||||
|
||||
async def test_unlinks_linked_account(self, db_session):
|
||||
async def test_unlinks_linked_account(self, db_session, user_service):
|
||||
"""Test that unlink_oauth_account removes a linked account.
|
||||
|
||||
Links Discord then unlinks it successfully.
|
||||
@ -605,6 +628,8 @@ class TestUnlinkOAuthAccount:
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
|
||||
# Link Discord
|
||||
discord_info = OAuthUserInfo(
|
||||
provider="discord",
|
||||
@ -612,22 +637,18 @@ class TestUnlinkOAuthAccount:
|
||||
email="discord@unlink.com",
|
||||
name="Discord Unlink",
|
||||
)
|
||||
await user_service.link_oauth_account(db_session, user, discord_info)
|
||||
await db_session.refresh(user)
|
||||
|
||||
# Verify linked
|
||||
assert len(user.linked_accounts) == 1
|
||||
await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
|
||||
|
||||
# Unlink
|
||||
result = await user_service.unlink_oauth_account(db_session, user, "discord")
|
||||
result = await user_service.unlink_oauth_account(user_id, user.oauth_provider, "discord")
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify unlinked
|
||||
linked = await user_service.get_linked_account(db_session, "discord", "discord-unlink")
|
||||
linked = await user_service.get_linked_account("discord", "discord-unlink")
|
||||
assert linked is None
|
||||
|
||||
async def test_returns_false_if_not_linked(self, db_session):
|
||||
async def test_returns_false_if_not_linked(self, db_session, user_service):
|
||||
"""Test that unlink returns False if provider isn't linked."""
|
||||
user = User(
|
||||
email="not-linked@example.com",
|
||||
@ -639,11 +660,12 @@ class TestUnlinkOAuthAccount:
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
result = await user_service.unlink_oauth_account(db_session, user, "discord")
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
result = await user_service.unlink_oauth_account(user_id, user.oauth_provider, "discord")
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_raises_error_if_unlinking_primary(self, db_session):
|
||||
async def test_raises_error_if_unlinking_primary(self, db_session, user_service):
|
||||
"""Test that unlinking primary provider raises error.
|
||||
|
||||
User cannot unlink their primary OAuth provider.
|
||||
@ -658,7 +680,9 @@ class TestUnlinkOAuthAccount:
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
|
||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||
|
||||
with pytest.raises(AccountLinkingError) as exc_info:
|
||||
await user_service.unlink_oauth_account(db_session, user, "google")
|
||||
await user_service.unlink_oauth_account(user_id, user.oauth_provider, "google")
|
||||
|
||||
assert "primary login provider" in str(exc_info.value)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user