The frontend routing guard checks has_starter_deck to decide whether to redirect users to starter selection. The field was missing from the API response, causing authenticated users with a starter deck to be incorrectly redirected to /starter on page refresh. - Add has_starter_deck computed property to User model - Add has_starter_deck field to UserResponse schema - Add unit tests for User model properties - Add API tests for has_starter_deck in profile response Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
463 lines
17 KiB
Python
463 lines
17 KiB
Python
"""Tests for users API endpoints.
|
|
|
|
Tests the user profile management endpoints.
|
|
|
|
Uses FastAPI's dependency override pattern for proper dependency injection testing.
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from uuid import UUID
|
|
|
|
import pytest
|
|
from fastapi import status
|
|
from fastapi.testclient import TestClient
|
|
|
|
from app.api.deps import get_user_service
|
|
from app.services.user_service import AccountLinkingError
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_user_service_instance():
|
|
"""Create a mock UserService for dependency injection.
|
|
|
|
Returns a MagicMock with async methods configured.
|
|
"""
|
|
mock = MagicMock()
|
|
mock.get_by_id = AsyncMock()
|
|
mock.get_by_email = AsyncMock()
|
|
mock.get_by_oauth = AsyncMock()
|
|
mock.update = AsyncMock()
|
|
mock.unlink_oauth_account = AsyncMock()
|
|
return mock
|
|
|
|
|
|
class TestGetCurrentUser:
|
|
"""Tests for GET /api/users/me endpoint."""
|
|
|
|
def test_returns_user_profile(
|
|
self, app, client: TestClient, test_user, access_token, mock_db_session
|
|
):
|
|
"""Test that endpoint returns user profile for authenticated user.
|
|
|
|
Should return the user's profile information including has_starter_deck.
|
|
"""
|
|
# Set up mock db session to return test user when queried
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
response = client.get(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["email"] == test_user.email
|
|
assert data["display_name"] == test_user.display_name
|
|
assert data["avatar_url"] == test_user.avatar_url
|
|
assert data["is_premium"] == test_user.is_premium
|
|
|
|
def test_returns_has_starter_deck_false_for_new_user(
|
|
self, app, client: TestClient, test_user, access_token, mock_db_session
|
|
):
|
|
"""Test that has_starter_deck is false for users without a starter deck.
|
|
|
|
New users who haven't selected a starter deck should have has_starter_deck=false.
|
|
This is used by the frontend to redirect to the starter selection page.
|
|
"""
|
|
# User has empty decks list (no starter selected)
|
|
test_user.decks = []
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
response = client.get(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["has_starter_deck"] is False
|
|
|
|
def test_returns_has_starter_deck_true_when_starter_selected(
|
|
self, app, client: TestClient, test_user, access_token, mock_db_session
|
|
):
|
|
"""Test that has_starter_deck is true after selecting a starter deck.
|
|
|
|
Users who have selected a starter deck should have has_starter_deck=true.
|
|
This allows the frontend to navigate to the dashboard instead of starter selection.
|
|
"""
|
|
# Create a mock starter deck
|
|
starter_deck = MagicMock()
|
|
starter_deck.is_starter = True
|
|
starter_deck.starter_type = "grass"
|
|
test_user.decks = [starter_deck]
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
response = client.get(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["has_starter_deck"] is True
|
|
|
|
def test_returns_has_starter_deck_false_for_non_starter_decks(
|
|
self, app, client: TestClient, test_user, access_token, mock_db_session
|
|
):
|
|
"""Test that has_starter_deck is false when user only has regular decks.
|
|
|
|
Users can have custom decks without having selected a starter.
|
|
The has_starter_deck field should only be true if a starter deck exists.
|
|
"""
|
|
# Create a mock regular deck (not a starter)
|
|
regular_deck = MagicMock()
|
|
regular_deck.is_starter = False
|
|
regular_deck.starter_type = None
|
|
test_user.decks = [regular_deck]
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
response = client.get(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["has_starter_deck"] is False
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.get("/api/users/me")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
def test_returns_401_for_invalid_token(self, client: TestClient):
|
|
"""Test that endpoint returns 401 for invalid access token."""
|
|
response = client.get(
|
|
"/api/users/me",
|
|
headers={"Authorization": "Bearer invalid.token.here"},
|
|
)
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestUpdateCurrentUser:
|
|
"""Tests for PATCH /api/users/me endpoint."""
|
|
|
|
def test_updates_display_name(
|
|
self,
|
|
app,
|
|
client: TestClient,
|
|
test_user,
|
|
access_token,
|
|
mock_db_session,
|
|
mock_user_service_instance,
|
|
):
|
|
"""Test that endpoint updates display_name when provided."""
|
|
# Create an updated user mock
|
|
updated_user = MagicMock()
|
|
updated_user.id = test_user.id
|
|
updated_user.email = test_user.email
|
|
updated_user.display_name = "New Name"
|
|
updated_user.avatar_url = test_user.avatar_url
|
|
updated_user.is_premium = test_user.is_premium
|
|
updated_user.premium_until = test_user.premium_until
|
|
updated_user.created_at = test_user.created_at
|
|
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
# Set up user service mock
|
|
mock_user_service_instance.update.return_value = updated_user
|
|
|
|
# Override the dependency
|
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
|
|
|
try:
|
|
response = client.patch(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
json={"display_name": "New Name"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["display_name"] == "New Name"
|
|
finally:
|
|
app.dependency_overrides.pop(get_user_service, None)
|
|
|
|
def test_updates_avatar_url(
|
|
self,
|
|
app,
|
|
client: TestClient,
|
|
test_user,
|
|
access_token,
|
|
mock_db_session,
|
|
mock_user_service_instance,
|
|
):
|
|
"""Test that endpoint updates avatar_url when provided."""
|
|
# Create an updated user mock
|
|
updated_user = MagicMock()
|
|
updated_user.id = test_user.id
|
|
updated_user.email = test_user.email
|
|
updated_user.display_name = test_user.display_name
|
|
updated_user.avatar_url = "https://new-avatar.com/img.jpg"
|
|
updated_user.is_premium = test_user.is_premium
|
|
updated_user.premium_until = test_user.premium_until
|
|
updated_user.created_at = test_user.created_at
|
|
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
# Set up user service mock
|
|
mock_user_service_instance.update.return_value = updated_user
|
|
|
|
# Override the dependency
|
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
|
|
|
try:
|
|
response = client.patch(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
json={"avatar_url": "https://new-avatar.com/img.jpg"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
|
|
finally:
|
|
app.dependency_overrides.pop(get_user_service, None)
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.patch(
|
|
"/api/users/me",
|
|
json={"display_name": "New Name"},
|
|
)
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestGetLinkedAccounts:
|
|
"""Tests for GET /api/users/me/linked-accounts endpoint."""
|
|
|
|
def test_returns_linked_accounts(
|
|
self, app, client: TestClient, test_user, access_token, mock_db_session
|
|
):
|
|
"""Test that endpoint returns list of linked OAuth accounts.
|
|
|
|
Should include the primary provider and any linked accounts.
|
|
"""
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
response = client.get(
|
|
"/api/users/me/linked-accounts",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
assert len(data) >= 1 # At least primary account
|
|
assert data[0]["provider"] == test_user.oauth_provider
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.get("/api/users/me/linked-accounts")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestGetActiveSessions:
|
|
"""Tests for GET /api/users/me/sessions endpoint."""
|
|
|
|
def test_returns_session_count(
|
|
self, app, client: TestClient, test_user, access_token, mock_get_redis, mock_db_session
|
|
):
|
|
"""Test that endpoint returns count of active sessions.
|
|
|
|
Should return the number of valid refresh tokens.
|
|
"""
|
|
|
|
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
|
|
|
|
# Store some tokens
|
|
import asyncio
|
|
|
|
async def setup_tokens():
|
|
async with mock_get_redis() as redis:
|
|
await redis.setex(f"refresh_token:{user_id}:jti-1", 86400, "1")
|
|
await redis.setex(f"refresh_token:{user_id}:jti-2", 86400, "1")
|
|
|
|
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
|
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
with patch("app.services.token_store.get_redis", mock_get_redis):
|
|
response = client.get(
|
|
"/api/users/me/sessions",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert "active_sessions" in data
|
|
assert data["active_sessions"] == 2
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.get("/api/users/me/sessions")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestUnlinkOAuthAccount:
|
|
"""Tests for DELETE /api/users/me/link/{provider} endpoint."""
|
|
|
|
def test_unlinks_provider_successfully(
|
|
self,
|
|
app,
|
|
client: TestClient,
|
|
test_user,
|
|
access_token,
|
|
mock_db_session,
|
|
mock_user_service_instance,
|
|
):
|
|
"""Test that endpoint successfully unlinks a provider.
|
|
|
|
Should return 204 when provider is unlinked.
|
|
"""
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
# Set up user service mock
|
|
mock_user_service_instance.unlink_oauth_account.return_value = True
|
|
|
|
# Override the dependency
|
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
|
|
|
try:
|
|
response = client.delete(
|
|
"/api/users/me/link/discord",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
|
finally:
|
|
app.dependency_overrides.pop(get_user_service, None)
|
|
|
|
def test_returns_404_if_not_linked(
|
|
self,
|
|
app,
|
|
client: TestClient,
|
|
test_user,
|
|
access_token,
|
|
mock_db_session,
|
|
mock_user_service_instance,
|
|
):
|
|
"""Test that endpoint returns 404 if provider isn't linked.
|
|
|
|
Should return 404 when trying to unlink a provider that isn't linked.
|
|
"""
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
# Set up user service mock
|
|
mock_user_service_instance.unlink_oauth_account.return_value = False
|
|
|
|
# Override the dependency
|
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
|
|
|
try:
|
|
response = client.delete(
|
|
"/api/users/me/link/discord",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
assert "not linked" in response.json()["detail"].lower()
|
|
finally:
|
|
app.dependency_overrides.pop(get_user_service, None)
|
|
|
|
def test_returns_400_for_primary_provider(
|
|
self,
|
|
app,
|
|
client: TestClient,
|
|
test_user,
|
|
access_token,
|
|
mock_db_session,
|
|
mock_user_service_instance,
|
|
):
|
|
"""Test that endpoint returns 400 when trying to unlink primary provider.
|
|
|
|
Cannot unlink the provider used to create the account.
|
|
"""
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
# Set up user service mock to raise AccountLinkingError
|
|
mock_user_service_instance.unlink_oauth_account.side_effect = AccountLinkingError(
|
|
"Cannot unlink Google - it is your primary login provider"
|
|
)
|
|
|
|
# Override the dependency
|
|
app.dependency_overrides[get_user_service] = lambda: mock_user_service_instance
|
|
|
|
try:
|
|
response = client.delete(
|
|
"/api/users/me/link/google",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
|
assert "primary" in response.json()["detail"].lower()
|
|
finally:
|
|
app.dependency_overrides.pop(get_user_service, None)
|
|
|
|
def test_returns_400_for_unknown_provider(
|
|
self, app, client: TestClient, test_user, access_token, mock_db_session
|
|
):
|
|
"""Test that endpoint returns 400 for unknown provider.
|
|
|
|
Only 'google' and 'discord' are valid providers.
|
|
"""
|
|
# Set up db session to return test user for authentication
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = test_user
|
|
mock_db_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
response = client.delete(
|
|
"/api/users/me/link/twitter",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
|
assert "unknown provider" in response.json()["detail"].lower()
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.delete("/api/users/me/link/discord")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|