- Add base_url config setting for OAuth callback URLs
- Change OAuth callbacks from relative to absolute URLs
- Add account linking OAuth flow (GET /auth/link/{provider})
- Add unlink endpoint (DELETE /users/me/link/{provider})
- Add AccountLinkingError and service methods for linking
- Add 14 new tests for linking functionality
- Update Phase 2 plan to mark complete (1072 tests passing)
260 lines
10 KiB
Python
260 lines
10 KiB
Python
"""Tests for users API endpoints.
|
|
|
|
Tests the user profile management endpoints.
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
from uuid import UUID
|
|
|
|
from fastapi import status
|
|
from fastapi.testclient import TestClient
|
|
|
|
from app.services.user_service import AccountLinkingError
|
|
|
|
|
|
class TestGetCurrentUser:
|
|
"""Tests for GET /api/users/me endpoint."""
|
|
|
|
def test_returns_user_profile(self, client: TestClient, test_user, access_token):
|
|
"""Test that endpoint returns user profile for authenticated user.
|
|
|
|
Should return the user's profile information.
|
|
"""
|
|
with patch("app.api.deps.user_service") as mock_user_service:
|
|
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
response = client.get(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["email"] == test_user.email
|
|
assert data["display_name"] == test_user.display_name
|
|
assert data["avatar_url"] == test_user.avatar_url
|
|
assert data["is_premium"] == test_user.is_premium
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.get("/api/users/me")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
def test_returns_401_for_invalid_token(self, client: TestClient):
|
|
"""Test that endpoint returns 401 for invalid access token."""
|
|
response = client.get(
|
|
"/api/users/me",
|
|
headers={"Authorization": "Bearer invalid.token.here"},
|
|
)
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestUpdateCurrentUser:
|
|
"""Tests for PATCH /api/users/me endpoint."""
|
|
|
|
def test_updates_display_name(self, client: TestClient, test_user, access_token):
|
|
"""Test that endpoint updates display_name when provided."""
|
|
updated_user = test_user
|
|
updated_user.display_name = "New Name"
|
|
|
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
with patch("app.api.users.user_service") as mock_user_service:
|
|
mock_user_service.update = AsyncMock(return_value=updated_user)
|
|
|
|
response = client.patch(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
json={"display_name": "New Name"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["display_name"] == "New Name"
|
|
|
|
def test_updates_avatar_url(self, client: TestClient, test_user, access_token):
|
|
"""Test that endpoint updates avatar_url when provided."""
|
|
updated_user = test_user
|
|
updated_user.avatar_url = "https://new-avatar.com/img.jpg"
|
|
|
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
with patch("app.api.users.user_service") as mock_user_service:
|
|
mock_user_service.update = AsyncMock(return_value=updated_user)
|
|
|
|
response = client.patch(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
json={"avatar_url": "https://new-avatar.com/img.jpg"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.patch(
|
|
"/api/users/me",
|
|
json={"display_name": "New Name"},
|
|
)
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestGetLinkedAccounts:
|
|
"""Tests for GET /api/users/me/linked-accounts endpoint."""
|
|
|
|
def test_returns_linked_accounts(self, client: TestClient, test_user, access_token):
|
|
"""Test that endpoint returns list of linked OAuth accounts.
|
|
|
|
Should include the primary provider and any linked accounts.
|
|
"""
|
|
with patch("app.api.deps.user_service") as mock_user_service:
|
|
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
response = client.get(
|
|
"/api/users/me/linked-accounts",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
assert len(data) >= 1 # At least primary account
|
|
assert data[0]["provider"] == test_user.oauth_provider
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.get("/api/users/me/linked-accounts")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestGetActiveSessions:
|
|
"""Tests for GET /api/users/me/sessions endpoint."""
|
|
|
|
def test_returns_session_count(
|
|
self, client: TestClient, test_user, access_token, mock_get_redis
|
|
):
|
|
"""Test that endpoint returns count of active sessions.
|
|
|
|
Should return the number of valid refresh tokens.
|
|
"""
|
|
|
|
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
|
|
|
|
# Store some tokens
|
|
import asyncio
|
|
|
|
async def setup_tokens():
|
|
async with mock_get_redis() as redis:
|
|
await redis.setex(f"refresh_token:{user_id}:jti-1", 86400, "1")
|
|
await redis.setex(f"refresh_token:{user_id}:jti-2", 86400, "1")
|
|
|
|
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
|
|
|
with patch("app.api.deps.user_service") as mock_user_service:
|
|
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
with patch("app.services.token_store.get_redis", mock_get_redis):
|
|
response = client.get(
|
|
"/api/users/me/sessions",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert "active_sessions" in data
|
|
assert data["active_sessions"] == 2
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.get("/api/users/me/sessions")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestUnlinkOAuthAccount:
|
|
"""Tests for DELETE /api/users/me/link/{provider} endpoint."""
|
|
|
|
def test_unlinks_provider_successfully(self, client: TestClient, test_user, access_token):
|
|
"""Test that endpoint successfully unlinks a provider.
|
|
|
|
Should return 204 when provider is unlinked.
|
|
"""
|
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
with patch("app.api.users.user_service") as mock_user_service:
|
|
mock_user_service.unlink_oauth_account = AsyncMock(return_value=True)
|
|
|
|
response = client.delete(
|
|
"/api/users/me/link/discord",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
|
|
|
def test_returns_404_if_not_linked(self, client: TestClient, test_user, access_token):
|
|
"""Test that endpoint returns 404 if provider isn't linked.
|
|
|
|
Should return 404 when trying to unlink a provider that isn't linked.
|
|
"""
|
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
with patch("app.api.users.user_service") as mock_user_service:
|
|
mock_user_service.unlink_oauth_account = AsyncMock(return_value=False)
|
|
|
|
response = client.delete(
|
|
"/api/users/me/link/discord",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
assert "not linked" in response.json()["detail"].lower()
|
|
|
|
def test_returns_400_for_primary_provider(self, client: TestClient, test_user, access_token):
|
|
"""Test that endpoint returns 400 when trying to unlink primary provider.
|
|
|
|
Cannot unlink the provider used to create the account.
|
|
"""
|
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
with patch("app.api.users.user_service") as mock_user_service:
|
|
mock_user_service.unlink_oauth_account = AsyncMock(
|
|
side_effect=AccountLinkingError(
|
|
"Cannot unlink Google - it is your primary login provider"
|
|
)
|
|
)
|
|
|
|
response = client.delete(
|
|
"/api/users/me/link/google",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
|
assert "primary" in response.json()["detail"].lower()
|
|
|
|
def test_returns_400_for_unknown_provider(self, client: TestClient, test_user, access_token):
|
|
"""Test that endpoint returns 400 for unknown provider.
|
|
|
|
Only 'google' and 'discord' are valid providers.
|
|
"""
|
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
response = client.delete(
|
|
"/api/users/me/link/twitter",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
|
assert "unknown provider" in response.json()["detail"].lower()
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.delete("/api/users/me/link/discord")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|