"""Tests for users API endpoints. Tests the user profile management endpoints. """ from unittest.mock import AsyncMock, patch from uuid import UUID from fastapi import status from fastapi.testclient import TestClient class TestGetCurrentUser: """Tests for GET /api/users/me endpoint.""" def test_returns_user_profile(self, client: TestClient, test_user, access_token): """Test that endpoint returns user profile for authenticated user. Should return the user's profile information. """ with patch("app.api.deps.user_service") as mock_user_service: mock_user_service.get_by_id = AsyncMock(return_value=test_user) response = client.get( "/api/users/me", headers={"Authorization": f"Bearer {access_token}"}, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["email"] == test_user.email assert data["display_name"] == test_user.display_name assert data["avatar_url"] == test_user.avatar_url assert data["is_premium"] == test_user.is_premium def test_requires_authentication(self, client: TestClient): """Test that endpoint returns 401 without authentication.""" response = client.get("/api/users/me") assert response.status_code == status.HTTP_401_UNAUTHORIZED def test_returns_401_for_invalid_token(self, client: TestClient): """Test that endpoint returns 401 for invalid access token.""" response = client.get( "/api/users/me", headers={"Authorization": "Bearer invalid.token.here"}, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED class TestUpdateCurrentUser: """Tests for PATCH /api/users/me endpoint.""" def test_updates_display_name(self, client: TestClient, test_user, access_token): """Test that endpoint updates display_name when provided.""" updated_user = test_user updated_user.display_name = "New Name" with patch("app.api.deps.user_service") as mock_deps_service: mock_deps_service.get_by_id = AsyncMock(return_value=test_user) with patch("app.api.users.user_service") as mock_user_service: mock_user_service.update = AsyncMock(return_value=updated_user) response = client.patch( "/api/users/me", headers={"Authorization": f"Bearer {access_token}"}, json={"display_name": "New Name"}, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["display_name"] == "New Name" def test_updates_avatar_url(self, client: TestClient, test_user, access_token): """Test that endpoint updates avatar_url when provided.""" updated_user = test_user updated_user.avatar_url = "https://new-avatar.com/img.jpg" with patch("app.api.deps.user_service") as mock_deps_service: mock_deps_service.get_by_id = AsyncMock(return_value=test_user) with patch("app.api.users.user_service") as mock_user_service: mock_user_service.update = AsyncMock(return_value=updated_user) response = client.patch( "/api/users/me", headers={"Authorization": f"Bearer {access_token}"}, json={"avatar_url": "https://new-avatar.com/img.jpg"}, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["avatar_url"] == "https://new-avatar.com/img.jpg" def test_requires_authentication(self, client: TestClient): """Test that endpoint returns 401 without authentication.""" response = client.patch( "/api/users/me", json={"display_name": "New Name"}, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED class TestGetLinkedAccounts: """Tests for GET /api/users/me/linked-accounts endpoint.""" def test_returns_linked_accounts(self, client: TestClient, test_user, access_token): """Test that endpoint returns list of linked OAuth accounts. Should include the primary provider and any linked accounts. """ with patch("app.api.deps.user_service") as mock_user_service: mock_user_service.get_by_id = AsyncMock(return_value=test_user) response = client.get( "/api/users/me/linked-accounts", headers={"Authorization": f"Bearer {access_token}"}, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert isinstance(data, list) assert len(data) >= 1 # At least primary account assert data[0]["provider"] == test_user.oauth_provider def test_requires_authentication(self, client: TestClient): """Test that endpoint returns 401 without authentication.""" response = client.get("/api/users/me/linked-accounts") assert response.status_code == status.HTTP_401_UNAUTHORIZED class TestGetActiveSessions: """Tests for GET /api/users/me/sessions endpoint.""" def test_returns_session_count( self, client: TestClient, test_user, access_token, mock_get_redis ): """Test that endpoint returns count of active sessions. Should return the number of valid refresh tokens. """ user_id = UUID(test_user.id) if isinstance(test_user.id, str) else test_user.id # Store some tokens import asyncio async def setup_tokens(): async with mock_get_redis() as redis: await redis.setex(f"refresh_token:{user_id}:jti-1", 86400, "1") await redis.setex(f"refresh_token:{user_id}:jti-2", 86400, "1") asyncio.get_event_loop().run_until_complete(setup_tokens()) with patch("app.api.deps.user_service") as mock_user_service: mock_user_service.get_by_id = AsyncMock(return_value=test_user) with patch("app.services.token_store.get_redis", mock_get_redis): response = client.get( "/api/users/me/sessions", headers={"Authorization": f"Bearer {access_token}"}, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert "active_sessions" in data assert data["active_sessions"] == 2 def test_requires_authentication(self, client: TestClient): """Test that endpoint returns 401 without authentication.""" response = client.get("/api/users/me/sessions") assert response.status_code == status.HTTP_401_UNAUTHORIZED