Complete OAuth-based authentication with JWT session management:
Core Services:
- JWT service for access/refresh token creation and verification
- Token store with Redis-backed refresh token revocation
- User service for CRUD operations and OAuth-based creation
- Google and Discord OAuth services with full flow support
API Endpoints:
- GET /api/auth/{google,discord} - Start OAuth flows
- GET /api/auth/{google,discord}/callback - Handle OAuth callbacks
- POST /api/auth/refresh - Exchange refresh token for new access token
- POST /api/auth/logout - Revoke single refresh token
- POST /api/auth/logout-all - Revoke all user sessions
- GET/PATCH /api/users/me - User profile management
- GET /api/users/me/linked-accounts - List OAuth providers
- GET /api/users/me/sessions - Count active sessions
Infrastructure:
- Pydantic schemas for auth/user request/response models
- FastAPI dependencies (get_current_user, get_current_premium_user)
- OAuthLinkedAccount model for multi-provider support
- Alembic migration for oauth_linked_accounts table
Dependencies added: email-validator, fakeredis (dev), respx (dev)
84 new tests, 1058 total passing
408 lines
13 KiB
Python
408 lines
13 KiB
Python
"""Tests for UserService.
|
|
|
|
Tests the user service CRUD operations and OAuth-based user creation.
|
|
Uses real Postgres via the db_session fixture from conftest.
|
|
"""
|
|
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
import pytest
|
|
|
|
from app.db.models import User
|
|
from app.schemas.user import OAuthUserInfo, UserCreate, UserUpdate
|
|
from app.services.user_service import user_service
|
|
|
|
# Import db_session fixture from db conftest
|
|
pytestmark = pytest.mark.asyncio
|
|
|
|
|
|
class TestGetById:
|
|
"""Tests for get_by_id method."""
|
|
|
|
async def test_returns_user_when_found(self, db_session):
|
|
"""Test that get_by_id returns user when it exists.
|
|
|
|
Creates a user and verifies it can be retrieved by ID.
|
|
"""
|
|
# Create user directly
|
|
user = User(
|
|
email="test@example.com",
|
|
display_name="Test User",
|
|
oauth_provider="google",
|
|
oauth_id="123456",
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
|
|
# Retrieve by ID
|
|
from uuid import UUID
|
|
|
|
user_id = UUID(user.id) if isinstance(user.id, str) else user.id
|
|
result = await user_service.get_by_id(db_session, user_id)
|
|
|
|
assert result is not None
|
|
assert result.email == "test@example.com"
|
|
|
|
async def test_returns_none_when_not_found(self, db_session):
|
|
"""Test that get_by_id returns None for nonexistent users."""
|
|
from uuid import uuid4
|
|
|
|
result = await user_service.get_by_id(db_session, uuid4())
|
|
assert result is None
|
|
|
|
|
|
class TestGetByEmail:
|
|
"""Tests for get_by_email method."""
|
|
|
|
async def test_returns_user_when_found(self, db_session):
|
|
"""Test that get_by_email returns user when it exists."""
|
|
user = User(
|
|
email="findme@example.com",
|
|
display_name="Find Me",
|
|
oauth_provider="discord",
|
|
oauth_id="discord123",
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
|
|
result = await user_service.get_by_email(db_session, "findme@example.com")
|
|
|
|
assert result is not None
|
|
assert result.display_name == "Find Me"
|
|
|
|
async def test_returns_none_when_not_found(self, db_session):
|
|
"""Test that get_by_email returns None for nonexistent emails."""
|
|
result = await user_service.get_by_email(db_session, "nobody@example.com")
|
|
assert result is None
|
|
|
|
|
|
class TestGetByOAuth:
|
|
"""Tests for get_by_oauth method."""
|
|
|
|
async def test_returns_user_when_found(self, db_session):
|
|
"""Test that get_by_oauth returns user for matching provider+id."""
|
|
user = User(
|
|
email="oauth@example.com",
|
|
display_name="OAuth User",
|
|
oauth_provider="google",
|
|
oauth_id="google-unique-id",
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
|
|
result = await user_service.get_by_oauth(db_session, "google", "google-unique-id")
|
|
|
|
assert result is not None
|
|
assert result.email == "oauth@example.com"
|
|
|
|
async def test_returns_none_for_wrong_provider(self, db_session):
|
|
"""Test that get_by_oauth returns None if provider doesn't match."""
|
|
user = User(
|
|
email="oauth2@example.com",
|
|
display_name="OAuth User 2",
|
|
oauth_provider="google",
|
|
oauth_id="google-id-2",
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
|
|
# Same ID, different provider
|
|
result = await user_service.get_by_oauth(db_session, "discord", "google-id-2")
|
|
assert result is None
|
|
|
|
async def test_returns_none_when_not_found(self, db_session):
|
|
"""Test that get_by_oauth returns None for nonexistent OAuth."""
|
|
result = await user_service.get_by_oauth(db_session, "google", "nonexistent")
|
|
assert result is None
|
|
|
|
|
|
class TestCreate:
|
|
"""Tests for create method."""
|
|
|
|
async def test_creates_user_with_all_fields(self, db_session):
|
|
"""Test that create properly persists all user fields."""
|
|
user_data = UserCreate(
|
|
email="new@example.com",
|
|
display_name="New User",
|
|
avatar_url="https://example.com/avatar.jpg",
|
|
oauth_provider="discord",
|
|
oauth_id="discord-new-id",
|
|
)
|
|
|
|
result = await user_service.create(db_session, user_data)
|
|
|
|
assert result.id is not None
|
|
assert result.email == "new@example.com"
|
|
assert result.display_name == "New User"
|
|
assert result.avatar_url == "https://example.com/avatar.jpg"
|
|
assert result.oauth_provider == "discord"
|
|
assert result.oauth_id == "discord-new-id"
|
|
assert result.is_premium is False
|
|
assert result.premium_until is None
|
|
|
|
async def test_creates_user_without_avatar(self, db_session):
|
|
"""Test that create works without optional avatar_url."""
|
|
user_data = UserCreate(
|
|
email="noavatar@example.com",
|
|
display_name="No Avatar",
|
|
oauth_provider="google",
|
|
oauth_id="google-no-avatar",
|
|
)
|
|
|
|
result = await user_service.create(db_session, user_data)
|
|
|
|
assert result.avatar_url is None
|
|
|
|
|
|
class TestCreateFromOAuth:
|
|
"""Tests for create_from_oauth method."""
|
|
|
|
async def test_creates_user_from_oauth_info(self, db_session):
|
|
"""Test that create_from_oauth converts OAuthUserInfo to User."""
|
|
oauth_info = OAuthUserInfo(
|
|
provider="google",
|
|
oauth_id="google-oauth-123",
|
|
email="oauthcreate@example.com",
|
|
name="OAuth Created User",
|
|
avatar_url="https://google.com/avatar.jpg",
|
|
)
|
|
|
|
result = await user_service.create_from_oauth(db_session, oauth_info)
|
|
|
|
assert result.email == "oauthcreate@example.com"
|
|
assert result.display_name == "OAuth Created User"
|
|
assert result.oauth_provider == "google"
|
|
assert result.oauth_id == "google-oauth-123"
|
|
|
|
|
|
class TestGetOrCreateFromOAuth:
|
|
"""Tests for get_or_create_from_oauth method."""
|
|
|
|
async def test_returns_existing_user_by_oauth(self, db_session):
|
|
"""Test that existing user is returned when OAuth matches.
|
|
|
|
Verifies the method returns (user, False) for existing users.
|
|
"""
|
|
# Create existing user
|
|
existing = User(
|
|
email="existing@example.com",
|
|
display_name="Existing",
|
|
oauth_provider="google",
|
|
oauth_id="existing-oauth-id",
|
|
)
|
|
db_session.add(existing)
|
|
await db_session.commit()
|
|
|
|
# Try to get or create with same OAuth
|
|
oauth_info = OAuthUserInfo(
|
|
provider="google",
|
|
oauth_id="existing-oauth-id",
|
|
email="existing@example.com",
|
|
name="Existing",
|
|
)
|
|
|
|
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
|
|
|
assert created is False
|
|
assert result.id == existing.id
|
|
|
|
async def test_links_existing_user_by_email(self, db_session):
|
|
"""Test that OAuth is linked when email matches existing user.
|
|
|
|
If a user exists with the same email but different OAuth,
|
|
the new OAuth should be linked to the existing account.
|
|
"""
|
|
# Create user with Google
|
|
existing = User(
|
|
email="link@example.com",
|
|
display_name="Link Me",
|
|
oauth_provider="google",
|
|
oauth_id="google-link-id",
|
|
)
|
|
db_session.add(existing)
|
|
await db_session.commit()
|
|
|
|
# Login with Discord (same email)
|
|
oauth_info = OAuthUserInfo(
|
|
provider="discord",
|
|
oauth_id="discord-link-id",
|
|
email="link@example.com",
|
|
name="Link Me",
|
|
avatar_url="https://discord.com/avatar.jpg",
|
|
)
|
|
|
|
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
|
|
|
assert created is False
|
|
assert result.id == existing.id
|
|
# OAuth should be updated to Discord
|
|
assert result.oauth_provider == "discord"
|
|
assert result.oauth_id == "discord-link-id"
|
|
|
|
async def test_creates_new_user_when_not_found(self, db_session):
|
|
"""Test that new user is created when no match exists.
|
|
|
|
Verifies the method returns (user, True) for new users.
|
|
"""
|
|
oauth_info = OAuthUserInfo(
|
|
provider="discord",
|
|
oauth_id="brand-new-id",
|
|
email="brandnew@example.com",
|
|
name="Brand New",
|
|
)
|
|
|
|
result, created = await user_service.get_or_create_from_oauth(db_session, oauth_info)
|
|
|
|
assert created is True
|
|
assert result.email == "brandnew@example.com"
|
|
|
|
|
|
class TestUpdate:
|
|
"""Tests for update method."""
|
|
|
|
async def test_updates_display_name(self, db_session):
|
|
"""Test that update changes display_name when provided."""
|
|
user = User(
|
|
email="update@example.com",
|
|
display_name="Old Name",
|
|
oauth_provider="google",
|
|
oauth_id="update-id",
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
|
|
update_data = UserUpdate(display_name="New Name")
|
|
result = await user_service.update(db_session, user, update_data)
|
|
|
|
assert result.display_name == "New Name"
|
|
|
|
async def test_updates_avatar_url(self, db_session):
|
|
"""Test that update changes avatar_url when provided."""
|
|
user = User(
|
|
email="avatar@example.com",
|
|
display_name="Avatar User",
|
|
oauth_provider="google",
|
|
oauth_id="avatar-id",
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
|
|
update_data = UserUpdate(avatar_url="https://new-avatar.com/img.jpg")
|
|
result = await user_service.update(db_session, user, update_data)
|
|
|
|
assert result.avatar_url == "https://new-avatar.com/img.jpg"
|
|
|
|
async def test_ignores_none_values(self, db_session):
|
|
"""Test that update doesn't change fields set to None.
|
|
|
|
Only explicitly provided fields should be updated.
|
|
"""
|
|
user = User(
|
|
email="keep@example.com",
|
|
display_name="Keep Me",
|
|
avatar_url="https://keep.com/avatar.jpg",
|
|
oauth_provider="google",
|
|
oauth_id="keep-id",
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
|
|
# Update only display_name, leave avatar alone
|
|
update_data = UserUpdate(display_name="Changed")
|
|
result = await user_service.update(db_session, user, update_data)
|
|
|
|
assert result.display_name == "Changed"
|
|
assert result.avatar_url == "https://keep.com/avatar.jpg"
|
|
|
|
|
|
class TestUpdateLastLogin:
|
|
"""Tests for update_last_login method."""
|
|
|
|
async def test_updates_last_login_timestamp(self, db_session):
|
|
"""Test that update_last_login sets current timestamp."""
|
|
user = User(
|
|
email="login@example.com",
|
|
display_name="Login User",
|
|
oauth_provider="google",
|
|
oauth_id="login-id",
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
|
|
assert user.last_login is None
|
|
|
|
before = datetime.now(UTC)
|
|
result = await user_service.update_last_login(db_session, user)
|
|
after = datetime.now(UTC)
|
|
|
|
assert result.last_login is not None
|
|
# Allow 1 second tolerance
|
|
assert before - timedelta(seconds=1) <= result.last_login <= after + timedelta(seconds=1)
|
|
|
|
|
|
class TestUpdatePremium:
|
|
"""Tests for update_premium method."""
|
|
|
|
async def test_grants_premium(self, db_session):
|
|
"""Test that update_premium sets premium status and expiration."""
|
|
user = User(
|
|
email="premium@example.com",
|
|
display_name="Premium User",
|
|
oauth_provider="google",
|
|
oauth_id="premium-id",
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
|
|
assert user.is_premium is False
|
|
|
|
expires = datetime.now(UTC) + timedelta(days=30)
|
|
result = await user_service.update_premium(db_session, user, expires)
|
|
|
|
assert result.is_premium is True
|
|
assert result.premium_until == expires
|
|
|
|
async def test_removes_premium(self, db_session):
|
|
"""Test that update_premium with None removes premium status."""
|
|
user = User(
|
|
email="unpremium@example.com",
|
|
display_name="Unpremium User",
|
|
oauth_provider="google",
|
|
oauth_id="unpremium-id",
|
|
is_premium=True,
|
|
premium_until=datetime.now(UTC) + timedelta(days=30),
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
|
|
result = await user_service.update_premium(db_session, user, None)
|
|
|
|
assert result.is_premium is False
|
|
assert result.premium_until is None
|
|
|
|
|
|
class TestDelete:
|
|
"""Tests for delete method."""
|
|
|
|
async def test_deletes_user(self, db_session):
|
|
"""Test that delete removes user from database."""
|
|
user = User(
|
|
email="delete@example.com",
|
|
display_name="Delete Me",
|
|
oauth_provider="google",
|
|
oauth_id="delete-id",
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
|
|
user_id = user.id
|
|
await user_service.delete(db_session, user)
|
|
|
|
# Verify user is gone
|
|
from uuid import UUID
|
|
|
|
result = await user_service.get_by_id(
|
|
db_session, UUID(user_id) if isinstance(user_id, str) else user_id
|
|
)
|
|
assert result is None
|