- Add UserRepository and LinkedAccountRepository protocols to protocols.py - Add UserEntry and LinkedAccountEntry DTOs for service layer decoupling - Implement PostgresUserRepository and PostgresLinkedAccountRepository - Refactor UserService to use constructor-injected repositories - Add get_user_service factory and UserServiceDep to API deps - Update auth.py and users.py endpoints to use UserServiceDep - Rewrite tests to use FastAPI dependency overrides (no monkey patching) This follows the established repository pattern used by DeckService and CollectionService, enabling future offline fork support. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
244 lines
7.0 KiB
Python
244 lines
7.0 KiB
Python
"""PostgreSQL implementation of UserRepository.
|
|
|
|
This module provides the PostgreSQL-specific implementation of the
|
|
UserRepository protocol using SQLAlchemy async sessions.
|
|
|
|
Example:
|
|
async with get_db_session() as db:
|
|
repo = PostgresUserRepository(db)
|
|
user = await repo.get_by_id(user_id)
|
|
"""
|
|
|
|
from datetime import UTC, datetime
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models.user import User
|
|
from app.repositories.protocols import UNSET, UserEntry
|
|
|
|
|
|
def _to_dto(model: User) -> UserEntry:
|
|
"""Convert ORM model to DTO."""
|
|
return UserEntry(
|
|
id=model.id,
|
|
email=model.email,
|
|
display_name=model.display_name,
|
|
avatar_url=model.avatar_url,
|
|
oauth_provider=model.oauth_provider,
|
|
oauth_id=model.oauth_id,
|
|
is_premium=model.is_premium,
|
|
premium_until=model.premium_until,
|
|
last_login=model.last_login,
|
|
created_at=model.created_at,
|
|
updated_at=model.updated_at,
|
|
)
|
|
|
|
|
|
class PostgresUserRepository:
|
|
"""PostgreSQL implementation of UserRepository.
|
|
|
|
Uses SQLAlchemy async sessions for database access.
|
|
|
|
Attributes:
|
|
_db: The async database session.
|
|
"""
|
|
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
"""Initialize with database session.
|
|
|
|
Args:
|
|
db: SQLAlchemy async session.
|
|
"""
|
|
self._db = db
|
|
|
|
async def get_by_id(self, user_id: UUID) -> UserEntry | None:
|
|
"""Get a user by their ID.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
|
|
Returns:
|
|
UserEntry if found, None otherwise.
|
|
"""
|
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
model = result.scalar_one_or_none()
|
|
return _to_dto(model) if model else None
|
|
|
|
async def get_by_email(self, email: str) -> UserEntry | None:
|
|
"""Get a user by their email address.
|
|
|
|
Args:
|
|
email: The user's email address.
|
|
|
|
Returns:
|
|
UserEntry if found, None otherwise.
|
|
"""
|
|
result = await self._db.execute(select(User).where(User.email == email))
|
|
model = result.scalar_one_or_none()
|
|
return _to_dto(model) if model else None
|
|
|
|
async def get_by_oauth(self, provider: str, oauth_id: str) -> UserEntry | None:
|
|
"""Get a user by their OAuth provider and ID.
|
|
|
|
Args:
|
|
provider: OAuth provider name (google, discord).
|
|
oauth_id: Unique ID from the OAuth provider.
|
|
|
|
Returns:
|
|
UserEntry if found, None otherwise.
|
|
"""
|
|
result = await self._db.execute(
|
|
select(User).where(
|
|
User.oauth_provider == provider,
|
|
User.oauth_id == oauth_id,
|
|
)
|
|
)
|
|
model = result.scalar_one_or_none()
|
|
return _to_dto(model) if model else None
|
|
|
|
async def create(
|
|
self,
|
|
email: str,
|
|
display_name: str,
|
|
oauth_provider: str,
|
|
oauth_id: str,
|
|
avatar_url: str | None = None,
|
|
) -> UserEntry:
|
|
"""Create a new user.
|
|
|
|
Args:
|
|
email: User's email address.
|
|
display_name: Public display name.
|
|
oauth_provider: OAuth provider name.
|
|
oauth_id: Unique ID from the OAuth provider.
|
|
avatar_url: Optional avatar URL.
|
|
|
|
Returns:
|
|
The created UserEntry.
|
|
"""
|
|
user = User(
|
|
email=email,
|
|
display_name=display_name,
|
|
oauth_provider=oauth_provider,
|
|
oauth_id=oauth_id,
|
|
avatar_url=avatar_url,
|
|
)
|
|
self._db.add(user)
|
|
await self._db.commit()
|
|
await self._db.refresh(user)
|
|
return _to_dto(user)
|
|
|
|
async def update(
|
|
self,
|
|
user_id: UUID,
|
|
display_name: str | None = None,
|
|
avatar_url: str | None = UNSET, # type: ignore[assignment]
|
|
oauth_provider: str | None = None,
|
|
oauth_id: str | None = None,
|
|
) -> UserEntry | None:
|
|
"""Update user profile fields.
|
|
|
|
Only provided (non-None/non-UNSET) fields are updated.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
display_name: New display name (None keeps existing).
|
|
avatar_url: New avatar URL (UNSET=keep, None=clear, str=set).
|
|
oauth_provider: New OAuth provider (None keeps existing).
|
|
oauth_id: New OAuth ID (None keeps existing).
|
|
|
|
Returns:
|
|
Updated UserEntry, or None if user not found.
|
|
"""
|
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user is None:
|
|
return None
|
|
|
|
if display_name is not None:
|
|
user.display_name = display_name
|
|
if avatar_url is not UNSET:
|
|
user.avatar_url = avatar_url
|
|
if oauth_provider is not None:
|
|
user.oauth_provider = oauth_provider
|
|
if oauth_id is not None:
|
|
user.oauth_id = oauth_id
|
|
|
|
await self._db.commit()
|
|
await self._db.refresh(user)
|
|
return _to_dto(user)
|
|
|
|
async def update_last_login(self, user_id: UUID) -> UserEntry | None:
|
|
"""Update the user's last login timestamp to now.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
|
|
Returns:
|
|
Updated UserEntry, or None if user not found.
|
|
"""
|
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user is None:
|
|
return None
|
|
|
|
user.last_login = datetime.now(UTC)
|
|
await self._db.commit()
|
|
await self._db.refresh(user)
|
|
return _to_dto(user)
|
|
|
|
async def update_premium(
|
|
self,
|
|
user_id: UUID,
|
|
is_premium: bool,
|
|
premium_until: datetime | None,
|
|
) -> UserEntry | None:
|
|
"""Update user's premium subscription status.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
is_premium: Whether user has premium.
|
|
premium_until: When premium expires, or None if not premium.
|
|
|
|
Returns:
|
|
Updated UserEntry, or None if user not found.
|
|
"""
|
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user is None:
|
|
return None
|
|
|
|
user.is_premium = is_premium
|
|
user.premium_until = premium_until
|
|
|
|
await self._db.commit()
|
|
await self._db.refresh(user)
|
|
return _to_dto(user)
|
|
|
|
async def delete(self, user_id: UUID) -> bool:
|
|
"""Delete a user account.
|
|
|
|
This will cascade delete all related data (decks, collection, etc.)
|
|
based on database constraints.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
|
|
Returns:
|
|
True if deleted, False if not found.
|
|
"""
|
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user is None:
|
|
return False
|
|
|
|
await self._db.delete(user)
|
|
await self._db.commit()
|
|
return True
|