- Add UserRepository and LinkedAccountRepository protocols to protocols.py - Add UserEntry and LinkedAccountEntry DTOs for service layer decoupling - Implement PostgresUserRepository and PostgresLinkedAccountRepository - Refactor UserService to use constructor-injected repositories - Add get_user_service factory and UserServiceDep to API deps - Update auth.py and users.py endpoints to use UserServiceDep - Rewrite tests to use FastAPI dependency overrides (no monkey patching) This follows the established repository pattern used by DeckService and CollectionService, enabling future offline fork support. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
150 lines
4.4 KiB
Python
150 lines
4.4 KiB
Python
"""PostgreSQL implementation of LinkedAccountRepository.
|
|
|
|
This module provides the PostgreSQL-specific implementation of the
|
|
LinkedAccountRepository protocol using SQLAlchemy async sessions.
|
|
|
|
Example:
|
|
async with get_db_session() as db:
|
|
repo = PostgresLinkedAccountRepository(db)
|
|
accounts = await repo.get_by_user(user_id)
|
|
"""
|
|
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models.oauth_account import OAuthLinkedAccount
|
|
from app.repositories.protocols import LinkedAccountEntry
|
|
|
|
|
|
def _to_dto(model: OAuthLinkedAccount) -> LinkedAccountEntry:
|
|
"""Convert ORM model to DTO."""
|
|
return LinkedAccountEntry(
|
|
id=model.id,
|
|
user_id=UUID(model.user_id) if isinstance(model.user_id, str) else model.user_id,
|
|
provider=model.provider,
|
|
oauth_id=model.oauth_id,
|
|
email=model.email,
|
|
display_name=model.display_name,
|
|
avatar_url=model.avatar_url,
|
|
linked_at=model.linked_at,
|
|
)
|
|
|
|
|
|
class PostgresLinkedAccountRepository:
|
|
"""PostgreSQL implementation of LinkedAccountRepository.
|
|
|
|
Uses SQLAlchemy async sessions for database access.
|
|
|
|
Attributes:
|
|
_db: The async database session.
|
|
"""
|
|
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
"""Initialize with database session.
|
|
|
|
Args:
|
|
db: SQLAlchemy async session.
|
|
"""
|
|
self._db = db
|
|
|
|
async def get_by_provider(
|
|
self,
|
|
provider: str,
|
|
oauth_id: str,
|
|
) -> LinkedAccountEntry | None:
|
|
"""Get a linked account by provider and OAuth ID.
|
|
|
|
Args:
|
|
provider: OAuth provider name (google, discord).
|
|
oauth_id: Unique ID from the OAuth provider.
|
|
|
|
Returns:
|
|
LinkedAccountEntry if found, None otherwise.
|
|
"""
|
|
result = await self._db.execute(
|
|
select(OAuthLinkedAccount).where(
|
|
OAuthLinkedAccount.provider == provider,
|
|
OAuthLinkedAccount.oauth_id == oauth_id,
|
|
)
|
|
)
|
|
model = result.scalar_one_or_none()
|
|
return _to_dto(model) if model else None
|
|
|
|
async def get_by_user(self, user_id: UUID) -> list[LinkedAccountEntry]:
|
|
"""Get all linked accounts for a user.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
|
|
Returns:
|
|
List of LinkedAccountEntry, ordered by provider.
|
|
"""
|
|
result = await self._db.execute(
|
|
select(OAuthLinkedAccount)
|
|
.where(OAuthLinkedAccount.user_id == str(user_id))
|
|
.order_by(OAuthLinkedAccount.provider)
|
|
)
|
|
return [_to_dto(model) for model in result.scalars().all()]
|
|
|
|
async def create(
|
|
self,
|
|
user_id: UUID,
|
|
provider: str,
|
|
oauth_id: str,
|
|
email: str | None = None,
|
|
display_name: str | None = None,
|
|
avatar_url: str | None = None,
|
|
) -> LinkedAccountEntry:
|
|
"""Link an OAuth provider to a user account.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
provider: OAuth provider name.
|
|
oauth_id: Unique ID from the OAuth provider.
|
|
email: Email from the OAuth provider.
|
|
display_name: Display name from the OAuth provider.
|
|
avatar_url: Avatar URL from the OAuth provider.
|
|
|
|
Returns:
|
|
The created LinkedAccountEntry.
|
|
"""
|
|
linked_account = OAuthLinkedAccount(
|
|
user_id=str(user_id),
|
|
provider=provider,
|
|
oauth_id=oauth_id,
|
|
email=email,
|
|
display_name=display_name,
|
|
avatar_url=avatar_url,
|
|
)
|
|
self._db.add(linked_account)
|
|
await self._db.commit()
|
|
await self._db.refresh(linked_account)
|
|
return _to_dto(linked_account)
|
|
|
|
async def delete(self, user_id: UUID, provider: str) -> bool:
|
|
"""Unlink an OAuth provider from a user account.
|
|
|
|
Args:
|
|
user_id: The user's UUID.
|
|
provider: OAuth provider name to unlink.
|
|
|
|
Returns:
|
|
True if deleted, False if not found.
|
|
"""
|
|
result = await self._db.execute(
|
|
select(OAuthLinkedAccount).where(
|
|
OAuthLinkedAccount.user_id == str(user_id),
|
|
OAuthLinkedAccount.provider == provider,
|
|
)
|
|
)
|
|
linked_account = result.scalar_one_or_none()
|
|
|
|
if linked_account is None:
|
|
return False
|
|
|
|
await self._db.delete(linked_account)
|
|
await self._db.commit()
|
|
return True
|