- Add base_url config setting for OAuth callback URLs
- Change OAuth callbacks from relative to absolute URLs
- Add account linking OAuth flow (GET /auth/link/{provider})
- Add unlink endpoint (DELETE /users/me/link/{provider})
- Add AccountLinkingError and service methods for linking
- Add 14 new tests for linking functionality
- Update Phase 2 plan to mark complete (1072 tests passing)
440 lines
13 KiB
Python
440 lines
13 KiB
Python
"""User service for Mantimon TCG.
|
|
|
|
This module provides async CRUD operations for user accounts,
|
|
including OAuth-based user creation and premium status management.
|
|
|
|
All database operations use async SQLAlchemy sessions.
|
|
|
|
Example:
|
|
from app.services.user_service import user_service
|
|
|
|
# Get user by ID
|
|
user = await user_service.get_by_id(db, user_id)
|
|
|
|
# Create from OAuth
|
|
user = await user_service.create_from_oauth(db, oauth_info)
|
|
|
|
# Update premium status
|
|
user = await user_service.update_premium(db, user_id, premium_until)
|
|
"""
|
|
|
|
from datetime import UTC, datetime
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models.oauth_account import OAuthLinkedAccount
|
|
from app.db.models.user import User
|
|
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
|
|
|
|
|
|
class AccountLinkingError(Exception):
|
|
"""Error during account linking operation."""
|
|
|
|
pass
|
|
|
|
|
|
class UserService:
|
|
"""Service for user account operations.
|
|
|
|
Provides async methods for user CRUD, OAuth-based creation,
|
|
and premium subscription management.
|
|
"""
|
|
|
|
async def get_by_id(self, db: AsyncSession, user_id: UUID) -> User | None:
|
|
"""Get a user by their ID.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
user_id: The user's UUID.
|
|
|
|
Returns:
|
|
User if found, None otherwise.
|
|
|
|
Example:
|
|
user = await user_service.get_by_id(db, user_id)
|
|
if user:
|
|
print(f"Found user: {user.display_name}")
|
|
"""
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
|
|
"""Get a user by their email address.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
email: The user's email address.
|
|
|
|
Returns:
|
|
User if found, None otherwise.
|
|
|
|
Example:
|
|
user = await user_service.get_by_email(db, "player@example.com")
|
|
"""
|
|
result = await db.execute(select(User).where(User.email == email))
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_oauth(
|
|
self,
|
|
db: AsyncSession,
|
|
provider: str,
|
|
oauth_id: str,
|
|
) -> User | None:
|
|
"""Get a user by their OAuth provider and ID.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
provider: OAuth provider name (google, discord).
|
|
oauth_id: Unique ID from the OAuth provider.
|
|
|
|
Returns:
|
|
User if found, None otherwise.
|
|
|
|
Example:
|
|
user = await user_service.get_by_oauth(db, "google", "123456789")
|
|
"""
|
|
result = await db.execute(
|
|
select(User).where(
|
|
User.oauth_provider == provider,
|
|
User.oauth_id == oauth_id,
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def create(self, db: AsyncSession, user_data: UserCreate) -> User:
|
|
"""Create a new user.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
user_data: User creation data.
|
|
|
|
Returns:
|
|
The created User instance.
|
|
|
|
Example:
|
|
user_data = UserCreate(
|
|
email="player@example.com",
|
|
display_name="Player1",
|
|
oauth_provider="google",
|
|
oauth_id="123456789"
|
|
)
|
|
user = await user_service.create(db, user_data)
|
|
"""
|
|
user = User(
|
|
email=user_data.email,
|
|
display_name=user_data.display_name,
|
|
avatar_url=user_data.avatar_url,
|
|
oauth_provider=user_data.oauth_provider,
|
|
oauth_id=user_data.oauth_id,
|
|
)
|
|
db.add(user)
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
return user
|
|
|
|
async def create_from_oauth(
|
|
self,
|
|
db: AsyncSession,
|
|
oauth_info: OAuthUserInfo,
|
|
) -> User:
|
|
"""Create a new user from OAuth provider info.
|
|
|
|
Convenience method that converts OAuthUserInfo to UserCreate.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
oauth_info: Normalized OAuth user information.
|
|
|
|
Returns:
|
|
The created User instance.
|
|
|
|
Example:
|
|
oauth_info = OAuthUserInfo(
|
|
provider="google",
|
|
oauth_id="123456789",
|
|
email="player@example.com",
|
|
name="Player One",
|
|
avatar_url="https://..."
|
|
)
|
|
user = await user_service.create_from_oauth(db, oauth_info)
|
|
"""
|
|
user_data = oauth_info.to_user_create()
|
|
return await self.create(db, user_data)
|
|
|
|
async def get_or_create_from_oauth(
|
|
self,
|
|
db: AsyncSession,
|
|
oauth_info: OAuthUserInfo,
|
|
) -> tuple[User, bool]:
|
|
"""Get existing user or create new one from OAuth info.
|
|
|
|
First checks for existing user by OAuth provider+ID, then by email
|
|
(for account linking), and finally creates a new user if not found.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
oauth_info: Normalized OAuth user information.
|
|
|
|
Returns:
|
|
Tuple of (User, created) where created is True if new user.
|
|
|
|
Example:
|
|
user, created = await user_service.get_or_create_from_oauth(db, oauth_info)
|
|
if created:
|
|
print("Welcome, new user!")
|
|
else:
|
|
print("Welcome back!")
|
|
"""
|
|
# First, check by OAuth provider + ID (exact match)
|
|
user = await self.get_by_oauth(db, oauth_info.provider, oauth_info.oauth_id)
|
|
if user:
|
|
return user, False
|
|
|
|
# Check by email for potential account linking
|
|
# If user exists with same email but different OAuth, update their OAuth
|
|
user = await self.get_by_email(db, oauth_info.email)
|
|
if user:
|
|
# Update OAuth credentials for existing user
|
|
# This links the new OAuth provider to the existing account
|
|
user.oauth_provider = oauth_info.provider
|
|
user.oauth_id = oauth_info.oauth_id
|
|
# Optionally update avatar if not set
|
|
if not user.avatar_url and oauth_info.avatar_url:
|
|
user.avatar_url = oauth_info.avatar_url
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
return user, False
|
|
|
|
# Create new user
|
|
user = await self.create_from_oauth(db, oauth_info)
|
|
return user, True
|
|
|
|
async def update(
|
|
self,
|
|
db: AsyncSession,
|
|
user: User,
|
|
update_data: UserUpdate,
|
|
) -> User:
|
|
"""Update user profile fields.
|
|
|
|
Only updates fields that are provided (not None).
|
|
|
|
Args:
|
|
db: Async database session.
|
|
user: The user to update.
|
|
update_data: Fields to update.
|
|
|
|
Returns:
|
|
The updated User instance.
|
|
|
|
Example:
|
|
update_data = UserUpdate(display_name="New Name")
|
|
user = await user_service.update(db, user, update_data)
|
|
"""
|
|
if update_data.display_name is not None:
|
|
user.display_name = update_data.display_name
|
|
if update_data.avatar_url is not None:
|
|
user.avatar_url = update_data.avatar_url
|
|
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
return user
|
|
|
|
async def update_last_login(self, db: AsyncSession, user: User) -> User:
|
|
"""Update the user's last login timestamp.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
user: The user to update.
|
|
|
|
Returns:
|
|
The updated User instance.
|
|
|
|
Example:
|
|
user = await user_service.update_last_login(db, user)
|
|
"""
|
|
user.last_login = datetime.now(UTC)
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
return user
|
|
|
|
async def update_premium(
|
|
self,
|
|
db: AsyncSession,
|
|
user: User,
|
|
premium_until: datetime | None,
|
|
) -> User:
|
|
"""Update user's premium subscription status.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
user: The user to update.
|
|
premium_until: When premium expires, or None to remove premium.
|
|
|
|
Returns:
|
|
The updated User instance.
|
|
|
|
Example:
|
|
# Grant 30 days of premium
|
|
expires = datetime.now(UTC) + timedelta(days=30)
|
|
user = await user_service.update_premium(db, user, expires)
|
|
|
|
# Remove premium
|
|
user = await user_service.update_premium(db, user, None)
|
|
"""
|
|
if premium_until is not None:
|
|
user.is_premium = True
|
|
user.premium_until = premium_until
|
|
else:
|
|
user.is_premium = False
|
|
user.premium_until = None
|
|
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
return user
|
|
|
|
async def delete(self, db: AsyncSession, user: User) -> None:
|
|
"""Delete a user account.
|
|
|
|
This will cascade delete all related data (decks, collection, etc.)
|
|
based on the model relationships.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
user: The user to delete.
|
|
|
|
Example:
|
|
await user_service.delete(db, user)
|
|
"""
|
|
await db.delete(user)
|
|
await db.commit()
|
|
|
|
async def get_linked_account(
|
|
self,
|
|
db: AsyncSession,
|
|
provider: str,
|
|
oauth_id: str,
|
|
) -> OAuthLinkedAccount | None:
|
|
"""Get a linked account by provider and OAuth ID.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
provider: OAuth provider name (google, discord).
|
|
oauth_id: Unique ID from the OAuth provider.
|
|
|
|
Returns:
|
|
OAuthLinkedAccount if found, None otherwise.
|
|
"""
|
|
result = await db.execute(
|
|
select(OAuthLinkedAccount).where(
|
|
OAuthLinkedAccount.provider == provider,
|
|
OAuthLinkedAccount.oauth_id == oauth_id,
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def link_oauth_account(
|
|
self,
|
|
db: AsyncSession,
|
|
user: User,
|
|
oauth_info: OAuthUserInfo,
|
|
) -> OAuthLinkedAccount:
|
|
"""Link an additional OAuth provider to a user account.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
user: The user to link the account to.
|
|
oauth_info: OAuth information from the provider.
|
|
|
|
Returns:
|
|
The created OAuthLinkedAccount.
|
|
|
|
Raises:
|
|
AccountLinkingError: If provider is already linked to this or another user.
|
|
|
|
Example:
|
|
linked = await user_service.link_oauth_account(db, user, discord_info)
|
|
"""
|
|
# Check if this provider+oauth_id is already linked to any user
|
|
existing = await self.get_linked_account(db, oauth_info.provider, oauth_info.oauth_id)
|
|
if existing:
|
|
if str(existing.user_id) == str(user.id):
|
|
raise AccountLinkingError(
|
|
f"{oauth_info.provider.title()} account is already linked to your account"
|
|
)
|
|
raise AccountLinkingError(
|
|
f"This {oauth_info.provider.title()} account is already linked to another user"
|
|
)
|
|
|
|
# Check if this is the user's primary OAuth provider
|
|
if user.oauth_provider == oauth_info.provider:
|
|
raise AccountLinkingError(
|
|
f"{oauth_info.provider.title()} is your primary login provider"
|
|
)
|
|
|
|
# Check if user already has this provider linked
|
|
for linked in user.linked_accounts:
|
|
if linked.provider == oauth_info.provider:
|
|
raise AccountLinkingError(
|
|
f"You already have a {oauth_info.provider.title()} account linked"
|
|
)
|
|
|
|
# Create the linked account
|
|
linked_account = OAuthLinkedAccount(
|
|
user_id=str(user.id),
|
|
provider=oauth_info.provider,
|
|
oauth_id=oauth_info.oauth_id,
|
|
email=oauth_info.email,
|
|
display_name=oauth_info.name,
|
|
avatar_url=oauth_info.avatar_url,
|
|
)
|
|
db.add(linked_account)
|
|
await db.commit()
|
|
await db.refresh(linked_account)
|
|
return linked_account
|
|
|
|
async def unlink_oauth_account(
|
|
self,
|
|
db: AsyncSession,
|
|
user: User,
|
|
provider: str,
|
|
) -> bool:
|
|
"""Unlink an OAuth provider from a user account.
|
|
|
|
Cannot unlink the primary OAuth provider.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
user: The user to unlink from.
|
|
provider: OAuth provider name to unlink.
|
|
|
|
Returns:
|
|
True if unlinked, False if provider wasn't linked.
|
|
|
|
Raises:
|
|
AccountLinkingError: If trying to unlink the primary provider.
|
|
|
|
Example:
|
|
success = await user_service.unlink_oauth_account(db, user, "discord")
|
|
"""
|
|
# Cannot unlink primary provider
|
|
if user.oauth_provider == provider:
|
|
raise AccountLinkingError(
|
|
f"Cannot unlink {provider.title()} - it is your primary login provider"
|
|
)
|
|
|
|
# Find and delete the linked account
|
|
for linked in user.linked_accounts:
|
|
if linked.provider == provider:
|
|
await db.delete(linked)
|
|
await db.commit()
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
# Global service instance
|
|
user_service = UserService()
|