Complete OAuth-based authentication with JWT session management:
Core Services:
- JWT service for access/refresh token creation and verification
- Token store with Redis-backed refresh token revocation
- User service for CRUD operations and OAuth-based creation
- Google and Discord OAuth services with full flow support
API Endpoints:
- GET /api/auth/{google,discord} - Start OAuth flows
- GET /api/auth/{google,discord}/callback - Handle OAuth callbacks
- POST /api/auth/refresh - Exchange refresh token for new access token
- POST /api/auth/logout - Revoke single refresh token
- POST /api/auth/logout-all - Revoke all user sessions
- GET/PATCH /api/users/me - User profile management
- GET /api/users/me/linked-accounts - List OAuth providers
- GET /api/users/me/sessions - Count active sessions
Infrastructure:
- Pydantic schemas for auth/user request/response models
- FastAPI dependencies (get_current_user, get_current_premium_user)
- OAuthLinkedAccount model for multi-provider support
- Alembic migration for oauth_linked_accounts table
Dependencies added: email-validator, fakeredis (dev), respx (dev)
84 new tests, 1058 total passing
173 lines
6.6 KiB
Python
173 lines
6.6 KiB
Python
"""Tests for users API endpoints.
|
|
|
|
Tests the user profile management endpoints.
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
from uuid import UUID
|
|
|
|
from fastapi import status
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
class TestGetCurrentUser:
|
|
"""Tests for GET /api/users/me endpoint."""
|
|
|
|
def test_returns_user_profile(self, client: TestClient, test_user, access_token):
|
|
"""Test that endpoint returns user profile for authenticated user.
|
|
|
|
Should return the user's profile information.
|
|
"""
|
|
with patch("app.api.deps.user_service") as mock_user_service:
|
|
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
response = client.get(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["email"] == test_user.email
|
|
assert data["display_name"] == test_user.display_name
|
|
assert data["avatar_url"] == test_user.avatar_url
|
|
assert data["is_premium"] == test_user.is_premium
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.get("/api/users/me")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
def test_returns_401_for_invalid_token(self, client: TestClient):
|
|
"""Test that endpoint returns 401 for invalid access token."""
|
|
response = client.get(
|
|
"/api/users/me",
|
|
headers={"Authorization": "Bearer invalid.token.here"},
|
|
)
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestUpdateCurrentUser:
|
|
"""Tests for PATCH /api/users/me endpoint."""
|
|
|
|
def test_updates_display_name(self, client: TestClient, test_user, access_token):
|
|
"""Test that endpoint updates display_name when provided."""
|
|
updated_user = test_user
|
|
updated_user.display_name = "New Name"
|
|
|
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
with patch("app.api.users.user_service") as mock_user_service:
|
|
mock_user_service.update = AsyncMock(return_value=updated_user)
|
|
|
|
response = client.patch(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
json={"display_name": "New Name"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["display_name"] == "New Name"
|
|
|
|
def test_updates_avatar_url(self, client: TestClient, test_user, access_token):
|
|
"""Test that endpoint updates avatar_url when provided."""
|
|
updated_user = test_user
|
|
updated_user.avatar_url = "https://new-avatar.com/img.jpg"
|
|
|
|
with patch("app.api.deps.user_service") as mock_deps_service:
|
|
mock_deps_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
with patch("app.api.users.user_service") as mock_user_service:
|
|
mock_user_service.update = AsyncMock(return_value=updated_user)
|
|
|
|
response = client.patch(
|
|
"/api/users/me",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
json={"avatar_url": "https://new-avatar.com/img.jpg"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["avatar_url"] == "https://new-avatar.com/img.jpg"
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.patch(
|
|
"/api/users/me",
|
|
json={"display_name": "New Name"},
|
|
)
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestGetLinkedAccounts:
|
|
"""Tests for GET /api/users/me/linked-accounts endpoint."""
|
|
|
|
def test_returns_linked_accounts(self, client: TestClient, test_user, access_token):
|
|
"""Test that endpoint returns list of linked OAuth accounts.
|
|
|
|
Should include the primary provider and any linked accounts.
|
|
"""
|
|
with patch("app.api.deps.user_service") as mock_user_service:
|
|
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
response = client.get(
|
|
"/api/users/me/linked-accounts",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
assert len(data) >= 1 # At least primary account
|
|
assert data[0]["provider"] == test_user.oauth_provider
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.get("/api/users/me/linked-accounts")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
class TestGetActiveSessions:
|
|
"""Tests for GET /api/users/me/sessions endpoint."""
|
|
|
|
def test_returns_session_count(
|
|
self, client: TestClient, test_user, access_token, mock_get_redis
|
|
):
|
|
"""Test that endpoint returns count of active sessions.
|
|
|
|
Should return the number of valid refresh tokens.
|
|
"""
|
|
|
|
user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id
|
|
|
|
# Store some tokens
|
|
import asyncio
|
|
|
|
async def setup_tokens():
|
|
async with mock_get_redis() as redis:
|
|
await redis.setex(f"refresh_token:{user_id}:jti-1", 86400, "1")
|
|
await redis.setex(f"refresh_token:{user_id}:jti-2", 86400, "1")
|
|
|
|
asyncio.get_event_loop().run_until_complete(setup_tokens())
|
|
|
|
with patch("app.api.deps.user_service") as mock_user_service:
|
|
mock_user_service.get_by_id = AsyncMock(return_value=test_user)
|
|
|
|
with patch("app.services.token_store.get_redis", mock_get_redis):
|
|
response = client.get(
|
|
"/api/users/me/sessions",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert "active_sessions" in data
|
|
assert data["active_sessions"] == 2
|
|
|
|
def test_requires_authentication(self, client: TestClient):
|
|
"""Test that endpoint returns 401 without authentication."""
|
|
response = client.get("/api/users/me/sessions")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|