Implement UserRepository pattern with dependency injection
- Add UserRepository and LinkedAccountRepository protocols to protocols.py - Add UserEntry and LinkedAccountEntry DTOs for service layer decoupling - Implement PostgresUserRepository and PostgresLinkedAccountRepository - Refactor UserService to use constructor-injected repositories - Add get_user_service factory and UserServiceDep to API deps - Update auth.py and users.py endpoints to use UserServiceDep - Rewrite tests to use FastAPI dependency overrides (no monkey patching) This follows the established repository pattern used by DeckService and CollectionService, enabling future offline fork support. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
f6e8ab5f67
commit
7fcb86ff51
@ -31,7 +31,7 @@ import secrets
|
|||||||
from fastapi import APIRouter, HTTPException, Query, status
|
from fastapi import APIRouter, HTTPException, Query, status
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
from app.api.deps import CurrentUser, DbSession
|
from app.api.deps import CurrentUser, UserServiceDep
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.db.redis import get_redis
|
from app.db.redis import get_redis
|
||||||
from app.schemas.auth import RefreshTokenRequest, TokenResponse
|
from app.schemas.auth import RefreshTokenRequest, TokenResponse
|
||||||
@ -45,7 +45,7 @@ from app.services.jwt_service import (
|
|||||||
from app.services.oauth.discord import DiscordOAuthError, discord_oauth
|
from app.services.oauth.discord import DiscordOAuthError, discord_oauth
|
||||||
from app.services.oauth.google import GoogleOAuthError, google_oauth
|
from app.services.oauth.google import GoogleOAuthError, google_oauth
|
||||||
from app.services.token_store import token_store
|
from app.services.token_store import token_store
|
||||||
from app.services.user_service import AccountLinkingError, user_service
|
from app.services.user_service import AccountLinkingError
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
@ -150,7 +150,7 @@ async def google_auth_redirect(
|
|||||||
|
|
||||||
@router.get("/google/callback")
|
@router.get("/google/callback")
|
||||||
async def google_auth_callback(
|
async def google_auth_callback(
|
||||||
db: DbSession,
|
user_service: UserServiceDep,
|
||||||
code: str = Query(..., description="Authorization code from Google"),
|
code: str = Query(..., description="Authorization code from Google"),
|
||||||
state: str = Query(..., description="State parameter for CSRF validation"),
|
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||||
) -> TokenResponse:
|
) -> TokenResponse:
|
||||||
@ -182,10 +182,10 @@ async def google_auth_callback(
|
|||||||
user_info = await google_oauth.get_user_info(code, oauth_callback)
|
user_info = await google_oauth.get_user_info(code, oauth_callback)
|
||||||
|
|
||||||
# Get or create user
|
# Get or create user
|
||||||
user, created = await user_service.get_or_create_from_oauth(db, user_info)
|
user, _created = await user_service.get_or_create_from_oauth(user_info)
|
||||||
|
|
||||||
# Update last login
|
# Update last login
|
||||||
await user_service.update_last_login(db, user)
|
await user_service.update_last_login(user.id)
|
||||||
|
|
||||||
# Create tokens
|
# Create tokens
|
||||||
return await _create_tokens_for_user(user.id)
|
return await _create_tokens_for_user(user.id)
|
||||||
@ -243,7 +243,7 @@ async def discord_auth_redirect(
|
|||||||
|
|
||||||
@router.get("/discord/callback")
|
@router.get("/discord/callback")
|
||||||
async def discord_auth_callback(
|
async def discord_auth_callback(
|
||||||
db: DbSession,
|
user_service: UserServiceDep,
|
||||||
code: str = Query(..., description="Authorization code from Discord"),
|
code: str = Query(..., description="Authorization code from Discord"),
|
||||||
state: str = Query(..., description="State parameter for CSRF validation"),
|
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||||
) -> RedirectResponse:
|
) -> RedirectResponse:
|
||||||
@ -276,10 +276,10 @@ async def discord_auth_callback(
|
|||||||
user_info = await discord_oauth.get_user_info(code, oauth_callback)
|
user_info = await discord_oauth.get_user_info(code, oauth_callback)
|
||||||
|
|
||||||
# Get or create user
|
# Get or create user
|
||||||
user, created = await user_service.get_or_create_from_oauth(db, user_info)
|
user, _created = await user_service.get_or_create_from_oauth(user_info)
|
||||||
|
|
||||||
# Update last login
|
# Update last login
|
||||||
await user_service.update_last_login(db, user)
|
await user_service.update_last_login(user.id)
|
||||||
|
|
||||||
# Create tokens
|
# Create tokens
|
||||||
tokens = await _create_tokens_for_user(user.id)
|
tokens = await _create_tokens_for_user(user.id)
|
||||||
@ -307,7 +307,7 @@ async def discord_auth_callback(
|
|||||||
|
|
||||||
@router.post("/refresh", response_model=TokenResponse)
|
@router.post("/refresh", response_model=TokenResponse)
|
||||||
async def refresh_tokens(
|
async def refresh_tokens(
|
||||||
db: DbSession,
|
user_service: UserServiceDep,
|
||||||
request: RefreshTokenRequest,
|
request: RefreshTokenRequest,
|
||||||
) -> TokenResponse:
|
) -> TokenResponse:
|
||||||
"""Refresh access token using refresh token.
|
"""Refresh access token using refresh token.
|
||||||
@ -344,7 +344,7 @@ async def refresh_tokens(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Verify user still exists
|
# Verify user still exists
|
||||||
user = await user_service.get_by_id(db, user_id)
|
user = await user_service.get_by_id(user_id)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@ -489,7 +489,7 @@ async def google_link_redirect(
|
|||||||
|
|
||||||
@router.get("/link/google/callback")
|
@router.get("/link/google/callback")
|
||||||
async def google_link_callback(
|
async def google_link_callback(
|
||||||
db: DbSession,
|
user_service: UserServiceDep,
|
||||||
code: str = Query(..., description="Authorization code from Google"),
|
code: str = Query(..., description="Authorization code from Google"),
|
||||||
state: str = Query(..., description="State parameter for CSRF validation"),
|
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||||
) -> RedirectResponse:
|
) -> RedirectResponse:
|
||||||
@ -523,7 +523,7 @@ async def google_link_callback(
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
user_id = UUID(user_id_str)
|
user_id = UUID(user_id_str)
|
||||||
user = await user_service.get_by_id(db, user_id)
|
user = await user_service.get_by_id(user_id)
|
||||||
if user is None:
|
if user is None:
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
url=f"{redirect_uri}?error=user_not_found",
|
url=f"{redirect_uri}?error=user_not_found",
|
||||||
@ -531,7 +531,7 @@ async def google_link_callback(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Link the account
|
# Link the account
|
||||||
await user_service.link_oauth_account(db, user, oauth_info)
|
await user_service.link_oauth_account(user.id, user.oauth_provider, oauth_info)
|
||||||
|
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
url=f"{redirect_uri}?linked=google",
|
url=f"{redirect_uri}?linked=google",
|
||||||
@ -600,7 +600,7 @@ async def discord_link_redirect(
|
|||||||
|
|
||||||
@router.get("/link/discord/callback")
|
@router.get("/link/discord/callback")
|
||||||
async def discord_link_callback(
|
async def discord_link_callback(
|
||||||
db: DbSession,
|
user_service: UserServiceDep,
|
||||||
code: str = Query(..., description="Authorization code from Discord"),
|
code: str = Query(..., description="Authorization code from Discord"),
|
||||||
state: str = Query(..., description="State parameter for CSRF validation"),
|
state: str = Query(..., description="State parameter for CSRF validation"),
|
||||||
) -> RedirectResponse:
|
) -> RedirectResponse:
|
||||||
@ -634,7 +634,7 @@ async def discord_link_callback(
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
user_id = UUID(user_id_str)
|
user_id = UUID(user_id_str)
|
||||||
user = await user_service.get_by_id(db, user_id)
|
user = await user_service.get_by_id(user_id)
|
||||||
if user is None:
|
if user is None:
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
url=f"{redirect_uri}?error=user_not_found",
|
url=f"{redirect_uri}?error=user_not_found",
|
||||||
@ -642,7 +642,7 @@ async def discord_link_callback(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Link the account
|
# Link the account
|
||||||
await user_service.link_oauth_account(db, user, oauth_info)
|
await user_service.link_oauth_account(user.id, user.oauth_provider, oauth_info)
|
||||||
|
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
url=f"{redirect_uri}?linked=discord",
|
url=f"{redirect_uri}?linked=discord",
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from typing import Annotated
|
|||||||
|
|
||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
@ -35,13 +36,15 @@ from app.db import get_session
|
|||||||
from app.db.models import User
|
from app.db.models import User
|
||||||
from app.repositories.postgres.collection import PostgresCollectionRepository
|
from app.repositories.postgres.collection import PostgresCollectionRepository
|
||||||
from app.repositories.postgres.deck import PostgresDeckRepository
|
from app.repositories.postgres.deck import PostgresDeckRepository
|
||||||
|
from app.repositories.postgres.linked_account import PostgresLinkedAccountRepository
|
||||||
|
from app.repositories.postgres.user import PostgresUserRepository
|
||||||
from app.services.card_service import CardService, get_card_service
|
from app.services.card_service import CardService, get_card_service
|
||||||
from app.services.collection_service import CollectionService
|
from app.services.collection_service import CollectionService
|
||||||
from app.services.deck_service import DeckService
|
from app.services.deck_service import DeckService
|
||||||
from app.services.game_service import GameService, game_service
|
from app.services.game_service import GameService, game_service
|
||||||
from app.services.game_state_manager import GameStateManager, game_state_manager
|
from app.services.game_state_manager import GameStateManager, game_state_manager
|
||||||
from app.services.jwt_service import verify_access_token
|
from app.services.jwt_service import verify_access_token
|
||||||
from app.services.user_service import user_service
|
from app.services.user_service import UserService
|
||||||
|
|
||||||
# OAuth2 scheme for extracting Bearer token from Authorization header
|
# OAuth2 scheme for extracting Bearer token from Authorization header
|
||||||
# tokenUrl is for OpenAPI docs - points to where tokens are obtained
|
# tokenUrl is for OpenAPI docs - points to where tokens are obtained
|
||||||
@ -148,8 +151,9 @@ async def get_current_user(
|
|||||||
if user_id is None:
|
if user_id is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
# Fetch user from database
|
# Fetch user from database (direct query for auth - not business logic)
|
||||||
user = await user_service.get_by_id(db, user_id)
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
if user is None:
|
if user is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
@ -187,7 +191,9 @@ async def get_optional_user(
|
|||||||
if user_id is None:
|
if user_id is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return await user_service.get_by_id(db, user_id)
|
# Direct query for auth - not business logic
|
||||||
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
async def get_current_premium_user(
|
async def get_current_premium_user(
|
||||||
@ -333,6 +339,31 @@ def get_game_state_manager_dep() -> GameStateManager:
|
|||||||
return game_state_manager
|
return game_state_manager
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_service(
|
||||||
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
) -> UserService:
|
||||||
|
"""Get UserService with PostgreSQL repositories.
|
||||||
|
|
||||||
|
Creates a UserService instance with user and linked account repositories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session from request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserService configured for PostgreSQL.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@router.post("/auth/google/callback")
|
||||||
|
async def google_callback(
|
||||||
|
user_service: UserService = Depends(get_user_service),
|
||||||
|
):
|
||||||
|
user, created = await user_service.get_or_create_from_oauth(oauth_info)
|
||||||
|
"""
|
||||||
|
user_repo = PostgresUserRepository(db)
|
||||||
|
linked_repo = PostgresLinkedAccountRepository(db)
|
||||||
|
return UserService(user_repo, linked_repo)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Type Aliases for Cleaner Endpoint Signatures
|
# Type Aliases for Cleaner Endpoint Signatures
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -351,6 +382,7 @@ CollectionServiceDep = Annotated[CollectionService, Depends(get_collection_servi
|
|||||||
CardServiceDep = Annotated[CardService, Depends(get_card_service_dep)]
|
CardServiceDep = Annotated[CardService, Depends(get_card_service_dep)]
|
||||||
GameServiceDep = Annotated[GameService, Depends(get_game_service_dep)]
|
GameServiceDep = Annotated[GameService, Depends(get_game_service_dep)]
|
||||||
GameStateManagerDep = Annotated[GameStateManager, Depends(get_game_state_manager_dep)]
|
GameStateManagerDep = Annotated[GameStateManager, Depends(get_game_state_manager_dep)]
|
||||||
|
UserServiceDep = Annotated[UserService, Depends(get_user_service)]
|
||||||
|
|
||||||
# Admin authentication
|
# Admin authentication
|
||||||
AdminAuth = Annotated[None, Depends(verify_admin_token)]
|
AdminAuth = Annotated[None, Depends(verify_admin_token)]
|
||||||
|
|||||||
@ -26,12 +26,12 @@ Example:
|
|||||||
from fastapi import APIRouter, HTTPException, status
|
from fastapi import APIRouter, HTTPException, status
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.api.deps import CurrentUser, DbSession, DeckServiceDep
|
from app.api.deps import CurrentUser, DeckServiceDep, UserServiceDep
|
||||||
from app.schemas.deck import DeckResponse, StarterDeckSelectRequest, StarterStatusResponse
|
from app.schemas.deck import DeckResponse, StarterDeckSelectRequest, StarterStatusResponse
|
||||||
from app.schemas.user import UserResponse, UserUpdate
|
from app.schemas.user import UserResponse, UserUpdate
|
||||||
from app.services.deck_service import DeckLimitExceededError, StarterAlreadySelectedError
|
from app.services.deck_service import DeckLimitExceededError, StarterAlreadySelectedError
|
||||||
from app.services.token_store import token_store
|
from app.services.token_store import token_store
|
||||||
from app.services.user_service import AccountLinkingError, user_service
|
from app.services.user_service import AccountLinkingError
|
||||||
|
|
||||||
router = APIRouter(prefix="/users", tags=["users"])
|
router = APIRouter(prefix="/users", tags=["users"])
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ async def get_current_user_profile(
|
|||||||
@router.patch("/me", response_model=UserResponse)
|
@router.patch("/me", response_model=UserResponse)
|
||||||
async def update_current_user_profile(
|
async def update_current_user_profile(
|
||||||
user: CurrentUser,
|
user: CurrentUser,
|
||||||
db: DbSession,
|
user_service: UserServiceDep,
|
||||||
update_data: UserUpdate,
|
update_data: UserUpdate,
|
||||||
) -> UserResponse:
|
) -> UserResponse:
|
||||||
"""Update the current user's profile.
|
"""Update the current user's profile.
|
||||||
@ -78,7 +78,7 @@ async def update_current_user_profile(
|
|||||||
Returns:
|
Returns:
|
||||||
Updated user profile.
|
Updated user profile.
|
||||||
"""
|
"""
|
||||||
updated_user = await user_service.update(db, user, update_data)
|
updated_user = await user_service.update(user.id, update_data)
|
||||||
return UserResponse.model_validate(updated_user)
|
return UserResponse.model_validate(updated_user)
|
||||||
|
|
||||||
|
|
||||||
@ -139,7 +139,7 @@ async def get_active_sessions(
|
|||||||
@router.delete("/me/link/{provider}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/me/link/{provider}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
async def unlink_oauth_account(
|
async def unlink_oauth_account(
|
||||||
user: CurrentUser,
|
user: CurrentUser,
|
||||||
db: DbSession,
|
user_service: UserServiceDep,
|
||||||
provider: str,
|
provider: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Unlink an OAuth provider from the current user's account.
|
"""Unlink an OAuth provider from the current user's account.
|
||||||
@ -162,7 +162,7 @@ async def unlink_oauth_account(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
unlinked = await user_service.unlink_oauth_account(db, user, provider)
|
unlinked = await user_service.unlink_oauth_account(user.id, user.oauth_provider, provider)
|
||||||
if not unlinked:
|
if not unlinked:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
|||||||
@ -10,11 +10,12 @@ The protocol pattern enables:
|
|||||||
- Offline fork support without rewriting service layer
|
- Offline fork support without rewriting service layer
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from app.repositories import CollectionRepository, DeckRepository
|
from app.repositories import CollectionRepository, DeckRepository, UserRepository
|
||||||
from app.repositories.postgres import PostgresCollectionRepository
|
from app.repositories.postgres import PostgresCollectionRepository, PostgresUserRepository
|
||||||
|
|
||||||
# In production (dependency injection)
|
# In production (dependency injection)
|
||||||
repo = PostgresCollectionRepository(db_session)
|
repo = PostgresCollectionRepository(db_session)
|
||||||
|
user_repo = PostgresUserRepository(db_session)
|
||||||
|
|
||||||
# In tests
|
# In tests
|
||||||
repo = MockCollectionRepository()
|
repo = MockCollectionRepository()
|
||||||
@ -23,9 +24,13 @@ Usage:
|
|||||||
from app.repositories.protocols import (
|
from app.repositories.protocols import (
|
||||||
CollectionRepository,
|
CollectionRepository,
|
||||||
DeckRepository,
|
DeckRepository,
|
||||||
|
LinkedAccountRepository,
|
||||||
|
UserRepository,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CollectionRepository",
|
"CollectionRepository",
|
||||||
"DeckRepository",
|
"DeckRepository",
|
||||||
|
"LinkedAccountRepository",
|
||||||
|
"UserRepository",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -8,11 +8,15 @@ Usage:
|
|||||||
from app.repositories.postgres import (
|
from app.repositories.postgres import (
|
||||||
PostgresCollectionRepository,
|
PostgresCollectionRepository,
|
||||||
PostgresDeckRepository,
|
PostgresDeckRepository,
|
||||||
|
PostgresLinkedAccountRepository,
|
||||||
|
PostgresUserRepository,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create repository with database session
|
# Create repository with database session
|
||||||
collection_repo = PostgresCollectionRepository(db_session)
|
collection_repo = PostgresCollectionRepository(db_session)
|
||||||
deck_repo = PostgresDeckRepository(db_session)
|
deck_repo = PostgresDeckRepository(db_session)
|
||||||
|
user_repo = PostgresUserRepository(db_session)
|
||||||
|
linked_repo = PostgresLinkedAccountRepository(db_session)
|
||||||
|
|
||||||
# Use via service layer
|
# Use via service layer
|
||||||
service = CollectionService(collection_repo)
|
service = CollectionService(collection_repo)
|
||||||
@ -20,8 +24,12 @@ Usage:
|
|||||||
|
|
||||||
from app.repositories.postgres.collection import PostgresCollectionRepository
|
from app.repositories.postgres.collection import PostgresCollectionRepository
|
||||||
from app.repositories.postgres.deck import PostgresDeckRepository
|
from app.repositories.postgres.deck import PostgresDeckRepository
|
||||||
|
from app.repositories.postgres.linked_account import PostgresLinkedAccountRepository
|
||||||
|
from app.repositories.postgres.user import PostgresUserRepository
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PostgresCollectionRepository",
|
"PostgresCollectionRepository",
|
||||||
"PostgresDeckRepository",
|
"PostgresDeckRepository",
|
||||||
|
"PostgresLinkedAccountRepository",
|
||||||
|
"PostgresUserRepository",
|
||||||
]
|
]
|
||||||
|
|||||||
149
backend/app/repositories/postgres/linked_account.py
Normal file
149
backend/app/repositories/postgres/linked_account.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
"""PostgreSQL implementation of LinkedAccountRepository.
|
||||||
|
|
||||||
|
This module provides the PostgreSQL-specific implementation of the
|
||||||
|
LinkedAccountRepository protocol using SQLAlchemy async sessions.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
async with get_db_session() as db:
|
||||||
|
repo = PostgresLinkedAccountRepository(db)
|
||||||
|
accounts = await repo.get_by_user(user_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db.models.oauth_account import OAuthLinkedAccount
|
||||||
|
from app.repositories.protocols import LinkedAccountEntry
|
||||||
|
|
||||||
|
|
||||||
|
def _to_dto(model: OAuthLinkedAccount) -> LinkedAccountEntry:
|
||||||
|
"""Convert ORM model to DTO."""
|
||||||
|
return LinkedAccountEntry(
|
||||||
|
id=model.id,
|
||||||
|
user_id=UUID(model.user_id) if isinstance(model.user_id, str) else model.user_id,
|
||||||
|
provider=model.provider,
|
||||||
|
oauth_id=model.oauth_id,
|
||||||
|
email=model.email,
|
||||||
|
display_name=model.display_name,
|
||||||
|
avatar_url=model.avatar_url,
|
||||||
|
linked_at=model.linked_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresLinkedAccountRepository:
|
||||||
|
"""PostgreSQL implementation of LinkedAccountRepository.
|
||||||
|
|
||||||
|
Uses SQLAlchemy async sessions for database access.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_db: The async database session.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
|
"""Initialize with database session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: SQLAlchemy async session.
|
||||||
|
"""
|
||||||
|
self._db = db
|
||||||
|
|
||||||
|
async def get_by_provider(
|
||||||
|
self,
|
||||||
|
provider: str,
|
||||||
|
oauth_id: str,
|
||||||
|
) -> LinkedAccountEntry | None:
|
||||||
|
"""Get a linked account by provider and OAuth ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: OAuth provider name (google, discord).
|
||||||
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LinkedAccountEntry if found, None otherwise.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(OAuthLinkedAccount).where(
|
||||||
|
OAuthLinkedAccount.provider == provider,
|
||||||
|
OAuthLinkedAccount.oauth_id == oauth_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
return _to_dto(model) if model else None
|
||||||
|
|
||||||
|
async def get_by_user(self, user_id: UUID) -> list[LinkedAccountEntry]:
|
||||||
|
"""Get all linked accounts for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LinkedAccountEntry, ordered by provider.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(OAuthLinkedAccount)
|
||||||
|
.where(OAuthLinkedAccount.user_id == str(user_id))
|
||||||
|
.order_by(OAuthLinkedAccount.provider)
|
||||||
|
)
|
||||||
|
return [_to_dto(model) for model in result.scalars().all()]
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
user_id: UUID,
|
||||||
|
provider: str,
|
||||||
|
oauth_id: str,
|
||||||
|
email: str | None = None,
|
||||||
|
display_name: str | None = None,
|
||||||
|
avatar_url: str | None = None,
|
||||||
|
) -> LinkedAccountEntry:
|
||||||
|
"""Link an OAuth provider to a user account.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
provider: OAuth provider name.
|
||||||
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
email: Email from the OAuth provider.
|
||||||
|
display_name: Display name from the OAuth provider.
|
||||||
|
avatar_url: Avatar URL from the OAuth provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created LinkedAccountEntry.
|
||||||
|
"""
|
||||||
|
linked_account = OAuthLinkedAccount(
|
||||||
|
user_id=str(user_id),
|
||||||
|
provider=provider,
|
||||||
|
oauth_id=oauth_id,
|
||||||
|
email=email,
|
||||||
|
display_name=display_name,
|
||||||
|
avatar_url=avatar_url,
|
||||||
|
)
|
||||||
|
self._db.add(linked_account)
|
||||||
|
await self._db.commit()
|
||||||
|
await self._db.refresh(linked_account)
|
||||||
|
return _to_dto(linked_account)
|
||||||
|
|
||||||
|
async def delete(self, user_id: UUID, provider: str) -> bool:
|
||||||
|
"""Unlink an OAuth provider from a user account.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
provider: OAuth provider name to unlink.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if not found.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(OAuthLinkedAccount).where(
|
||||||
|
OAuthLinkedAccount.user_id == str(user_id),
|
||||||
|
OAuthLinkedAccount.provider == provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
linked_account = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if linked_account is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self._db.delete(linked_account)
|
||||||
|
await self._db.commit()
|
||||||
|
return True
|
||||||
243
backend/app/repositories/postgres/user.py
Normal file
243
backend/app/repositories/postgres/user.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
"""PostgreSQL implementation of UserRepository.
|
||||||
|
|
||||||
|
This module provides the PostgreSQL-specific implementation of the
|
||||||
|
UserRepository protocol using SQLAlchemy async sessions.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
async with get_db_session() as db:
|
||||||
|
repo = PostgresUserRepository(db)
|
||||||
|
user = await repo.get_by_id(user_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db.models.user import User
|
||||||
|
from app.repositories.protocols import UNSET, UserEntry
|
||||||
|
|
||||||
|
|
||||||
|
def _to_dto(model: User) -> UserEntry:
|
||||||
|
"""Convert ORM model to DTO."""
|
||||||
|
return UserEntry(
|
||||||
|
id=model.id,
|
||||||
|
email=model.email,
|
||||||
|
display_name=model.display_name,
|
||||||
|
avatar_url=model.avatar_url,
|
||||||
|
oauth_provider=model.oauth_provider,
|
||||||
|
oauth_id=model.oauth_id,
|
||||||
|
is_premium=model.is_premium,
|
||||||
|
premium_until=model.premium_until,
|
||||||
|
last_login=model.last_login,
|
||||||
|
created_at=model.created_at,
|
||||||
|
updated_at=model.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresUserRepository:
|
||||||
|
"""PostgreSQL implementation of UserRepository.
|
||||||
|
|
||||||
|
Uses SQLAlchemy async sessions for database access.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_db: The async database session.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
|
"""Initialize with database session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: SQLAlchemy async session.
|
||||||
|
"""
|
||||||
|
self._db = db
|
||||||
|
|
||||||
|
async def get_by_id(self, user_id: UUID) -> UserEntry | None:
|
||||||
|
"""Get a user by their ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserEntry if found, None otherwise.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
return _to_dto(model) if model else None
|
||||||
|
|
||||||
|
async def get_by_email(self, email: str) -> UserEntry | None:
|
||||||
|
"""Get a user by their email address.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
email: The user's email address.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserEntry if found, None otherwise.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(select(User).where(User.email == email))
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
return _to_dto(model) if model else None
|
||||||
|
|
||||||
|
async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None:
|
||||||
|
"""Get a user by their OAuth provider and ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: OAuth provider name (google, discord).
|
||||||
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserEntry if found, None otherwise.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(User).where(
|
||||||
|
User.oauth_provider == provider,
|
||||||
|
User.oauth_id == oauth_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
return _to_dto(model) if model else None
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
email: str,
|
||||||
|
display_name: str,
|
||||||
|
oauth_provider: str,
|
||||||
|
oauth_id: str,
|
||||||
|
avatar_url: str | None = None,
|
||||||
|
) -> UserEntry:
|
||||||
|
"""Create a new user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
email: User's email address.
|
||||||
|
display_name: Public display name.
|
||||||
|
oauth_provider: OAuth provider name.
|
||||||
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
avatar_url: Optional avatar URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created UserEntry.
|
||||||
|
"""
|
||||||
|
user = User(
|
||||||
|
email=email,
|
||||||
|
display_name=display_name,
|
||||||
|
oauth_provider=oauth_provider,
|
||||||
|
oauth_id=oauth_id,
|
||||||
|
avatar_url=avatar_url,
|
||||||
|
)
|
||||||
|
self._db.add(user)
|
||||||
|
await self._db.commit()
|
||||||
|
await self._db.refresh(user)
|
||||||
|
return _to_dto(user)
|
||||||
|
|
||||||
|
async def update(
|
||||||
|
self,
|
||||||
|
user_id: UUID,
|
||||||
|
display_name: str | None = None,
|
||||||
|
avatar_url: str | None = UNSET, # type: ignore[assignment]
|
||||||
|
oauth_provider: str | None = None,
|
||||||
|
oauth_id: str | None = None,
|
||||||
|
) -> UserEntry | None:
|
||||||
|
"""Update user profile fields.
|
||||||
|
|
||||||
|
Only provided (non-None/non-UNSET) fields are updated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
display_name: New display name (None keeps existing).
|
||||||
|
avatar_url: New avatar URL (UNSET=keep, None=clear, str=set).
|
||||||
|
oauth_provider: New OAuth provider (None keeps existing).
|
||||||
|
oauth_id: New OAuth ID (None keeps existing).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated UserEntry, or None if user not found.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if display_name is not None:
|
||||||
|
user.display_name = display_name
|
||||||
|
if avatar_url is not UNSET:
|
||||||
|
user.avatar_url = avatar_url
|
||||||
|
if oauth_provider is not None:
|
||||||
|
user.oauth_provider = oauth_provider
|
||||||
|
if oauth_id is not None:
|
||||||
|
user.oauth_id = oauth_id
|
||||||
|
|
||||||
|
await self._db.commit()
|
||||||
|
await self._db.refresh(user)
|
||||||
|
return _to_dto(user)
|
||||||
|
|
||||||
|
async def update_last_login(self, user_id: UUID) -> UserEntry | None:
|
||||||
|
"""Update the user's last login timestamp to now.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated UserEntry, or None if user not found.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user.last_login = datetime.now(UTC)
|
||||||
|
await self._db.commit()
|
||||||
|
await self._db.refresh(user)
|
||||||
|
return _to_dto(user)
|
||||||
|
|
||||||
|
async def update_premium(
|
||||||
|
self,
|
||||||
|
user_id: UUID,
|
||||||
|
is_premium: bool,
|
||||||
|
premium_until: datetime | None,
|
||||||
|
) -> UserEntry | None:
|
||||||
|
"""Update user's premium subscription status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
is_premium: Whether user has premium.
|
||||||
|
premium_until: When premium expires, or None if not premium.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated UserEntry, or None if user not found.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user.is_premium = is_premium
|
||||||
|
user.premium_until = premium_until
|
||||||
|
|
||||||
|
await self._db.commit()
|
||||||
|
await self._db.refresh(user)
|
||||||
|
return _to_dto(user)
|
||||||
|
|
||||||
|
async def delete(self, user_id: UUID) -> bool:
|
||||||
|
"""Delete a user account.
|
||||||
|
|
||||||
|
This will cascade delete all related data (decks, collection, etc.)
|
||||||
|
based on database constraints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if not found.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self._db.delete(user)
|
||||||
|
await self._db.commit()
|
||||||
|
return True
|
||||||
@ -102,6 +102,46 @@ class DeckEntry:
|
|||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UserEntry:
|
||||||
|
"""Storage-agnostic representation of a user account.
|
||||||
|
|
||||||
|
This DTO decouples the service layer from the ORM model,
|
||||||
|
enabling different storage backends (PostgreSQL, SQLite, JSON)
|
||||||
|
to be used interchangeably.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
email: str
|
||||||
|
display_name: str
|
||||||
|
avatar_url: str | None
|
||||||
|
oauth_provider: str
|
||||||
|
oauth_id: str
|
||||||
|
is_premium: bool
|
||||||
|
premium_until: datetime | None
|
||||||
|
last_login: datetime | None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LinkedAccountEntry:
|
||||||
|
"""Storage-agnostic representation of a linked OAuth account.
|
||||||
|
|
||||||
|
Users can link multiple OAuth providers (e.g., Google + Discord)
|
||||||
|
to a single account for flexible login options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
user_id: UUID
|
||||||
|
provider: str
|
||||||
|
oauth_id: str
|
||||||
|
email: str | None
|
||||||
|
display_name: str | None
|
||||||
|
avatar_url: str | None
|
||||||
|
linked_at: datetime
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Repository Protocols
|
# Repository Protocols
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -342,3 +382,211 @@ class DeckRepository(Protocol):
|
|||||||
Tuple of (has_starter, starter_type).
|
Tuple of (has_starter, starter_type).
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class UserRepository(Protocol):
|
||||||
|
"""Protocol for user account data access.
|
||||||
|
|
||||||
|
Implementations handle storage-specific details (PostgreSQL, SQLite, JSON).
|
||||||
|
Services use this protocol for business logic without knowing storage details.
|
||||||
|
|
||||||
|
Note: Business logic like get_or_create_from_oauth belongs in the service layer,
|
||||||
|
not in the repository.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get_by_id(self, user_id: UUID) -> UserEntry | None:
|
||||||
|
"""Get a user by their ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserEntry if found, None otherwise.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_by_email(self, email: str) -> UserEntry | None:
|
||||||
|
"""Get a user by their email address.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
email: The user's email address.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserEntry if found, None otherwise.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None:
|
||||||
|
"""Get a user by their OAuth provider and ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: OAuth provider name (google, discord).
|
||||||
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserEntry if found, None otherwise.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
email: str,
|
||||||
|
display_name: str,
|
||||||
|
oauth_provider: str,
|
||||||
|
oauth_id: str,
|
||||||
|
avatar_url: str | None = None,
|
||||||
|
) -> UserEntry:
|
||||||
|
"""Create a new user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
email: User's email address.
|
||||||
|
display_name: Public display name.
|
||||||
|
oauth_provider: OAuth provider name.
|
||||||
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
avatar_url: Optional avatar URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created UserEntry.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def update(
|
||||||
|
self,
|
||||||
|
user_id: UUID,
|
||||||
|
display_name: str | None = None,
|
||||||
|
avatar_url: str | None = UNSET, # type: ignore[assignment]
|
||||||
|
oauth_provider: str | None = None,
|
||||||
|
oauth_id: str | None = None,
|
||||||
|
) -> UserEntry | None:
|
||||||
|
"""Update user profile fields.
|
||||||
|
|
||||||
|
Only provided (non-None/non-UNSET) fields are updated.
|
||||||
|
Use UNSET (default) to keep existing value for nullable fields,
|
||||||
|
or None to explicitly clear them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
display_name: New display name (None keeps existing).
|
||||||
|
avatar_url: New avatar URL (UNSET=keep, None=clear, str=set).
|
||||||
|
oauth_provider: New OAuth provider (None keeps existing).
|
||||||
|
oauth_id: New OAuth ID (None keeps existing).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated UserEntry, or None if user not found.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def update_last_login(self, user_id: UUID) -> UserEntry | None:
|
||||||
|
"""Update the user's last login timestamp to now.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated UserEntry, or None if user not found.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def update_premium(
|
||||||
|
self,
|
||||||
|
user_id: UUID,
|
||||||
|
is_premium: bool,
|
||||||
|
premium_until: datetime | None,
|
||||||
|
) -> UserEntry | None:
|
||||||
|
"""Update user's premium subscription status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
is_premium: Whether user has premium.
|
||||||
|
premium_until: When premium expires, or None if not premium.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated UserEntry, or None if user not found.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def delete(self, user_id: UUID) -> bool:
|
||||||
|
"""Delete a user account.
|
||||||
|
|
||||||
|
This will cascade delete all related data (decks, collection, etc.)
|
||||||
|
based on database constraints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if not found.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class LinkedAccountRepository(Protocol):
|
||||||
|
"""Protocol for OAuth linked accounts data access.
|
||||||
|
|
||||||
|
Users can link multiple OAuth providers to a single account.
|
||||||
|
The primary OAuth provider is stored on the User model itself;
|
||||||
|
additional linked providers are stored as LinkedAccount records.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get_by_provider(
|
||||||
|
self,
|
||||||
|
provider: str,
|
||||||
|
oauth_id: str,
|
||||||
|
) -> LinkedAccountEntry | None:
|
||||||
|
"""Get a linked account by provider and OAuth ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: OAuth provider name (google, discord).
|
||||||
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LinkedAccountEntry if found, None otherwise.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_by_user(self, user_id: UUID) -> list[LinkedAccountEntry]:
|
||||||
|
"""Get all linked accounts for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LinkedAccountEntry, ordered by provider.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
user_id: UUID,
|
||||||
|
provider: str,
|
||||||
|
oauth_id: str,
|
||||||
|
email: str | None = None,
|
||||||
|
display_name: str | None = None,
|
||||||
|
avatar_url: str | None = None,
|
||||||
|
) -> LinkedAccountEntry:
|
||||||
|
"""Link an OAuth provider to a user account.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
provider: OAuth provider name.
|
||||||
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
email: Email from the OAuth provider.
|
||||||
|
display_name: Display name from the OAuth provider.
|
||||||
|
avatar_url: Avatar URL from the OAuth provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created LinkedAccountEntry.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def delete(self, user_id: UUID, provider: str) -> bool:
|
||||||
|
"""Unlink an OAuth provider from a user account.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
provider: OAuth provider name to unlink.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if not found.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|||||||
@ -1,32 +1,37 @@
|
|||||||
"""User service for Mantimon TCG.
|
"""User service for Mantimon TCG.
|
||||||
|
|
||||||
This module provides async CRUD operations for user accounts,
|
This module provides business logic for user account operations,
|
||||||
including OAuth-based user creation and premium status management.
|
including OAuth-based user creation, account linking, and premium
|
||||||
|
status management.
|
||||||
|
|
||||||
All database operations use async SQLAlchemy sessions.
|
The service layer contains business logic while repositories handle
|
||||||
|
pure data access. This separation enables testing and different
|
||||||
|
storage backends.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
from app.services.user_service import user_service
|
from app.services.user_service import UserService
|
||||||
|
from app.repositories.postgres import PostgresUserRepository, PostgresLinkedAccountRepository
|
||||||
|
|
||||||
# Get user by ID
|
# Create service with injected repositories
|
||||||
user = await user_service.get_by_id(db, user_id)
|
user_repo = PostgresUserRepository(db)
|
||||||
|
linked_repo = PostgresLinkedAccountRepository(db)
|
||||||
|
service = UserService(user_repo, linked_repo)
|
||||||
|
|
||||||
# Create from OAuth
|
# Use service
|
||||||
user = await user_service.create_from_oauth(db, oauth_info)
|
user = await service.get_by_id(user_id)
|
||||||
|
user, created = await service.get_or_create_from_oauth(oauth_info)
|
||||||
# Update premium status
|
|
||||||
user = await user_service.update_premium(db, user_id, premium_until)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import select
|
from app.repositories.protocols import (
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
LinkedAccountEntry,
|
||||||
|
LinkedAccountRepository,
|
||||||
from app.db.models.oauth_account import OAuthLinkedAccount
|
UserEntry,
|
||||||
from app.db.models.user import User
|
UserRepository,
|
||||||
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
|
)
|
||||||
|
from app.schemas.user import OAuthUserInfo, UserUpdate
|
||||||
|
|
||||||
|
|
||||||
class AccountLinkingError(Exception):
|
class AccountLinkingError(Exception):
|
||||||
@ -38,117 +43,119 @@ class AccountLinkingError(Exception):
|
|||||||
class UserService:
|
class UserService:
|
||||||
"""Service for user account operations.
|
"""Service for user account operations.
|
||||||
|
|
||||||
Provides async methods for user CRUD, OAuth-based creation,
|
Provides business logic for user CRUD, OAuth-based creation,
|
||||||
and premium subscription management.
|
account linking, and premium subscription management.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_user_repo: Repository for user data access.
|
||||||
|
_linked_repo: Repository for linked account data access.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def get_by_id(self, db: AsyncSession, user_id: UUID) -> User | None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_repository: UserRepository,
|
||||||
|
linked_account_repository: LinkedAccountRepository,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize with repository dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_repository: Repository for user data access.
|
||||||
|
linked_account_repository: Repository for linked account data access.
|
||||||
|
"""
|
||||||
|
self._user_repo = user_repository
|
||||||
|
self._linked_repo = linked_account_repository
|
||||||
|
|
||||||
|
async def get_by_id(self, user_id: UUID) -> UserEntry | None:
|
||||||
"""Get a user by their ID.
|
"""Get a user by their ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Async database session.
|
|
||||||
user_id: The user's UUID.
|
user_id: The user's UUID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
User if found, None otherwise.
|
UserEntry if found, None otherwise.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
user = await user_service.get_by_id(db, user_id)
|
user = await service.get_by_id(user_id)
|
||||||
if user:
|
if user:
|
||||||
print(f"Found user: {user.display_name}")
|
print(f"Found user: {user.display_name}")
|
||||||
"""
|
"""
|
||||||
result = await db.execute(select(User).where(User.id == user_id))
|
return await self._user_repo.get_by_id(user_id)
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
|
async def get_by_email(self, email: str) -> UserEntry | None:
|
||||||
"""Get a user by their email address.
|
"""Get a user by their email address.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Async database session.
|
|
||||||
email: The user's email address.
|
email: The user's email address.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
User if found, None otherwise.
|
UserEntry if found, None otherwise.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
user = await user_service.get_by_email(db, "player@example.com")
|
user = await service.get_by_email("player@example.com")
|
||||||
"""
|
"""
|
||||||
result = await db.execute(select(User).where(User.email == email))
|
return await self._user_repo.get_by_email(email)
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
async def get_by_oauth(
|
async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None:
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
provider: str,
|
|
||||||
oauth_id: str,
|
|
||||||
) -> User | None:
|
|
||||||
"""Get a user by their OAuth provider and ID.
|
"""Get a user by their OAuth provider and ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Async database session.
|
|
||||||
provider: OAuth provider name (google, discord).
|
provider: OAuth provider name (google, discord).
|
||||||
oauth_id: Unique ID from the OAuth provider.
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
User if found, None otherwise.
|
UserEntry if found, None otherwise.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
user = await user_service.get_by_oauth(db, "google", "123456789")
|
user = await service.get_by_oauth("google", "123456789")
|
||||||
"""
|
"""
|
||||||
result = await db.execute(
|
return await self._user_repo.get_by_oauth(provider, oauth_id)
|
||||||
select(User).where(
|
|
||||||
User.oauth_provider == provider,
|
|
||||||
User.oauth_id == oauth_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
async def create(self, db: AsyncSession, user_data: UserCreate) -> User:
|
async def create(
|
||||||
|
self,
|
||||||
|
email: str,
|
||||||
|
display_name: str,
|
||||||
|
oauth_provider: str,
|
||||||
|
oauth_id: str,
|
||||||
|
avatar_url: str | None = None,
|
||||||
|
) -> UserEntry:
|
||||||
"""Create a new user.
|
"""Create a new user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Async database session.
|
email: User's email address.
|
||||||
user_data: User creation data.
|
display_name: Public display name.
|
||||||
|
oauth_provider: OAuth provider name.
|
||||||
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
avatar_url: Optional avatar URL.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The created User instance.
|
The created UserEntry.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
user_data = UserCreate(
|
user = await service.create(
|
||||||
email="player@example.com",
|
email="player@example.com",
|
||||||
display_name="Player1",
|
display_name="Player1",
|
||||||
oauth_provider="google",
|
oauth_provider="google",
|
||||||
oauth_id="123456789"
|
oauth_id="123456789"
|
||||||
)
|
)
|
||||||
user = await user_service.create(db, user_data)
|
|
||||||
"""
|
"""
|
||||||
user = User(
|
return await self._user_repo.create(
|
||||||
email=user_data.email,
|
email=email,
|
||||||
display_name=user_data.display_name,
|
display_name=display_name,
|
||||||
avatar_url=user_data.avatar_url,
|
oauth_provider=oauth_provider,
|
||||||
oauth_provider=user_data.oauth_provider,
|
oauth_id=oauth_id,
|
||||||
oauth_id=user_data.oauth_id,
|
avatar_url=avatar_url,
|
||||||
)
|
)
|
||||||
db.add(user)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(user)
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def create_from_oauth(
|
async def create_from_oauth(self, oauth_info: OAuthUserInfo) -> UserEntry:
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
oauth_info: OAuthUserInfo,
|
|
||||||
) -> User:
|
|
||||||
"""Create a new user from OAuth provider info.
|
"""Create a new user from OAuth provider info.
|
||||||
|
|
||||||
Convenience method that converts OAuthUserInfo to UserCreate.
|
Convenience method that extracts fields from OAuthUserInfo.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Async database session.
|
|
||||||
oauth_info: Normalized OAuth user information.
|
oauth_info: Normalized OAuth user information.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The created User instance.
|
The created UserEntry.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
oauth_info = OAuthUserInfo(
|
oauth_info = OAuthUserInfo(
|
||||||
@ -158,209 +165,212 @@ class UserService:
|
|||||||
name="Player One",
|
name="Player One",
|
||||||
avatar_url="https://..."
|
avatar_url="https://..."
|
||||||
)
|
)
|
||||||
user = await user_service.create_from_oauth(db, oauth_info)
|
user = await service.create_from_oauth(oauth_info)
|
||||||
"""
|
"""
|
||||||
user_data = oauth_info.to_user_create()
|
return await self.create(
|
||||||
return await self.create(db, user_data)
|
email=oauth_info.email,
|
||||||
|
display_name=oauth_info.name,
|
||||||
|
oauth_provider=oauth_info.provider,
|
||||||
|
oauth_id=oauth_info.oauth_id,
|
||||||
|
avatar_url=oauth_info.avatar_url,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_or_create_from_oauth(
|
async def get_or_create_from_oauth(
|
||||||
self,
|
self,
|
||||||
db: AsyncSession,
|
|
||||||
oauth_info: OAuthUserInfo,
|
oauth_info: OAuthUserInfo,
|
||||||
) -> tuple[User, bool]:
|
) -> tuple[UserEntry, bool]:
|
||||||
"""Get existing user or create new one from OAuth info.
|
"""Get existing user or create new one from OAuth info.
|
||||||
|
|
||||||
First checks for existing user by OAuth provider+ID, then by email
|
First checks for existing user by OAuth provider+ID, then by email
|
||||||
(for account linking), and finally creates a new user if not found.
|
(for account linking), and finally creates a new user if not found.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Async database session.
|
|
||||||
oauth_info: Normalized OAuth user information.
|
oauth_info: Normalized OAuth user information.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (User, created) where created is True if new user.
|
Tuple of (UserEntry, created) where created is True if new user.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
user, created = await user_service.get_or_create_from_oauth(db, oauth_info)
|
user, created = await service.get_or_create_from_oauth(oauth_info)
|
||||||
if created:
|
if created:
|
||||||
print("Welcome, new user!")
|
print("Welcome, new user!")
|
||||||
else:
|
else:
|
||||||
print("Welcome back!")
|
print("Welcome back!")
|
||||||
"""
|
"""
|
||||||
# First, check by OAuth provider + ID (exact match)
|
# First, check by OAuth provider + ID (exact match)
|
||||||
user = await self.get_by_oauth(db, oauth_info.provider, oauth_info.oauth_id)
|
user = await self._user_repo.get_by_oauth(oauth_info.provider, oauth_info.oauth_id)
|
||||||
if user:
|
if user:
|
||||||
return user, False
|
return user, False
|
||||||
|
|
||||||
# Check by email for potential account linking
|
# Check by email for potential account linking
|
||||||
# If user exists with same email but different OAuth, update their OAuth
|
# If user exists with same email but different OAuth, update their OAuth
|
||||||
user = await self.get_by_email(db, oauth_info.email)
|
user = await self._user_repo.get_by_email(oauth_info.email)
|
||||||
if user:
|
if user:
|
||||||
# Update OAuth credentials for existing user
|
# Update OAuth credentials for existing user
|
||||||
# This links the new OAuth provider to the existing account
|
# This links the new OAuth provider to the existing account
|
||||||
user.oauth_provider = oauth_info.provider
|
updated_user = await self._user_repo.update(
|
||||||
user.oauth_id = oauth_info.oauth_id
|
user_id=user.id,
|
||||||
# Optionally update avatar if not set
|
oauth_provider=oauth_info.provider,
|
||||||
if not user.avatar_url and oauth_info.avatar_url:
|
oauth_id=oauth_info.oauth_id,
|
||||||
user.avatar_url = oauth_info.avatar_url
|
avatar_url=oauth_info.avatar_url if not user.avatar_url else None,
|
||||||
await db.commit()
|
)
|
||||||
await db.refresh(user)
|
return updated_user or user, False
|
||||||
return user, False
|
|
||||||
|
|
||||||
# Create new user
|
# Create new user
|
||||||
user = await self.create_from_oauth(db, oauth_info)
|
user = await self.create_from_oauth(oauth_info)
|
||||||
return user, True
|
return user, True
|
||||||
|
|
||||||
async def update(
|
async def update(
|
||||||
self,
|
self,
|
||||||
db: AsyncSession,
|
user_id: UUID,
|
||||||
user: User,
|
|
||||||
update_data: UserUpdate,
|
update_data: UserUpdate,
|
||||||
) -> User:
|
) -> UserEntry | None:
|
||||||
"""Update user profile fields.
|
"""Update user profile fields.
|
||||||
|
|
||||||
Only updates fields that are provided (not None).
|
Only updates fields that are explicitly provided. Uses UNSET pattern
|
||||||
|
for avatar_url to distinguish "not provided" from "set to None".
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Async database session.
|
user_id: The user's UUID.
|
||||||
user: The user to update.
|
|
||||||
update_data: Fields to update.
|
update_data: Fields to update.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The updated User instance.
|
The updated UserEntry, or None if not found.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
update_data = UserUpdate(display_name="New Name")
|
update_data = UserUpdate(display_name="New Name")
|
||||||
user = await user_service.update(db, user, update_data)
|
user = await service.update(user_id, update_data)
|
||||||
"""
|
"""
|
||||||
if update_data.display_name is not None:
|
from app.repositories.protocols import UNSET
|
||||||
user.display_name = update_data.display_name
|
|
||||||
if update_data.avatar_url is not None:
|
|
||||||
user.avatar_url = update_data.avatar_url
|
|
||||||
|
|
||||||
await db.commit()
|
# Use UNSET for avatar_url unless explicitly provided
|
||||||
await db.refresh(user)
|
avatar_url = (
|
||||||
return user
|
update_data.avatar_url if "avatar_url" in update_data.model_fields_set else UNSET
|
||||||
|
)
|
||||||
|
|
||||||
async def update_last_login(self, db: AsyncSession, user: User) -> User:
|
return await self._user_repo.update(
|
||||||
|
user_id=user_id,
|
||||||
|
display_name=update_data.display_name,
|
||||||
|
avatar_url=avatar_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_last_login(self, user_id: UUID) -> UserEntry | None:
|
||||||
"""Update the user's last login timestamp.
|
"""Update the user's last login timestamp.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Async database session.
|
user_id: The user's UUID.
|
||||||
user: The user to update.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The updated User instance.
|
The updated UserEntry, or None if not found.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
user = await user_service.update_last_login(db, user)
|
user = await service.update_last_login(user_id)
|
||||||
"""
|
"""
|
||||||
user.last_login = datetime.now(UTC)
|
return await self._user_repo.update_last_login(user_id)
|
||||||
await db.commit()
|
|
||||||
await db.refresh(user)
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def update_premium(
|
async def update_premium(
|
||||||
self,
|
self,
|
||||||
db: AsyncSession,
|
user_id: UUID,
|
||||||
user: User,
|
|
||||||
premium_until: datetime | None,
|
premium_until: datetime | None,
|
||||||
) -> User:
|
) -> UserEntry | None:
|
||||||
"""Update user's premium subscription status.
|
"""Update user's premium subscription status.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Async database session.
|
user_id: The user's UUID.
|
||||||
user: The user to update.
|
|
||||||
premium_until: When premium expires, or None to remove premium.
|
premium_until: When premium expires, or None to remove premium.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The updated User instance.
|
The updated UserEntry, or None if not found.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
# Grant 30 days of premium
|
# Grant 30 days of premium
|
||||||
expires = datetime.now(UTC) + timedelta(days=30)
|
expires = datetime.now(UTC) + timedelta(days=30)
|
||||||
user = await user_service.update_premium(db, user, expires)
|
user = await service.update_premium(user_id, expires)
|
||||||
|
|
||||||
# Remove premium
|
# Remove premium
|
||||||
user = await user_service.update_premium(db, user, None)
|
user = await service.update_premium(user_id, None)
|
||||||
"""
|
"""
|
||||||
if premium_until is not None:
|
is_premium = premium_until is not None
|
||||||
user.is_premium = True
|
return await self._user_repo.update_premium(
|
||||||
user.premium_until = premium_until
|
user_id=user_id,
|
||||||
else:
|
is_premium=is_premium,
|
||||||
user.is_premium = False
|
premium_until=premium_until,
|
||||||
user.premium_until = None
|
)
|
||||||
|
|
||||||
await db.commit()
|
async def delete(self, user_id: UUID) -> bool:
|
||||||
await db.refresh(user)
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def delete(self, db: AsyncSession, user: User) -> None:
|
|
||||||
"""Delete a user account.
|
"""Delete a user account.
|
||||||
|
|
||||||
This will cascade delete all related data (decks, collection, etc.)
|
This will cascade delete all related data (decks, collection, etc.)
|
||||||
based on the model relationships.
|
based on the database constraints.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Async database session.
|
user_id: The user's UUID.
|
||||||
user: The user to delete.
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if not found.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
await user_service.delete(db, user)
|
success = await service.delete(user_id)
|
||||||
"""
|
"""
|
||||||
await db.delete(user)
|
return await self._user_repo.delete(user_id)
|
||||||
await db.commit()
|
|
||||||
|
# =========================================================================
|
||||||
|
# Linked Account Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def get_linked_accounts(self, user_id: UUID) -> list[LinkedAccountEntry]:
|
||||||
|
"""Get all linked OAuth accounts for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's UUID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LinkedAccountEntry.
|
||||||
|
"""
|
||||||
|
return await self._linked_repo.get_by_user(user_id)
|
||||||
|
|
||||||
async def get_linked_account(
|
async def get_linked_account(
|
||||||
self,
|
self,
|
||||||
db: AsyncSession,
|
|
||||||
provider: str,
|
provider: str,
|
||||||
oauth_id: str,
|
oauth_id: str,
|
||||||
) -> OAuthLinkedAccount | None:
|
) -> LinkedAccountEntry | None:
|
||||||
"""Get a linked account by provider and OAuth ID.
|
"""Get a linked account by provider and OAuth ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Async database session.
|
|
||||||
provider: OAuth provider name (google, discord).
|
provider: OAuth provider name (google, discord).
|
||||||
oauth_id: Unique ID from the OAuth provider.
|
oauth_id: Unique ID from the OAuth provider.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
OAuthLinkedAccount if found, None otherwise.
|
LinkedAccountEntry if found, None otherwise.
|
||||||
"""
|
"""
|
||||||
result = await db.execute(
|
return await self._linked_repo.get_by_provider(provider, oauth_id)
|
||||||
select(OAuthLinkedAccount).where(
|
|
||||||
OAuthLinkedAccount.provider == provider,
|
|
||||||
OAuthLinkedAccount.oauth_id == oauth_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
async def link_oauth_account(
|
async def link_oauth_account(
|
||||||
self,
|
self,
|
||||||
db: AsyncSession,
|
user_id: UUID,
|
||||||
user: User,
|
user_oauth_provider: str,
|
||||||
oauth_info: OAuthUserInfo,
|
oauth_info: OAuthUserInfo,
|
||||||
) -> OAuthLinkedAccount:
|
) -> LinkedAccountEntry:
|
||||||
"""Link an additional OAuth provider to a user account.
|
"""Link an additional OAuth provider to a user account.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Async database session.
|
user_id: The user's UUID.
|
||||||
user: The user to link the account to.
|
user_oauth_provider: The user's primary OAuth provider.
|
||||||
oauth_info: OAuth information from the provider.
|
oauth_info: OAuth information from the provider.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The created OAuthLinkedAccount.
|
The created LinkedAccountEntry.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AccountLinkingError: If provider is already linked to this or another user.
|
AccountLinkingError: If provider is already linked to this or another user.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
linked = await user_service.link_oauth_account(db, user, discord_info)
|
linked = await service.link_oauth_account(user.id, user.oauth_provider, discord_info)
|
||||||
"""
|
"""
|
||||||
# Check if this provider+oauth_id is already linked to any user
|
# Check if this provider+oauth_id is already linked to any user
|
||||||
existing = await self.get_linked_account(db, oauth_info.provider, oauth_info.oauth_id)
|
existing = await self._linked_repo.get_by_provider(oauth_info.provider, oauth_info.oauth_id)
|
||||||
if existing:
|
if existing:
|
||||||
if str(existing.user_id) == str(user.id):
|
if existing.user_id == user_id:
|
||||||
raise AccountLinkingError(
|
raise AccountLinkingError(
|
||||||
f"{oauth_info.provider.title()} account is already linked to your account"
|
f"{oauth_info.provider.title()} account is already linked to your account"
|
||||||
)
|
)
|
||||||
@ -369,36 +379,33 @@ class UserService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if this is the user's primary OAuth provider
|
# Check if this is the user's primary OAuth provider
|
||||||
if user.oauth_provider == oauth_info.provider:
|
if user_oauth_provider == oauth_info.provider:
|
||||||
raise AccountLinkingError(
|
raise AccountLinkingError(
|
||||||
f"{oauth_info.provider.title()} is your primary login provider"
|
f"{oauth_info.provider.title()} is your primary login provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if user already has this provider linked
|
# Check if user already has this provider linked
|
||||||
for linked in user.linked_accounts:
|
linked_accounts = await self._linked_repo.get_by_user(user_id)
|
||||||
|
for linked in linked_accounts:
|
||||||
if linked.provider == oauth_info.provider:
|
if linked.provider == oauth_info.provider:
|
||||||
raise AccountLinkingError(
|
raise AccountLinkingError(
|
||||||
f"You already have a {oauth_info.provider.title()} account linked"
|
f"You already have a {oauth_info.provider.title()} account linked"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the linked account
|
# Create the linked account
|
||||||
linked_account = OAuthLinkedAccount(
|
return await self._linked_repo.create(
|
||||||
user_id=str(user.id),
|
user_id=user_id,
|
||||||
provider=oauth_info.provider,
|
provider=oauth_info.provider,
|
||||||
oauth_id=oauth_info.oauth_id,
|
oauth_id=oauth_info.oauth_id,
|
||||||
email=oauth_info.email,
|
email=oauth_info.email,
|
||||||
display_name=oauth_info.name,
|
display_name=oauth_info.name,
|
||||||
avatar_url=oauth_info.avatar_url,
|
avatar_url=oauth_info.avatar_url,
|
||||||
)
|
)
|
||||||
db.add(linked_account)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(linked_account)
|
|
||||||
return linked_account
|
|
||||||
|
|
||||||
async def unlink_oauth_account(
|
async def unlink_oauth_account(
|
||||||
self,
|
self,
|
||||||
db: AsyncSession,
|
user_id: UUID,
|
||||||
user: User,
|
user_oauth_provider: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Unlink an OAuth provider from a user account.
|
"""Unlink an OAuth provider from a user account.
|
||||||
@ -406,8 +413,8 @@ class UserService:
|
|||||||
Cannot unlink the primary OAuth provider.
|
Cannot unlink the primary OAuth provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Async database session.
|
user_id: The user's UUID.
|
||||||
user: The user to unlink from.
|
user_oauth_provider: The user's primary OAuth provider.
|
||||||
provider: OAuth provider name to unlink.
|
provider: OAuth provider name to unlink.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -417,23 +424,12 @@ class UserService:
|
|||||||
AccountLinkingError: If trying to unlink the primary provider.
|
AccountLinkingError: If trying to unlink the primary provider.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
success = await user_service.unlink_oauth_account(db, user, "discord")
|
success = await service.unlink_oauth_account(user.id, user.oauth_provider, "discord")
|
||||||
"""
|
"""
|
||||||
# Cannot unlink primary provider
|
# Cannot unlink primary provider
|
||||||
if user.oauth_provider == provider:
|
if user_oauth_provider == provider:
|
||||||
raise AccountLinkingError(
|
raise AccountLinkingError(
|
||||||
f"Cannot unlink {provider.title()} - it is your primary login provider"
|
f"Cannot unlink {provider.title()} - it is your primary login provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find and delete the linked account
|
return await self._linked_repo.delete(user_id, provider)
|
||||||
for linked in user.linked_accounts:
|
|
||||||
if linked.provider == provider:
|
|
||||||
await db.delete(linked)
|
|
||||||
await db.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# Global service instance
|
|
||||||
user_service = UserService()
|
|
||||||
|
|||||||
@ -2,14 +2,19 @@
|
|||||||
|
|
||||||
Tests the authentication endpoints including OAuth redirects,
|
Tests the authentication endpoints including OAuth redirects,
|
||||||
token refresh, and logout.
|
token refresh, and logout.
|
||||||
|
|
||||||
|
Uses FastAPI's dependency override pattern for proper dependency injection testing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
from fastapi import status
|
from fastapi import status
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.api.deps import get_user_service
|
||||||
|
|
||||||
|
|
||||||
class TestGoogleAuthRedirect:
|
class TestGoogleAuthRedirect:
|
||||||
"""Tests for GET /api/auth/google endpoint."""
|
"""Tests for GET /api/auth/google endpoint."""
|
||||||
@ -50,11 +55,32 @@ class TestDiscordAuthRedirect:
|
|||||||
assert "not configured" in response.json()["detail"]
|
assert "not configured" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_user_service_instance():
|
||||||
|
"""Create a mock UserService for dependency injection.
|
||||||
|
|
||||||
|
Returns a MagicMock with async methods configured.
|
||||||
|
"""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.get_by_id = AsyncMock()
|
||||||
|
mock.get_by_email = AsyncMock()
|
||||||
|
mock.get_by_oauth = AsyncMock()
|
||||||
|
mock.get_or_create_from_oauth = AsyncMock()
|
||||||
|
mock.update_last_login = AsyncMock()
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
class TestRefreshTokens:
|
class TestRefreshTokens:
|
||||||
"""Tests for POST /api/auth/refresh endpoint."""
|
"""Tests for POST /api/auth/refresh endpoint."""
|
||||||
|
|
||||||
def test_returns_new_access_token(
|
def test_returns_new_access_token(
|
||||||
self, client: TestClient, test_user, refresh_token_data, mock_get_redis
|
self,
|
||||||
|
app,
|
||||||
|
client: TestClient,
|
||||||
|
test_user,
|
||||||
|
refresh_token_data,
|
||||||
|
mock_get_redis,
|
||||||
|
mock_user_service_instance,
|
||||||
):
|
):
|
||||||
"""Test that refresh endpoint returns new access token for valid refresh token.
|
"""Test that refresh endpoint returns new access token for valid refresh token.
|
||||||
|
|
||||||
@ -70,25 +96,31 @@ class TestRefreshTokens:
|
|||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(setup_token())
|
asyncio.get_event_loop().run_until_complete(setup_token())
|
||||||
|
|
||||||
# Mock user service to return our test user
|
# Configure mock to return test user
|
||||||
with patch("app.api.auth.user_service") as mock_user_service:
|
# Convert to UserEntry-like object
|
||||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
mock_user_entry = MagicMock()
|
||||||
|
mock_user_entry.id = test_user.id
|
||||||
|
mock_user_entry.email = test_user.email
|
||||||
|
mock_user_service_instance.get_by_id.return_value = mock_user_entry
|
||||||
|
|
||||||
with (
|
# Override the dependency on the test app
|
||||||
patch("app.api.auth.get_redis", mock_get_redis),
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
|
||||||
):
|
|
||||||
response = client.post(
|
|
||||||
"/api/auth/refresh",
|
|
||||||
json={"refresh_token": refresh_token_data["token"]},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
try:
|
||||||
data = response.json()
|
response = client.post(
|
||||||
assert "access_token" in data
|
"/api/auth/refresh",
|
||||||
assert data["refresh_token"] == refresh_token_data["token"]
|
json={"refresh_token": refresh_token_data["token"]},
|
||||||
assert data["token_type"] == "bearer"
|
)
|
||||||
assert "expires_in" in data
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert data["refresh_token"] == refresh_token_data["token"]
|
||||||
|
assert data["token_type"] == "bearer"
|
||||||
|
assert "expires_in" in data
|
||||||
|
finally:
|
||||||
|
# Clean up override
|
||||||
|
app.dependency_overrides.pop(get_user_service, None)
|
||||||
|
|
||||||
def test_returns_401_for_invalid_token(self, client: TestClient):
|
def test_returns_401_for_invalid_token(self, client: TestClient):
|
||||||
"""Test that refresh endpoint returns 401 for invalid refresh token."""
|
"""Test that refresh endpoint returns 401 for invalid refresh token."""
|
||||||
@ -107,21 +139,23 @@ class TestRefreshTokens:
|
|||||||
A refresh token not in Redis (revoked/expired) should be rejected.
|
A refresh token not in Redis (revoked/expired) should be rejected.
|
||||||
"""
|
"""
|
||||||
# Don't store the token in Redis - simulating revocation
|
# Don't store the token in Redis - simulating revocation
|
||||||
|
# The mock_get_redis is already patched via conftest's app fixture
|
||||||
|
|
||||||
with (
|
response = client.post(
|
||||||
patch("app.api.auth.get_redis", mock_get_redis),
|
"/api/auth/refresh",
|
||||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
json={"refresh_token": refresh_token_data["token"]},
|
||||||
):
|
)
|
||||||
response = client.post(
|
|
||||||
"/api/auth/refresh",
|
|
||||||
json={"refresh_token": refresh_token_data["token"]},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
assert "revoked" in response.json()["detail"]
|
assert "revoked" in response.json()["detail"]
|
||||||
|
|
||||||
def test_returns_401_for_deleted_user(
|
def test_returns_401_for_deleted_user(
|
||||||
self, client: TestClient, refresh_token_data, mock_get_redis
|
self,
|
||||||
|
app,
|
||||||
|
client: TestClient,
|
||||||
|
refresh_token_data,
|
||||||
|
mock_get_redis,
|
||||||
|
mock_user_service_instance,
|
||||||
):
|
):
|
||||||
"""Test that refresh endpoint returns 401 if user no longer exists."""
|
"""Test that refresh endpoint returns 401 if user no longer exists."""
|
||||||
# Store the token
|
# Store the token
|
||||||
@ -134,21 +168,23 @@ class TestRefreshTokens:
|
|||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(setup_token())
|
asyncio.get_event_loop().run_until_complete(setup_token())
|
||||||
|
|
||||||
# Mock user service to return None (user deleted)
|
# Configure mock to return None (user deleted)
|
||||||
with patch("app.api.auth.user_service") as mock_user_service:
|
mock_user_service_instance.get_by_id.return_value = None
|
||||||
mock_user_service.get_by_id = AsyncMock(return_value=None)
|
|
||||||
|
|
||||||
with (
|
# Override the dependency on the test app
|
||||||
patch("app.api.auth.get_redis", mock_get_redis),
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
|
||||||
):
|
|
||||||
response = client.post(
|
|
||||||
"/api/auth/refresh",
|
|
||||||
json={"refresh_token": refresh_token_data["token"]},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
try:
|
||||||
assert "User not found" in response.json()["detail"]
|
response = client.post(
|
||||||
|
"/api/auth/refresh",
|
||||||
|
json={"refresh_token": refresh_token_data["token"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
assert "User not found" in response.json()["detail"]
|
||||||
|
finally:
|
||||||
|
# Clean up override
|
||||||
|
app.dependency_overrides.pop(get_user_service, None)
|
||||||
|
|
||||||
|
|
||||||
class TestLogout:
|
class TestLogout:
|
||||||
@ -171,14 +207,10 @@ class TestLogout:
|
|||||||
key = asyncio.get_event_loop().run_until_complete(setup_and_check())
|
key = asyncio.get_event_loop().run_until_complete(setup_and_check())
|
||||||
|
|
||||||
# Logout
|
# Logout
|
||||||
with (
|
response = client.post(
|
||||||
patch("app.api.auth.get_redis", mock_get_redis),
|
"/api/auth/logout",
|
||||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
json={"refresh_token": refresh_token_data["token"]},
|
||||||
):
|
)
|
||||||
response = client.post(
|
|
||||||
"/api/auth/logout",
|
|
||||||
json={"refresh_token": refresh_token_data["token"]},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
@ -214,7 +246,9 @@ class TestLogoutAll:
|
|||||||
response = client.post("/api/auth/logout-all")
|
response = client.post("/api/auth/logout-all")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
def test_revokes_all_tokens(self, client: TestClient, test_user, access_token, mock_get_redis):
|
def test_revokes_all_tokens(
|
||||||
|
self, app, client: TestClient, test_user, access_token, mock_get_redis, mock_db_session
|
||||||
|
):
|
||||||
"""Test that logout-all revokes all refresh tokens for user.
|
"""Test that logout-all revokes all refresh tokens for user.
|
||||||
|
|
||||||
Should delete all tokens matching the user's ID pattern.
|
Should delete all tokens matching the user's ID pattern.
|
||||||
@ -233,18 +267,16 @@ class TestLogoutAll:
|
|||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
||||||
|
|
||||||
# Mock dependencies
|
# Set up mock db session to return test user when queried
|
||||||
with patch("app.api.deps.user_service") as mock_user_service:
|
# The get_current_user dependency now does a direct DB query
|
||||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = test_user
|
||||||
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
with (
|
response = client.post(
|
||||||
patch("app.api.auth.get_redis", mock_get_redis),
|
"/api/auth/logout-all",
|
||||||
patch("app.services.token_store.get_redis", mock_get_redis),
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
):
|
)
|
||||||
response = client.post(
|
|
||||||
"/api/auth/logout-all",
|
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
|||||||
@ -1,32 +1,55 @@
|
|||||||
"""Tests for users API endpoints.
|
"""Tests for users API endpoints.
|
||||||
|
|
||||||
Tests the user profile management endpoints.
|
Tests the user profile management endpoints.
|
||||||
|
|
||||||
|
Uses FastAPI's dependency override pattern for proper dependency injection testing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
from fastapi import status
|
from fastapi import status
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.api.deps import get_user_service
|
||||||
from app.services.user_service import AccountLinkingError
|
from app.services.user_service import AccountLinkingError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_user_service_instance():
|
||||||
|
"""Create a mock UserService for dependency injection.
|
||||||
|
|
||||||
|
Returns a MagicMock with async methods configured.
|
||||||
|
"""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.get_by_id = AsyncMock()
|
||||||
|
mock.get_by_email = AsyncMock()
|
||||||
|
mock.get_by_oauth = AsyncMock()
|
||||||
|
mock.update = AsyncMock()
|
||||||
|
mock.unlink_oauth_account = AsyncMock()
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
class TestGetCurrentUser:
|
class TestGetCurrentUser:
|
||||||
"""Tests for GET /api/users/me endpoint."""
|
"""Tests for GET /api/users/me endpoint."""
|
||||||
|
|
||||||
def test_returns_user_profile(self, client: TestClient, test_user, access_token):
|
def test_returns_user_profile(
|
||||||
|
self, app, client: TestClient, test_user, access_token, mock_db_session
|
||||||
|
):
|
||||||
"""Test that endpoint returns user profile for authenticated user.
|
"""Test that endpoint returns user profile for authenticated user.
|
||||||
|
|
||||||
Should return the user's profile information.
|
Should return the user's profile information.
|
||||||
"""
|
"""
|
||||||
with patch("app.api.deps.user_service") as mock_user_service:
|
# Set up mock db session to return test user when queried
|
||||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = test_user
|
||||||
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
response = client.get(
|
response = client.get(
|
||||||
"/api/users/me",
|
"/api/users/me",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@ -52,47 +75,93 @@ class TestGetCurrentUser:
|
|||||||
class TestUpdateCurrentUser:
|
class TestUpdateCurrentUser:
|
||||||
"""Tests for PATCH /api/users/me endpoint."""
|
"""Tests for PATCH /api/users/me endpoint."""
|
||||||
|
|
||||||
def test_updates_display_name(self, client: TestClient, test_user, access_token):
|
def test_updates_display_name(
|
||||||
|
self,
|
||||||
|
app,
|
||||||
|
client: TestClient,
|
||||||
|
test_user,
|
||||||
|
access_token,
|
||||||
|
mock_db_session,
|
||||||
|
mock_user_service_instance,
|
||||||
|
):
|
||||||
"""Test that endpoint updates display_name when provided."""
|
"""Test that endpoint updates display_name when provided."""
|
||||||
updated_user = test_user
|
# Create an updated user mock
|
||||||
|
updated_user = MagicMock()
|
||||||
|
updated_user.id = test_user.id
|
||||||
|
updated_user.email = test_user.email
|
||||||
updated_user.display_name = "New Name"
|
updated_user.display_name = "New Name"
|
||||||
|
updated_user.avatar_url = test_user.avatar_url
|
||||||
|
updated_user.is_premium = test_user.is_premium
|
||||||
|
updated_user.premium_until = test_user.premium_until
|
||||||
|
updated_user.created_at = test_user.created_at
|
||||||
|
|
||||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
# Set up db session to return test user for authentication
|
||||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = test_user
|
||||||
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
with patch("app.api.users.user_service") as mock_user_service:
|
# Set up user service mock
|
||||||
mock_user_service.update = AsyncMock(return_value=updated_user)
|
mock_user_service_instance.update.return_value = updated_user
|
||||||
|
|
||||||
response = client.patch(
|
# Override the dependency
|
||||||
"/api/users/me",
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
|
||||||
json={"display_name": "New Name"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
try:
|
||||||
data = response.json()
|
response = client.patch(
|
||||||
assert data["display_name"] == "New Name"
|
"/api/users/me",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json={"display_name": "New Name"},
|
||||||
|
)
|
||||||
|
|
||||||
def test_updates_avatar_url(self, client: TestClient, test_user, access_token):
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["display_name"] == "New Name"
|
||||||
|
finally:
|
||||||
|
app.dependency_overrides.pop(get_user_service, None)
|
||||||
|
|
||||||
|
def test_updates_avatar_url(
|
||||||
|
self,
|
||||||
|
app,
|
||||||
|
client: TestClient,
|
||||||
|
test_user,
|
||||||
|
access_token,
|
||||||
|
mock_db_session,
|
||||||
|
mock_user_service_instance,
|
||||||
|
):
|
||||||
"""Test that endpoint updates avatar_url when provided."""
|
"""Test that endpoint updates avatar_url when provided."""
|
||||||
updated_user = test_user
|
# Create an updated user mock
|
||||||
|
updated_user = MagicMock()
|
||||||
|
updated_user.id = test_user.id
|
||||||
|
updated_user.email = test_user.email
|
||||||
|
updated_user.display_name = test_user.display_name
|
||||||
updated_user.avatar_url = "https://new-avatar.com/img.jpg"
|
updated_user.avatar_url = "https://new-avatar.com/img.jpg"
|
||||||
|
updated_user.is_premium = test_user.is_premium
|
||||||
|
updated_user.premium_until = test_user.premium_until
|
||||||
|
updated_user.created_at = test_user.created_at
|
||||||
|
|
||||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
# Set up db session to return test user for authentication
|
||||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = test_user
|
||||||
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
with patch("app.api.users.user_service") as mock_user_service:
|
# Set up user service mock
|
||||||
mock_user_service.update = AsyncMock(return_value=updated_user)
|
mock_user_service_instance.update.return_value = updated_user
|
||||||
|
|
||||||
response = client.patch(
|
# Override the dependency
|
||||||
"/api/users/me",
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
|
||||||
json={"avatar_url": "https://new-avatar.com/img.jpg"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
try:
|
||||||
data = response.json()
|
response = client.patch(
|
||||||
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
|
"/api/users/me",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json={"avatar_url": "https://new-avatar.com/img.jpg"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
|
||||||
|
finally:
|
||||||
|
app.dependency_overrides.pop(get_user_service, None)
|
||||||
|
|
||||||
def test_requires_authentication(self, client: TestClient):
|
def test_requires_authentication(self, client: TestClient):
|
||||||
"""Test that endpoint returns 401 without authentication."""
|
"""Test that endpoint returns 401 without authentication."""
|
||||||
@ -106,18 +175,22 @@ class TestUpdateCurrentUser:
|
|||||||
class TestGetLinkedAccounts:
|
class TestGetLinkedAccounts:
|
||||||
"""Tests for GET /api/users/me/linked-accounts endpoint."""
|
"""Tests for GET /api/users/me/linked-accounts endpoint."""
|
||||||
|
|
||||||
def test_returns_linked_accounts(self, client: TestClient, test_user, access_token):
|
def test_returns_linked_accounts(
|
||||||
|
self, app, client: TestClient, test_user, access_token, mock_db_session
|
||||||
|
):
|
||||||
"""Test that endpoint returns list of linked OAuth accounts.
|
"""Test that endpoint returns list of linked OAuth accounts.
|
||||||
|
|
||||||
Should include the primary provider and any linked accounts.
|
Should include the primary provider and any linked accounts.
|
||||||
"""
|
"""
|
||||||
with patch("app.api.deps.user_service") as mock_user_service:
|
# Set up db session to return test user for authentication
|
||||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = test_user
|
||||||
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
response = client.get(
|
response = client.get(
|
||||||
"/api/users/me/linked-accounts",
|
"/api/users/me/linked-accounts",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@ -135,7 +208,7 @@ class TestGetActiveSessions:
|
|||||||
"""Tests for GET /api/users/me/sessions endpoint."""
|
"""Tests for GET /api/users/me/sessions endpoint."""
|
||||||
|
|
||||||
def test_returns_session_count(
|
def test_returns_session_count(
|
||||||
self, client: TestClient, test_user, access_token, mock_get_redis
|
self, app, client: TestClient, test_user, access_token, mock_get_redis, mock_db_session
|
||||||
):
|
):
|
||||||
"""Test that endpoint returns count of active sessions.
|
"""Test that endpoint returns count of active sessions.
|
||||||
|
|
||||||
@ -154,14 +227,16 @@ class TestGetActiveSessions:
|
|||||||
|
|
||||||
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
||||||
|
|
||||||
with patch("app.api.deps.user_service") as mock_user_service:
|
# Set up db session to return test user for authentication
|
||||||
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = test_user
|
||||||
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
with patch("app.services.token_store.get_redis", mock_get_redis):
|
with patch("app.services.token_store.get_redis", mock_get_redis):
|
||||||
response = client.get(
|
response = client.get(
|
||||||
"/api/users/me/sessions",
|
"/api/users/me/sessions",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@ -177,78 +252,128 @@ class TestGetActiveSessions:
|
|||||||
class TestUnlinkOAuthAccount:
|
class TestUnlinkOAuthAccount:
|
||||||
"""Tests for DELETE /api/users/me/link/{provider} endpoint."""
|
"""Tests for DELETE /api/users/me/link/{provider} endpoint."""
|
||||||
|
|
||||||
def test_unlinks_provider_successfully(self, client: TestClient, test_user, access_token):
|
def test_unlinks_provider_successfully(
|
||||||
|
self,
|
||||||
|
app,
|
||||||
|
client: TestClient,
|
||||||
|
test_user,
|
||||||
|
access_token,
|
||||||
|
mock_db_session,
|
||||||
|
mock_user_service_instance,
|
||||||
|
):
|
||||||
"""Test that endpoint successfully unlinks a provider.
|
"""Test that endpoint successfully unlinks a provider.
|
||||||
|
|
||||||
Should return 204 when provider is unlinked.
|
Should return 204 when provider is unlinked.
|
||||||
"""
|
"""
|
||||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
# Set up db session to return test user for authentication
|
||||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = test_user
|
||||||
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
with patch("app.api.users.user_service") as mock_user_service:
|
# Set up user service mock
|
||||||
mock_user_service.unlink_oauth_account = AsyncMock(return_value=True)
|
mock_user_service_instance.unlink_oauth_account.return_value = True
|
||||||
|
|
||||||
response = client.delete(
|
# Override the dependency
|
||||||
"/api/users/me/link/discord",
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
try:
|
||||||
|
response = client.delete(
|
||||||
|
"/api/users/me/link/discord",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
def test_returns_404_if_not_linked(self, client: TestClient, test_user, access_token):
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
finally:
|
||||||
|
app.dependency_overrides.pop(get_user_service, None)
|
||||||
|
|
||||||
|
def test_returns_404_if_not_linked(
|
||||||
|
self,
|
||||||
|
app,
|
||||||
|
client: TestClient,
|
||||||
|
test_user,
|
||||||
|
access_token,
|
||||||
|
mock_db_session,
|
||||||
|
mock_user_service_instance,
|
||||||
|
):
|
||||||
"""Test that endpoint returns 404 if provider isn't linked.
|
"""Test that endpoint returns 404 if provider isn't linked.
|
||||||
|
|
||||||
Should return 404 when trying to unlink a provider that isn't linked.
|
Should return 404 when trying to unlink a provider that isn't linked.
|
||||||
"""
|
"""
|
||||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
# Set up db session to return test user for authentication
|
||||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = test_user
|
||||||
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
with patch("app.api.users.user_service") as mock_user_service:
|
# Set up user service mock
|
||||||
mock_user_service.unlink_oauth_account = AsyncMock(return_value=False)
|
mock_user_service_instance.unlink_oauth_account.return_value = False
|
||||||
|
|
||||||
response = client.delete(
|
# Override the dependency
|
||||||
"/api/users/me/link/discord",
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
try:
|
||||||
assert "not linked" in response.json()["detail"].lower()
|
response = client.delete(
|
||||||
|
"/api/users/me/link/discord",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
def test_returns_400_for_primary_provider(self, client: TestClient, test_user, access_token):
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert "not linked" in response.json()["detail"].lower()
|
||||||
|
finally:
|
||||||
|
app.dependency_overrides.pop(get_user_service, None)
|
||||||
|
|
||||||
|
def test_returns_400_for_primary_provider(
|
||||||
|
self,
|
||||||
|
app,
|
||||||
|
client: TestClient,
|
||||||
|
test_user,
|
||||||
|
access_token,
|
||||||
|
mock_db_session,
|
||||||
|
mock_user_service_instance,
|
||||||
|
):
|
||||||
"""Test that endpoint returns 400 when trying to unlink primary provider.
|
"""Test that endpoint returns 400 when trying to unlink primary provider.
|
||||||
|
|
||||||
Cannot unlink the provider used to create the account.
|
Cannot unlink the provider used to create the account.
|
||||||
"""
|
"""
|
||||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
# Set up db session to return test user for authentication
|
||||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = test_user
|
||||||
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
with patch("app.api.users.user_service") as mock_user_service:
|
# Set up user service mock to raise AccountLinkingError
|
||||||
mock_user_service.unlink_oauth_account = AsyncMock(
|
mock_user_service_instance.unlink_oauth_account.side_effect = AccountLinkingError(
|
||||||
side_effect=AccountLinkingError(
|
"Cannot unlink Google - it is your primary login provider"
|
||||||
"Cannot unlink Google - it is your primary login provider"
|
)
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.delete(
|
# Override the dependency
|
||||||
"/api/users/me/link/google",
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
try:
|
||||||
assert "primary" in response.json()["detail"].lower()
|
response = client.delete(
|
||||||
|
"/api/users/me/link/google",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
def test_returns_400_for_unknown_provider(self, client: TestClient, test_user, access_token):
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert "primary" in response.json()["detail"].lower()
|
||||||
|
finally:
|
||||||
|
app.dependency_overrides.pop(get_user_service, None)
|
||||||
|
|
||||||
|
def test_returns_400_for_unknown_provider(
|
||||||
|
self, app, client: TestClient, test_user, access_token, mock_db_session
|
||||||
|
):
|
||||||
"""Test that endpoint returns 400 for unknown provider.
|
"""Test that endpoint returns 400 for unknown provider.
|
||||||
|
|
||||||
Only 'google' and 'discord' are valid providers.
|
Only 'google' and 'discord' are valid providers.
|
||||||
"""
|
"""
|
||||||
with patch("app.api.deps.user_service") as mock_deps_service:
|
# Set up db session to return test user for authentication
|
||||||
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = test_user
|
||||||
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
response = client.delete(
|
response = client.delete(
|
||||||
"/api/users/me/link/twitter",
|
"/api/users/me/link/twitter",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
assert "unknown provider" in response.json()["detail"].lower()
|
assert "unknown provider" in response.json()["detail"].lower()
|
||||||
|
|||||||
@ -2,25 +2,43 @@
|
|||||||
|
|
||||||
Tests the user service CRUD operations and OAuth-based user creation.
|
Tests the user service CRUD operations and OAuth-based user creation.
|
||||||
Uses real Postgres via the db_session fixture from conftest.
|
Uses real Postgres via the db_session fixture from conftest.
|
||||||
|
|
||||||
|
The UserService now uses dependency injection with repositories injected
|
||||||
|
via constructor, so we create a fresh service instance per test.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.db.models import User
|
from app.db.models import User
|
||||||
from app.db.models.oauth_account import OAuthLinkedAccount
|
from app.db.models.oauth_account import OAuthLinkedAccount
|
||||||
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
|
from app.repositories.postgres.linked_account import PostgresLinkedAccountRepository
|
||||||
from app.services.user_service import AccountLinkingError, user_service
|
from app.repositories.postgres.user import PostgresUserRepository
|
||||||
|
from app.schemas.user import OAuthUserInfo, UserUpdate
|
||||||
|
from app.services.user_service import AccountLinkingError, UserService
|
||||||
|
|
||||||
# Import db_session fixture from db conftest
|
# Import db_session fixture from db conftest
|
||||||
pytestmark = pytest.mark.asyncio
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_service(db_session):
|
||||||
|
"""Create UserService with real PostgreSQL repositories.
|
||||||
|
|
||||||
|
This fixture provides a properly constructed UserService for each test,
|
||||||
|
following the dependency injection pattern used in production.
|
||||||
|
"""
|
||||||
|
user_repo = PostgresUserRepository(db_session)
|
||||||
|
linked_repo = PostgresLinkedAccountRepository(db_session)
|
||||||
|
return UserService(user_repo, linked_repo)
|
||||||
|
|
||||||
|
|
||||||
class TestGetById:
|
class TestGetById:
|
||||||
"""Tests for get_by_id method."""
|
"""Tests for get_by_id method."""
|
||||||
|
|
||||||
async def test_returns_user_when_found(self, db_session):
|
async def test_returns_user_when_found(self, db_session, user_service):
|
||||||
"""Test that get_by_id returns user when it exists.
|
"""Test that get_by_id returns user when it exists.
|
||||||
|
|
||||||
Creates a user and verifies it can be retrieved by ID.
|
Creates a user and verifies it can be retrieved by ID.
|
||||||
@ -36,26 +54,24 @@ class TestGetById:
|
|||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Retrieve by ID
|
# Retrieve by ID
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
result = await user_service.get_by_id(db_session, user_id)
|
result = await user_service.get_by_id(user_id)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.email == "test@example.com"
|
assert result.email == "test@example.com"
|
||||||
|
|
||||||
async def test_returns_none_when_not_found(self, db_session):
|
async def test_returns_none_when_not_found(self, user_service):
|
||||||
"""Test that get_by_id returns None for nonexistent users."""
|
"""Test that get_by_id returns None for nonexistent users."""
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
result = await user_service.get_by_id(db_session, uuid4())
|
result = await user_service.get_by_id(uuid4())
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
class TestGetByEmail:
|
class TestGetByEmail:
|
||||||
"""Tests for get_by_email method."""
|
"""Tests for get_by_email method."""
|
||||||
|
|
||||||
async def test_returns_user_when_found(self, db_session):
|
async def test_returns_user_when_found(self, db_session, user_service):
|
||||||
"""Test that get_by_email returns user when it exists."""
|
"""Test that get_by_email returns user when it exists."""
|
||||||
user = User(
|
user = User(
|
||||||
email="findme@example.com",
|
email="findme@example.com",
|
||||||
@ -66,21 +82,21 @@ class TestGetByEmail:
|
|||||||
db_session.add(user)
|
db_session.add(user)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
result = await user_service.get_by_email(db_session, "findme@example.com")
|
result = await user_service.get_by_email("findme@example.com")
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.display_name == "Find Me"
|
assert result.display_name == "Find Me"
|
||||||
|
|
||||||
async def test_returns_none_when_not_found(self, db_session):
|
async def test_returns_none_when_not_found(self, user_service):
|
||||||
"""Test that get_by_email returns None for nonexistent emails."""
|
"""Test that get_by_email returns None for nonexistent emails."""
|
||||||
result = await user_service.get_by_email(db_session, "nobody@example.com")
|
result = await user_service.get_by_email("nobody@example.com")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
class TestGetByOAuth:
|
class TestGetByOAuth:
|
||||||
"""Tests for get_by_oauth method."""
|
"""Tests for get_by_oauth method."""
|
||||||
|
|
||||||
async def test_returns_user_when_found(self, db_session):
|
async def test_returns_user_when_found(self, db_session, user_service):
|
||||||
"""Test that get_by_oauth returns user for matching provider+id."""
|
"""Test that get_by_oauth returns user for matching provider+id."""
|
||||||
user = User(
|
user = User(
|
||||||
email="oauth@example.com",
|
email="oauth@example.com",
|
||||||
@ -91,12 +107,12 @@ class TestGetByOAuth:
|
|||||||
db_session.add(user)
|
db_session.add(user)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
result = await user_service.get_by_oauth(db_session, "google", "google-unique-id")
|
result = await user_service.get_by_oauth("google", "google-unique-id")
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.email == "oauth@example.com"
|
assert result.email == "oauth@example.com"
|
||||||
|
|
||||||
async def test_returns_none_for_wrong_provider(self, db_session):
|
async def test_returns_none_for_wrong_provider(self, db_session, user_service):
|
||||||
"""Test that get_by_oauth returns None if provider doesn't match."""
|
"""Test that get_by_oauth returns None if provider doesn't match."""
|
||||||
user = User(
|
user = User(
|
||||||
email="oauth2@example.com",
|
email="oauth2@example.com",
|
||||||
@ -108,30 +124,28 @@ class TestGetByOAuth:
|
|||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Same ID, different provider
|
# Same ID, different provider
|
||||||
result = await user_service.get_by_oauth(db_session, "discord", "google-id-2")
|
result = await user_service.get_by_oauth("discord", "google-id-2")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
async def test_returns_none_when_not_found(self, db_session):
|
async def test_returns_none_when_not_found(self, user_service):
|
||||||
"""Test that get_by_oauth returns None for nonexistent OAuth."""
|
"""Test that get_by_oauth returns None for nonexistent OAuth."""
|
||||||
result = await user_service.get_by_oauth(db_session, "google", "nonexistent")
|
result = await user_service.get_by_oauth("google", "nonexistent")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
class TestCreate:
|
class TestCreate:
|
||||||
"""Tests for create method."""
|
"""Tests for create method."""
|
||||||
|
|
||||||
async def test_creates_user_with_all_fields(self, db_session):
|
async def test_creates_user_with_all_fields(self, user_service):
|
||||||
"""Test that create properly persists all user fields."""
|
"""Test that create properly persists all user fields."""
|
||||||
user_data = UserCreate(
|
result = await user_service.create(
|
||||||
email="new@example.com",
|
email="new@example.com",
|
||||||
display_name="New User",
|
display_name="New User",
|
||||||
avatar_url="https://example.com/avatar.jpg",
|
|
||||||
oauth_provider="discord",
|
oauth_provider="discord",
|
||||||
oauth_id="discord-new-id",
|
oauth_id="discord-new-id",
|
||||||
|
avatar_url="https://example.com/avatar.jpg",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await user_service.create(db_session, user_data)
|
|
||||||
|
|
||||||
assert result.id is not None
|
assert result.id is not None
|
||||||
assert result.email == "new@example.com"
|
assert result.email == "new@example.com"
|
||||||
assert result.display_name == "New User"
|
assert result.display_name == "New User"
|
||||||
@ -141,24 +155,22 @@ class TestCreate:
|
|||||||
assert result.is_premium is False
|
assert result.is_premium is False
|
||||||
assert result.premium_until is None
|
assert result.premium_until is None
|
||||||
|
|
||||||
async def test_creates_user_without_avatar(self, db_session):
|
async def test_creates_user_without_avatar(self, user_service):
|
||||||
"""Test that create works without optional avatar_url."""
|
"""Test that create works without optional avatar_url."""
|
||||||
user_data = UserCreate(
|
result = await user_service.create(
|
||||||
email="noavatar@example.com",
|
email="noavatar@example.com",
|
||||||
display_name="No Avatar",
|
display_name="No Avatar",
|
||||||
oauth_provider="google",
|
oauth_provider="google",
|
||||||
oauth_id="google-no-avatar",
|
oauth_id="google-no-avatar",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await user_service.create(db_session, user_data)
|
|
||||||
|
|
||||||
assert result.avatar_url is None
|
assert result.avatar_url is None
|
||||||
|
|
||||||
|
|
||||||
class TestCreateFromOAuth:
|
class TestCreateFromOAuth:
|
||||||
"""Tests for create_from_oauth method."""
|
"""Tests for create_from_oauth method."""
|
||||||
|
|
||||||
async def test_creates_user_from_oauth_info(self, db_session):
|
async def test_creates_user_from_oauth_info(self, user_service):
|
||||||
"""Test that create_from_oauth converts OAuthUserInfo to User."""
|
"""Test that create_from_oauth converts OAuthUserInfo to User."""
|
||||||
oauth_info = OAuthUserInfo(
|
oauth_info = OAuthUserInfo(
|
||||||
provider="google",
|
provider="google",
|
||||||
@ -168,7 +180,7 @@ class TestCreateFromOAuth:
|
|||||||
avatar_url="https://google.com/avatar.jpg",
|
avatar_url="https://google.com/avatar.jpg",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await user_service.create_from_oauth(db_session, oauth_info)
|
result = await user_service.create_from_oauth(oauth_info)
|
||||||
|
|
||||||
assert result.email == "oauthcreate@example.com"
|
assert result.email == "oauthcreate@example.com"
|
||||||
assert result.display_name == "OAuth Created User"
|
assert result.display_name == "OAuth Created User"
|
||||||
@ -179,7 +191,7 @@ class TestCreateFromOAuth:
|
|||||||
class TestGetOrCreateFromOAuth:
|
class TestGetOrCreateFromOAuth:
|
||||||
"""Tests for get_or_create_from_oauth method."""
|
"""Tests for get_or_create_from_oauth method."""
|
||||||
|
|
||||||
async def test_returns_existing_user_by_oauth(self, db_session):
|
async def test_returns_existing_user_by_oauth(self, db_session, user_service):
|
||||||
"""Test that existing user is returned when OAuth matches.
|
"""Test that existing user is returned when OAuth matches.
|
||||||
|
|
||||||
Verifies the method returns (user, False) for existing users.
|
Verifies the method returns (user, False) for existing users.
|
||||||
@ -202,12 +214,12 @@ class TestGetOrCreateFromOAuth:
|
|||||||
name="Existing",
|
name="Existing",
|
||||||
)
|
)
|
||||||
|
|
||||||
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
result, created = await user_service.get_or_create_from_oauth(oauth_info)
|
||||||
|
|
||||||
assert created is False
|
assert created is False
|
||||||
assert result.id == existing.id
|
assert str(result.id) == str(existing.id)
|
||||||
|
|
||||||
async def test_links_existing_user_by_email(self, db_session):
|
async def test_links_existing_user_by_email(self, db_session, user_service):
|
||||||
"""Test that OAuth is linked when email matches existing user.
|
"""Test that OAuth is linked when email matches existing user.
|
||||||
|
|
||||||
If a user exists with the same email but different OAuth,
|
If a user exists with the same email but different OAuth,
|
||||||
@ -232,15 +244,15 @@ class TestGetOrCreateFromOAuth:
|
|||||||
avatar_url="https://discord.com/avatar.jpg",
|
avatar_url="https://discord.com/avatar.jpg",
|
||||||
)
|
)
|
||||||
|
|
||||||
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
result, created = await user_service.get_or_create_from_oauth(oauth_info)
|
||||||
|
|
||||||
assert created is False
|
assert created is False
|
||||||
assert result.id == existing.id
|
assert str(result.id) == str(existing.id)
|
||||||
# OAuth should be updated to Discord
|
# OAuth should be updated to Discord
|
||||||
assert result.oauth_provider == "discord"
|
assert result.oauth_provider == "discord"
|
||||||
assert result.oauth_id == "discord-link-id"
|
assert result.oauth_id == "discord-link-id"
|
||||||
|
|
||||||
async def test_creates_new_user_when_not_found(self, db_session):
|
async def test_creates_new_user_when_not_found(self, user_service):
|
||||||
"""Test that new user is created when no match exists.
|
"""Test that new user is created when no match exists.
|
||||||
|
|
||||||
Verifies the method returns (user, True) for new users.
|
Verifies the method returns (user, True) for new users.
|
||||||
@ -252,7 +264,7 @@ class TestGetOrCreateFromOAuth:
|
|||||||
name="Brand New",
|
name="Brand New",
|
||||||
)
|
)
|
||||||
|
|
||||||
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
result, created = await user_service.get_or_create_from_oauth(oauth_info)
|
||||||
|
|
||||||
assert created is True
|
assert created is True
|
||||||
assert result.email == "brandnew@example.com"
|
assert result.email == "brandnew@example.com"
|
||||||
@ -261,7 +273,7 @@ class TestGetOrCreateFromOAuth:
|
|||||||
class TestUpdate:
|
class TestUpdate:
|
||||||
"""Tests for update method."""
|
"""Tests for update method."""
|
||||||
|
|
||||||
async def test_updates_display_name(self, db_session):
|
async def test_updates_display_name(self, db_session, user_service):
|
||||||
"""Test that update changes display_name when provided."""
|
"""Test that update changes display_name when provided."""
|
||||||
user = User(
|
user = User(
|
||||||
email="update@example.com",
|
email="update@example.com",
|
||||||
@ -272,12 +284,13 @@ class TestUpdate:
|
|||||||
db_session.add(user)
|
db_session.add(user)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
update_data = UserUpdate(display_name="New Name")
|
update_data = UserUpdate(display_name="New Name")
|
||||||
result = await user_service.update(db_session, user, update_data)
|
result = await user_service.update(user_id, update_data)
|
||||||
|
|
||||||
assert result.display_name == "New Name"
|
assert result.display_name == "New Name"
|
||||||
|
|
||||||
async def test_updates_avatar_url(self, db_session):
|
async def test_updates_avatar_url(self, db_session, user_service):
|
||||||
"""Test that update changes avatar_url when provided."""
|
"""Test that update changes avatar_url when provided."""
|
||||||
user = User(
|
user = User(
|
||||||
email="avatar@example.com",
|
email="avatar@example.com",
|
||||||
@ -288,12 +301,13 @@ class TestUpdate:
|
|||||||
db_session.add(user)
|
db_session.add(user)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
update_data = UserUpdate(avatar_url="https://new-avatar.com/img.jpg")
|
update_data = UserUpdate(avatar_url="https://new-avatar.com/img.jpg")
|
||||||
result = await user_service.update(db_session, user, update_data)
|
result = await user_service.update(user_id, update_data)
|
||||||
|
|
||||||
assert result.avatar_url == "https://new-avatar.com/img.jpg"
|
assert result.avatar_url == "https://new-avatar.com/img.jpg"
|
||||||
|
|
||||||
async def test_ignores_none_values(self, db_session):
|
async def test_ignores_none_values(self, db_session, user_service):
|
||||||
"""Test that update doesn't change fields set to None.
|
"""Test that update doesn't change fields set to None.
|
||||||
|
|
||||||
Only explicitly provided fields should be updated.
|
Only explicitly provided fields should be updated.
|
||||||
@ -309,8 +323,9 @@ class TestUpdate:
|
|||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Update only display_name, leave avatar alone
|
# Update only display_name, leave avatar alone
|
||||||
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
update_data = UserUpdate(display_name="Changed")
|
update_data = UserUpdate(display_name="Changed")
|
||||||
result = await user_service.update(db_session, user, update_data)
|
result = await user_service.update(user_id, update_data)
|
||||||
|
|
||||||
assert result.display_name == "Changed"
|
assert result.display_name == "Changed"
|
||||||
assert result.avatar_url == "https://keep.com/avatar.jpg"
|
assert result.avatar_url == "https://keep.com/avatar.jpg"
|
||||||
@ -319,7 +334,7 @@ class TestUpdate:
|
|||||||
class TestUpdateLastLogin:
|
class TestUpdateLastLogin:
|
||||||
"""Tests for update_last_login method."""
|
"""Tests for update_last_login method."""
|
||||||
|
|
||||||
async def test_updates_last_login_timestamp(self, db_session):
|
async def test_updates_last_login_timestamp(self, db_session, user_service):
|
||||||
"""Test that update_last_login sets current timestamp."""
|
"""Test that update_last_login sets current timestamp."""
|
||||||
user = User(
|
user = User(
|
||||||
email="login@example.com",
|
email="login@example.com",
|
||||||
@ -332,8 +347,9 @@ class TestUpdateLastLogin:
|
|||||||
|
|
||||||
assert user.last_login is None
|
assert user.last_login is None
|
||||||
|
|
||||||
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
before = datetime.now(UTC)
|
before = datetime.now(UTC)
|
||||||
result = await user_service.update_last_login(db_session, user)
|
result = await user_service.update_last_login(user_id)
|
||||||
after = datetime.now(UTC)
|
after = datetime.now(UTC)
|
||||||
|
|
||||||
assert result.last_login is not None
|
assert result.last_login is not None
|
||||||
@ -344,7 +360,7 @@ class TestUpdateLastLogin:
|
|||||||
class TestUpdatePremium:
|
class TestUpdatePremium:
|
||||||
"""Tests for update_premium method."""
|
"""Tests for update_premium method."""
|
||||||
|
|
||||||
async def test_grants_premium(self, db_session):
|
async def test_grants_premium(self, db_session, user_service):
|
||||||
"""Test that update_premium sets premium status and expiration."""
|
"""Test that update_premium sets premium status and expiration."""
|
||||||
user = User(
|
user = User(
|
||||||
email="premium@example.com",
|
email="premium@example.com",
|
||||||
@ -357,13 +373,14 @@ class TestUpdatePremium:
|
|||||||
|
|
||||||
assert user.is_premium is False
|
assert user.is_premium is False
|
||||||
|
|
||||||
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
expires = datetime.now(UTC) + timedelta(days=30)
|
expires = datetime.now(UTC) + timedelta(days=30)
|
||||||
result = await user_service.update_premium(db_session, user, expires)
|
result = await user_service.update_premium(user_id, expires)
|
||||||
|
|
||||||
assert result.is_premium is True
|
assert result.is_premium is True
|
||||||
assert result.premium_until == expires
|
assert result.premium_until == expires
|
||||||
|
|
||||||
async def test_removes_premium(self, db_session):
|
async def test_removes_premium(self, db_session, user_service):
|
||||||
"""Test that update_premium with None removes premium status."""
|
"""Test that update_premium with None removes premium status."""
|
||||||
user = User(
|
user = User(
|
||||||
email="unpremium@example.com",
|
email="unpremium@example.com",
|
||||||
@ -376,7 +393,8 @@ class TestUpdatePremium:
|
|||||||
db_session.add(user)
|
db_session.add(user)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
result = await user_service.update_premium(db_session, user, None)
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
|
result = await user_service.update_premium(user_id, None)
|
||||||
|
|
||||||
assert result.is_premium is False
|
assert result.is_premium is False
|
||||||
assert result.premium_until is None
|
assert result.premium_until is None
|
||||||
@ -385,7 +403,7 @@ class TestUpdatePremium:
|
|||||||
class TestDelete:
|
class TestDelete:
|
||||||
"""Tests for delete method."""
|
"""Tests for delete method."""
|
||||||
|
|
||||||
async def test_deletes_user(self, db_session):
|
async def test_deletes_user(self, db_session, user_service):
|
||||||
"""Test that delete removes user from database."""
|
"""Test that delete removes user from database."""
|
||||||
user = User(
|
user = User(
|
||||||
email="delete@example.com",
|
email="delete@example.com",
|
||||||
@ -396,22 +414,18 @@ class TestDelete:
|
|||||||
db_session.add(user)
|
db_session.add(user)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
await user_service.delete(db_session, user)
|
await user_service.delete(user_id)
|
||||||
|
|
||||||
# Verify user is gone
|
# Verify user is gone
|
||||||
from uuid import UUID
|
result = await user_service.get_by_id(user_id)
|
||||||
|
|
||||||
result = await user_service.get_by_id(
|
|
||||||
db_session, UUID(user_id) if isinstance(user_id, str) else user_id
|
|
||||||
)
|
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
class TestGetLinkedAccount:
|
class TestGetLinkedAccount:
|
||||||
"""Tests for get_linked_account method."""
|
"""Tests for get_linked_account method."""
|
||||||
|
|
||||||
async def test_returns_linked_account_when_found(self, db_session):
|
async def test_returns_linked_account_when_found(self, db_session, user_service):
|
||||||
"""Test that get_linked_account returns account when it exists.
|
"""Test that get_linked_account returns account when it exists.
|
||||||
|
|
||||||
Creates a user with a linked account and verifies it can be retrieved.
|
Creates a user with a linked account and verifies it can be retrieved.
|
||||||
@ -437,22 +451,22 @@ class TestGetLinkedAccount:
|
|||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Retrieve linked account
|
# Retrieve linked account
|
||||||
result = await user_service.get_linked_account(db_session, "discord", "discord-linked-123")
|
result = await user_service.get_linked_account("discord", "discord-linked-123")
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.provider == "discord"
|
assert result.provider == "discord"
|
||||||
assert result.oauth_id == "discord-linked-123"
|
assert result.oauth_id == "discord-linked-123"
|
||||||
|
|
||||||
async def test_returns_none_when_not_found(self, db_session):
|
async def test_returns_none_when_not_found(self, user_service):
|
||||||
"""Test that get_linked_account returns None for nonexistent accounts."""
|
"""Test that get_linked_account returns None for nonexistent accounts."""
|
||||||
result = await user_service.get_linked_account(db_session, "discord", "nonexistent-id")
|
result = await user_service.get_linked_account("discord", "nonexistent-id")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
class TestLinkOAuthAccount:
|
class TestLinkOAuthAccount:
|
||||||
"""Tests for link_oauth_account method."""
|
"""Tests for link_oauth_account method."""
|
||||||
|
|
||||||
async def test_links_new_provider(self, db_session):
|
async def test_links_new_provider(self, db_session, user_service):
|
||||||
"""Test that link_oauth_account successfully links a new provider.
|
"""Test that link_oauth_account successfully links a new provider.
|
||||||
|
|
||||||
Creates a Google user and links Discord to them.
|
Creates a Google user and links Discord to them.
|
||||||
@ -468,6 +482,8 @@ class TestLinkOAuthAccount:
|
|||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
await db_session.refresh(user)
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
|
|
||||||
# Link Discord
|
# Link Discord
|
||||||
discord_info = OAuthUserInfo(
|
discord_info = OAuthUserInfo(
|
||||||
provider="discord",
|
provider="discord",
|
||||||
@ -477,16 +493,16 @@ class TestLinkOAuthAccount:
|
|||||||
avatar_url="https://discord.com/avatar.png",
|
avatar_url="https://discord.com/avatar.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await user_service.link_oauth_account(db_session, user, discord_info)
|
result = await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.provider == "discord"
|
assert result.provider == "discord"
|
||||||
assert result.oauth_id == "discord-456"
|
assert result.oauth_id == "discord-456"
|
||||||
assert result.email == "discord@example.com"
|
assert result.email == "discord@example.com"
|
||||||
assert result.display_name == "Discord Name"
|
assert result.display_name == "Discord Name"
|
||||||
assert str(result.user_id) == str(user.id)
|
assert str(result.user_id) == str(user_id)
|
||||||
|
|
||||||
async def test_raises_error_if_already_linked_to_same_user(self, db_session):
|
async def test_raises_error_if_already_linked_to_same_user(self, db_session, user_service):
|
||||||
"""Test that linking same provider twice raises error.
|
"""Test that linking same provider twice raises error.
|
||||||
|
|
||||||
A user cannot have the same provider linked multiple times.
|
A user cannot have the same provider linked multiple times.
|
||||||
@ -501,6 +517,8 @@ class TestLinkOAuthAccount:
|
|||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
await db_session.refresh(user)
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
|
|
||||||
# Link Discord first time
|
# Link Discord first time
|
||||||
discord_info = OAuthUserInfo(
|
discord_info = OAuthUserInfo(
|
||||||
provider="discord",
|
provider="discord",
|
||||||
@ -508,16 +526,15 @@ class TestLinkOAuthAccount:
|
|||||||
email="first@discord.com",
|
email="first@discord.com",
|
||||||
name="First",
|
name="First",
|
||||||
)
|
)
|
||||||
await user_service.link_oauth_account(db_session, user, discord_info)
|
await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
|
||||||
await db_session.refresh(user)
|
|
||||||
|
|
||||||
# Try to link same Discord account again
|
# Try to link same Discord account again
|
||||||
with pytest.raises(AccountLinkingError) as exc_info:
|
with pytest.raises(AccountLinkingError) as exc_info:
|
||||||
await user_service.link_oauth_account(db_session, user, discord_info)
|
await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
|
||||||
|
|
||||||
assert "already linked to your account" in str(exc_info.value)
|
assert "already linked to your account" in str(exc_info.value)
|
||||||
|
|
||||||
async def test_raises_error_if_linked_to_another_user(self, db_session):
|
async def test_raises_error_if_linked_to_another_user(self, db_session, user_service):
|
||||||
"""Test that linking account already linked to another user raises error.
|
"""Test that linking account already linked to another user raises error.
|
||||||
|
|
||||||
The same OAuth provider+ID cannot be linked to multiple users.
|
The same OAuth provider+ID cannot be linked to multiple users.
|
||||||
@ -533,13 +550,15 @@ class TestLinkOAuthAccount:
|
|||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
await db_session.refresh(user1)
|
await db_session.refresh(user1)
|
||||||
|
|
||||||
|
user1_id = UUID(user1.id) if isinstance(user1.id, str) else user1.id
|
||||||
|
|
||||||
discord_info = OAuthUserInfo(
|
discord_info = OAuthUserInfo(
|
||||||
provider="discord",
|
provider="discord",
|
||||||
oauth_id="shared-discord",
|
oauth_id="shared-discord",
|
||||||
email="shared@discord.com",
|
email="shared@discord.com",
|
||||||
name="Shared",
|
name="Shared",
|
||||||
)
|
)
|
||||||
await user_service.link_oauth_account(db_session, user1, discord_info)
|
await user_service.link_oauth_account(user1_id, user1.oauth_provider, discord_info)
|
||||||
|
|
||||||
# Create second user
|
# Create second user
|
||||||
user2 = User(
|
user2 = User(
|
||||||
@ -552,13 +571,15 @@ class TestLinkOAuthAccount:
|
|||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
await db_session.refresh(user2)
|
await db_session.refresh(user2)
|
||||||
|
|
||||||
|
user2_id = UUID(user2.id) if isinstance(user2.id, str) else user2.id
|
||||||
|
|
||||||
# Try to link same Discord account to second user
|
# Try to link same Discord account to second user
|
||||||
with pytest.raises(AccountLinkingError) as exc_info:
|
with pytest.raises(AccountLinkingError) as exc_info:
|
||||||
await user_service.link_oauth_account(db_session, user2, discord_info)
|
await user_service.link_oauth_account(user2_id, user2.oauth_provider, discord_info)
|
||||||
|
|
||||||
assert "already linked to another user" in str(exc_info.value)
|
assert "already linked to another user" in str(exc_info.value)
|
||||||
|
|
||||||
async def test_raises_error_if_linking_primary_provider(self, db_session):
|
async def test_raises_error_if_linking_primary_provider(self, db_session, user_service):
|
||||||
"""Test that linking the same provider as primary raises error.
|
"""Test that linking the same provider as primary raises error.
|
||||||
|
|
||||||
User cannot link Google if they already signed up with Google.
|
User cannot link Google if they already signed up with Google.
|
||||||
@ -573,6 +594,8 @@ class TestLinkOAuthAccount:
|
|||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
await db_session.refresh(user)
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
|
|
||||||
# Try to link another Google account
|
# Try to link another Google account
|
||||||
google_info = OAuthUserInfo(
|
google_info = OAuthUserInfo(
|
||||||
provider="google",
|
provider="google",
|
||||||
@ -582,7 +605,7 @@ class TestLinkOAuthAccount:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(AccountLinkingError) as exc_info:
|
with pytest.raises(AccountLinkingError) as exc_info:
|
||||||
await user_service.link_oauth_account(db_session, user, google_info)
|
await user_service.link_oauth_account(user_id, user.oauth_provider, google_info)
|
||||||
|
|
||||||
assert "primary login provider" in str(exc_info.value)
|
assert "primary login provider" in str(exc_info.value)
|
||||||
|
|
||||||
@ -590,7 +613,7 @@ class TestLinkOAuthAccount:
|
|||||||
class TestUnlinkOAuthAccount:
|
class TestUnlinkOAuthAccount:
|
||||||
"""Tests for unlink_oauth_account method."""
|
"""Tests for unlink_oauth_account method."""
|
||||||
|
|
||||||
async def test_unlinks_linked_account(self, db_session):
|
async def test_unlinks_linked_account(self, db_session, user_service):
|
||||||
"""Test that unlink_oauth_account removes a linked account.
|
"""Test that unlink_oauth_account removes a linked account.
|
||||||
|
|
||||||
Links Discord then unlinks it successfully.
|
Links Discord then unlinks it successfully.
|
||||||
@ -605,6 +628,8 @@ class TestUnlinkOAuthAccount:
|
|||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
await db_session.refresh(user)
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
|
|
||||||
# Link Discord
|
# Link Discord
|
||||||
discord_info = OAuthUserInfo(
|
discord_info = OAuthUserInfo(
|
||||||
provider="discord",
|
provider="discord",
|
||||||
@ -612,22 +637,18 @@ class TestUnlinkOAuthAccount:
|
|||||||
email="discord@unlink.com",
|
email="discord@unlink.com",
|
||||||
name="Discord Unlink",
|
name="Discord Unlink",
|
||||||
)
|
)
|
||||||
await user_service.link_oauth_account(db_session, user, discord_info)
|
await user_service.link_oauth_account(user_id, user.oauth_provider, discord_info)
|
||||||
await db_session.refresh(user)
|
|
||||||
|
|
||||||
# Verify linked
|
|
||||||
assert len(user.linked_accounts) == 1
|
|
||||||
|
|
||||||
# Unlink
|
# Unlink
|
||||||
result = await user_service.unlink_oauth_account(db_session, user, "discord")
|
result = await user_service.unlink_oauth_account(user_id, user.oauth_provider, "discord")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
# Verify unlinked
|
# Verify unlinked
|
||||||
linked = await user_service.get_linked_account(db_session, "discord", "discord-unlink")
|
linked = await user_service.get_linked_account("discord", "discord-unlink")
|
||||||
assert linked is None
|
assert linked is None
|
||||||
|
|
||||||
async def test_returns_false_if_not_linked(self, db_session):
|
async def test_returns_false_if_not_linked(self, db_session, user_service):
|
||||||
"""Test that unlink returns False if provider isn't linked."""
|
"""Test that unlink returns False if provider isn't linked."""
|
||||||
user = User(
|
user = User(
|
||||||
email="not-linked@example.com",
|
email="not-linked@example.com",
|
||||||
@ -639,11 +660,12 @@ class TestUnlinkOAuthAccount:
|
|||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
await db_session.refresh(user)
|
await db_session.refresh(user)
|
||||||
|
|
||||||
result = await user_service.unlink_oauth_account(db_session, user, "discord")
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
|
result = await user_service.unlink_oauth_account(user_id, user.oauth_provider, "discord")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
async def test_raises_error_if_unlinking_primary(self, db_session):
|
async def test_raises_error_if_unlinking_primary(self, db_session, user_service):
|
||||||
"""Test that unlinking primary provider raises error.
|
"""Test that unlinking primary provider raises error.
|
||||||
|
|
||||||
User cannot unlink their primary OAuth provider.
|
User cannot unlink their primary OAuth provider.
|
||||||
@ -658,7 +680,9 @@ class TestUnlinkOAuthAccount:
|
|||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
await db_session.refresh(user)
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
||||||
|
|
||||||
with pytest.raises(AccountLinkingError) as exc_info:
|
with pytest.raises(AccountLinkingError) as exc_info:
|
||||||
await user_service.unlink_oauth_account(db_session, user, "google")
|
await user_service.unlink_oauth_account(user_id, user.oauth_provider, "google")
|
||||||
|
|
||||||
assert "primary login provider" in str(exc_info.value)
|
assert "primary login provider" in str(exc_info.value)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user