- Add UserRepository and LinkedAccountRepository protocols to protocols.py - Add UserEntry and LinkedAccountEntry DTOs for service layer decoupling - Implement PostgresUserRepository and PostgresLinkedAccountRepository - Refactor UserService to use constructor-injected repositories - Add get_user_service factory and UserServiceDep to API deps - Update auth.py and users.py endpoints to use UserServiceDep - Rewrite tests to use FastAPI dependency overrides (no monkey patching) This follows the established repository pattern used by DeckService and CollectionService, enabling future offline fork support. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
385 lines
14 KiB
Python
385 lines
14 KiB
Python
"""Tests for users API endpoints.
|
|
|
|
Tests the user profile management endpoints.
|
|
|
|
Uses FastAPI's dependency override pattern for proper dependency injection testing.
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from uuid import UUID
|
|
|
|
import pytest
|
|
from fastapi import status
|
|
from fastapi.testclient import TestClient
|
|
|
|
from app.api.deps import get_user_service
|
|
from app.services.user_service import AccountLinkingError
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_user_service_instance():
|
|
"""Create a mock UserService for dependency injection.
|
|
|
|
Returns a MagicMock with async methods configured.
|
|
"""
|
|
mock = MagicMock()
|
|
mock.get_by_id = AsyncMock()
|
|
mock.get_by_email = AsyncMock()
|
|
mock.get_by_oauth = AsyncMock()
|
|
mock.update = AsyncMock()
|
|
mock.unlink_oauth_account = AsyncMock()
|
|
return mock
|
|
|
|
|
|
class TestGetCurrentUser:
|
|
"""Tests for GET /api/users/me endpoint."""
|
|
|
|
def test_returns_user_profile(
|
|
self, app, client: TestClient, test_user, access_token, mock_db_session
|
|
):
|
|
"""Test that endpoint returns user profile for authenticated user.
|
|
|
|
Should return the user's profile information.
|
|
"""
|
|
# Set up mock db session to return test user when queried
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
response = client.get(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["email"] == test_user.email
|
|
assert data["display_name"] == test_user.display_name
|
|
assert data["avatar_url"] == test_user.avatar_url
|
|
assert data["is_premium"] == test_user.is_premium
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.get("/api/users/me")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
def test_returns_401_for_invalid_token(self, client: TestClient):
|
|
"""Test that endpoint returns 401 for invalid access token."""
|
|
response = client.get(
|
|
"/api/users/me",
|
|
headers={"Authorization": "Bearer invalid.token.here"},
|
|
)
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestUpdateCurrentUser:
|
|
"""Tests for PATCH /api/users/me endpoint."""
|
|
|
|
def test_updates_display_name(
|
|
self,
|
|
app,
|
|
client: TestClient,
|
|
test_user,
|
|
access_token,
|
|
mock_db_session,
|
|
mock_user_service_instance,
|
|
):
|
|
"""Test that endpoint updates display_name when provided."""
|
|
# Create an updated user mock
|
|
updated_user = MagicMock()
|
|
updated_user.id = test_user.id
|
|
updated_user.email = test_user.email
|
|
updated_user.display_name = "New Name"
|
|
updated_user.avatar_url = test_user.avatar_url
|
|
updated_user.is_premium = test_user.is_premium
|
|
updated_user.premium_until = test_user.premium_until
|
|
updated_user.created_at = test_user.created_at
|
|
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
# Set up user service mock
|
|
mock_user_service_instance.update.return_value = updated_user
|
|
|
|
# Override the dependency
|
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
|
|
|
try:
|
|
response = client.patch(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
json={"display_name": "New Name"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["display_name"] == "New Name"
|
|
finally:
|
|
app.dependency_overrides.pop(get_user_service, None)
|
|
|
|
def test_updates_avatar_url(
|
|
self,
|
|
app,
|
|
client: TestClient,
|
|
test_user,
|
|
access_token,
|
|
mock_db_session,
|
|
mock_user_service_instance,
|
|
):
|
|
"""Test that endpoint updates avatar_url when provided."""
|
|
# Create an updated user mock
|
|
updated_user = MagicMock()
|
|
updated_user.id = test_user.id
|
|
updated_user.email = test_user.email
|
|
updated_user.display_name = test_user.display_name
|
|
updated_user.avatar_url = "https://new-avatar.com/img.jpg"
|
|
updated_user.is_premium = test_user.is_premium
|
|
updated_user.premium_until = test_user.premium_until
|
|
updated_user.created_at = test_user.created_at
|
|
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
# Set up user service mock
|
|
mock_user_service_instance.update.return_value = updated_user
|
|
|
|
# Override the dependency
|
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
|
|
|
try:
|
|
response = client.patch(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
json={"avatar_url": "https://new-avatar.com/img.jpg"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
|
|
finally:
|
|
app.dependency_overrides.pop(get_user_service, None)
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.patch(
|
|
"/api/users/me",
|
|
json={"display_name": "New Name"},
|
|
)
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestGetLinkedAccounts:
|
|
"""Tests for GET /api/users/me/linked-accounts endpoint."""
|
|
|
|
def test_returns_linked_accounts(
|
|
self, app, client: TestClient, test_user, access_token, mock_db_session
|
|
):
|
|
"""Test that endpoint returns list of linked OAuth accounts.
|
|
|
|
Should include the primary provider and any linked accounts.
|
|
"""
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
response = client.get(
|
|
"/api/users/me/linked-accounts",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
assert len(data) >= 1 # At least primary account
|
|
assert data[0]["provider"] == test_user.oauth_provider
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.get("/api/users/me/linked-accounts")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestGetActiveSessions:
|
|
"""Tests for GET /api/users/me/sessions endpoint."""
|
|
|
|
def test_returns_session_count(
|
|
self, app, client: TestClient, test_user, access_token, mock_get_redis, mock_db_session
|
|
):
|
|
"""Test that endpoint returns count of active sessions.
|
|
|
|
Should return the number of valid refresh tokens.
|
|
"""
|
|
|
|
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
|
|
|
|
# Store some tokens
|
|
import asyncio
|
|
|
|
async def setup_tokens():
|
|
async with mock_get_redis() as redis:
|
|
await redis.setex(f"refresh_token:{user_id}:jti-1", 86400, "1")
|
|
await redis.setex(f"refresh_token:{user_id}:jti-2", 86400, "1")
|
|
|
|
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
|
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
with patch("app.services.token_store.get_redis", mock_get_redis):
|
|
response = client.get(
|
|
"/api/users/me/sessions",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert "active_sessions" in data
|
|
assert data["active_sessions"] == 2
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.get("/api/users/me/sessions")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestUnlinkOAuthAccount:
|
|
"""Tests for DELETE /api/users/me/link/{provider} endpoint."""
|
|
|
|
def test_unlinks_provider_successfully(
|
|
self,
|
|
app,
|
|
client: TestClient,
|
|
test_user,
|
|
access_token,
|
|
mock_db_session,
|
|
mock_user_service_instance,
|
|
):
|
|
"""Test that endpoint successfully unlinks a provider.
|
|
|
|
Should return 204 when provider is unlinked.
|
|
"""
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
# Set up user service mock
|
|
mock_user_service_instance.unlink_oauth_account.return_value = True
|
|
|
|
# Override the dependency
|
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
|
|
|
try:
|
|
response = client.delete(
|
|
"/api/users/me/link/discord",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
|
finally:
|
|
app.dependency_overrides.pop(get_user_service, None)
|
|
|
|
def test_returns_404_if_not_linked(
|
|
self,
|
|
app,
|
|
client: TestClient,
|
|
test_user,
|
|
access_token,
|
|
mock_db_session,
|
|
mock_user_service_instance,
|
|
):
|
|
"""Test that endpoint returns 404 if provider isn't linked.
|
|
|
|
Should return 404 when trying to unlink a provider that isn't linked.
|
|
"""
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
# Set up user service mock
|
|
mock_user_service_instance.unlink_oauth_account.return_value = False
|
|
|
|
# Override the dependency
|
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
|
|
|
try:
|
|
response = client.delete(
|
|
"/api/users/me/link/discord",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
assert "not linked" in response.json()["detail"].lower()
|
|
finally:
|
|
app.dependency_overrides.pop(get_user_service, None)
|
|
|
|
def test_returns_400_for_primary_provider(
|
|
self,
|
|
app,
|
|
client: TestClient,
|
|
test_user,
|
|
access_token,
|
|
mock_db_session,
|
|
mock_user_service_instance,
|
|
):
|
|
"""Test that endpoint returns 400 when trying to unlink primary provider.
|
|
|
|
Cannot unlink the provider used to create the account.
|
|
"""
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
# Set up user service mock to raise AccountLinkingError
|
|
mock_user_service_instance.unlink_oauth_account.side_effect = AccountLinkingError(
|
|
"Cannot unlink Google - it is your primary login provider"
|
|
)
|
|
|
|
# Override the dependency
|
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
|
|
|
try:
|
|
response = client.delete(
|
|
"/api/users/me/link/google",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
|
assert "primary" in response.json()["detail"].lower()
|
|
finally:
|
|
app.dependency_overrides.pop(get_user_service, None)
|
|
|
|
def test_returns_400_for_unknown_provider(
|
|
self, app, client: TestClient, test_user, access_token, mock_db_session
|
|
):
|
|
"""Test that endpoint returns 400 for unknown provider.
|
|
|
|
Only 'google' and 'discord' are valid providers.
|
|
"""
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
response = client.delete(
|
|
"/api/users/me/link/twitter",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
|
assert "unknown provider" in response.json()["detail"].lower()
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.delete("/api/users/me/link/discord")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|