Complete OAuth-based authentication with JWT session management:
Core Services:
- JWT service for access/refresh token creation and verification
- Token store with Redis-backed refresh token revocation
- User service for CRUD operations and OAuth-based creation
- Google and Discord OAuth services with full flow support
API Endpoints:
- GET /api/auth/{google,discord} - Start OAuth flows
- GET /api/auth/{google,discord}/callback - Handle OAuth callbacks
- POST /api/auth/refresh - Exchange refresh token for new access token
- POST /api/auth/logout - Revoke single refresh token
- POST /api/auth/logout-all - Revoke all user sessions
- GET/PATCH /api/users/me - User profile management
- GET /api/users/me/linked-accounts - List OAuth providers
- GET /api/users/me/sessions - Count active sessions
Infrastructure:
- Pydantic schemas for auth/user request/response models
- FastAPI dependencies (get_current_user, get_current_premium_user)
- OAuthLinkedAccount model for multi-provider support
- Alembic migration for oauth_linked_accounts table
Dependencies added: email-validator, fakeredis (dev), respx (dev)
84 new tests, 1058 total passing
371 lines
12 KiB
Python
371 lines
12 KiB
Python
"""Tests for JWT service.
|
|
|
|
Tests the JWT token creation and verification functions used for
|
|
authentication throughout the application.
|
|
"""
|
|
|
|
import uuid
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
from jose import jwt
|
|
|
|
from app.config import settings
|
|
from app.schemas.auth import TokenType
|
|
from app.services.jwt_service import (
|
|
create_access_token,
|
|
create_refresh_token,
|
|
decode_token,
|
|
get_refresh_token_expiration,
|
|
get_token_expiration_seconds,
|
|
verify_access_token,
|
|
verify_refresh_token,
|
|
)
|
|
|
|
|
|
class TestCreateAccessToken:
|
|
"""Tests for create_access_token function."""
|
|
|
|
def test_creates_valid_jwt(self):
|
|
"""Test that create_access_token returns a valid JWT string.
|
|
|
|
The returned token should be decodable and contain the expected
|
|
claims including subject, expiration, and token type.
|
|
"""
|
|
user_id = uuid.uuid4()
|
|
token = create_access_token(user_id)
|
|
|
|
# Should be a valid JWT (three dot-separated parts)
|
|
assert isinstance(token, str)
|
|
assert token.count(".") == 2
|
|
|
|
# Should be decodable
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.secret_key.get_secret_value(),
|
|
algorithms=[settings.jwt_algorithm],
|
|
)
|
|
assert payload["sub"] == str(user_id)
|
|
assert payload["type"] == TokenType.ACCESS.value
|
|
|
|
def test_sets_correct_expiration(self):
|
|
"""Test that access token expiration matches configured setting.
|
|
|
|
The token should expire approximately jwt_expire_minutes from now.
|
|
JWT timestamps have second precision, so we allow 1 second tolerance.
|
|
"""
|
|
user_id = uuid.uuid4()
|
|
before = datetime.now(UTC)
|
|
token = create_access_token(user_id)
|
|
after = datetime.now(UTC)
|
|
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.secret_key.get_secret_value(),
|
|
algorithms=[settings.jwt_algorithm],
|
|
)
|
|
|
|
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
|
expected_min = (
|
|
before + timedelta(minutes=settings.jwt_expire_minutes) - timedelta(seconds=1)
|
|
)
|
|
expected_max = after + timedelta(minutes=settings.jwt_expire_minutes) + timedelta(seconds=1)
|
|
|
|
assert expected_min <= exp <= expected_max
|
|
|
|
def test_includes_issued_at(self):
|
|
"""Test that access token includes iat (issued at) claim.
|
|
|
|
The iat claim should be set to approximately the current time.
|
|
JWT timestamps have second precision, so we allow 1 second tolerance.
|
|
"""
|
|
user_id = uuid.uuid4()
|
|
before = datetime.now(UTC)
|
|
token = create_access_token(user_id)
|
|
after = datetime.now(UTC)
|
|
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.secret_key.get_secret_value(),
|
|
algorithms=[settings.jwt_algorithm],
|
|
)
|
|
|
|
iat = datetime.fromtimestamp(payload["iat"], tz=UTC)
|
|
assert before - timedelta(seconds=1) <= iat <= after + timedelta(seconds=1)
|
|
|
|
|
|
class TestCreateRefreshToken:
|
|
"""Tests for create_refresh_token function."""
|
|
|
|
def test_creates_valid_jwt_with_jti(self):
|
|
"""Test that create_refresh_token returns a valid JWT and JTI.
|
|
|
|
The function should return a tuple of (token, jti) where the
|
|
token contains the JTI for revocation tracking.
|
|
"""
|
|
user_id = uuid.uuid4()
|
|
token, jti = create_refresh_token(user_id)
|
|
|
|
# Should return token and jti
|
|
assert isinstance(token, str)
|
|
assert isinstance(jti, str)
|
|
assert token.count(".") == 2
|
|
|
|
# JTI should be a valid UUID
|
|
uuid.UUID(jti) # Will raise if invalid
|
|
|
|
# Token should contain the JTI
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.secret_key.get_secret_value(),
|
|
algorithms=[settings.jwt_algorithm],
|
|
)
|
|
assert payload["jti"] == jti
|
|
assert payload["type"] == TokenType.REFRESH.value
|
|
|
|
def test_sets_correct_expiration(self):
|
|
"""Test that refresh token expiration matches configured setting.
|
|
|
|
The token should expire approximately jwt_refresh_expire_days from now.
|
|
JWT timestamps have second precision, so we allow 1 second tolerance.
|
|
"""
|
|
user_id = uuid.uuid4()
|
|
before = datetime.now(UTC)
|
|
token, _ = create_refresh_token(user_id)
|
|
after = datetime.now(UTC)
|
|
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.secret_key.get_secret_value(),
|
|
algorithms=[settings.jwt_algorithm],
|
|
)
|
|
|
|
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
|
expected_min = (
|
|
before + timedelta(days=settings.jwt_refresh_expire_days) - timedelta(seconds=1)
|
|
)
|
|
expected_max = (
|
|
after + timedelta(days=settings.jwt_refresh_expire_days) + timedelta(seconds=1)
|
|
)
|
|
|
|
assert expected_min <= exp <= expected_max
|
|
|
|
def test_generates_unique_jti(self):
|
|
"""Test that each refresh token gets a unique JTI.
|
|
|
|
Multiple calls should generate different JTIs to ensure
|
|
each token can be individually revoked.
|
|
"""
|
|
user_id = uuid.uuid4()
|
|
_, jti1 = create_refresh_token(user_id)
|
|
_, jti2 = create_refresh_token(user_id)
|
|
|
|
assert jti1 != jti2
|
|
|
|
|
|
class TestDecodeToken:
|
|
"""Tests for decode_token function."""
|
|
|
|
def test_decodes_valid_token(self):
|
|
"""Test that decode_token returns TokenPayload for valid tokens.
|
|
|
|
A valid token should be decoded into a TokenPayload with
|
|
all expected fields populated.
|
|
"""
|
|
user_id = uuid.uuid4()
|
|
token = create_access_token(user_id)
|
|
|
|
payload = decode_token(token)
|
|
|
|
assert payload is not None
|
|
assert payload.sub == str(user_id)
|
|
assert payload.type == TokenType.ACCESS
|
|
assert payload.exp is not None
|
|
assert payload.iat is not None
|
|
|
|
def test_returns_none_for_invalid_token(self):
|
|
"""Test that decode_token returns None for malformed tokens.
|
|
|
|
Invalid JWT strings should not raise exceptions but return None.
|
|
"""
|
|
result = decode_token("invalid.token.here")
|
|
assert result is None
|
|
|
|
def test_returns_none_for_wrong_signature(self):
|
|
"""Test that decode_token returns None for tokens with wrong signature.
|
|
|
|
Tokens signed with a different key should be rejected.
|
|
"""
|
|
user_id = uuid.uuid4()
|
|
# Create token with different secret
|
|
payload = {
|
|
"sub": str(user_id),
|
|
"exp": datetime.now(UTC) + timedelta(hours=1),
|
|
"iat": datetime.now(UTC),
|
|
"type": "access",
|
|
}
|
|
token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
|
|
|
result = decode_token(token)
|
|
assert result is None
|
|
|
|
def test_returns_none_for_expired_token(self):
|
|
"""Test that decode_token returns None for expired tokens.
|
|
|
|
Tokens past their expiration should be rejected.
|
|
"""
|
|
user_id = uuid.uuid4()
|
|
payload = {
|
|
"sub": str(user_id),
|
|
"exp": datetime.now(UTC) - timedelta(hours=1), # Already expired
|
|
"iat": datetime.now(UTC) - timedelta(hours=2),
|
|
"type": "access",
|
|
}
|
|
token = jwt.encode(
|
|
payload,
|
|
settings.secret_key.get_secret_value(),
|
|
algorithm=settings.jwt_algorithm,
|
|
)
|
|
|
|
result = decode_token(token)
|
|
assert result is None
|
|
|
|
|
|
class TestVerifyAccessToken:
|
|
"""Tests for verify_access_token function."""
|
|
|
|
def test_returns_user_id_for_valid_access_token(self):
|
|
"""Test that verify_access_token returns user ID for valid tokens.
|
|
|
|
A valid access token should return the UUID of the user.
|
|
"""
|
|
user_id = uuid.uuid4()
|
|
token = create_access_token(user_id)
|
|
|
|
result = verify_access_token(token)
|
|
|
|
assert result == user_id
|
|
|
|
def test_returns_none_for_refresh_token(self):
|
|
"""Test that verify_access_token rejects refresh tokens.
|
|
|
|
Even valid refresh tokens should be rejected when verifying
|
|
as access tokens to prevent token type confusion.
|
|
"""
|
|
user_id = uuid.uuid4()
|
|
token, _ = create_refresh_token(user_id)
|
|
|
|
result = verify_access_token(token)
|
|
|
|
assert result is None
|
|
|
|
def test_returns_none_for_invalid_token(self):
|
|
"""Test that verify_access_token returns None for invalid tokens."""
|
|
result = verify_access_token("invalid.token.here")
|
|
assert result is None
|
|
|
|
def test_returns_none_for_invalid_uuid_subject(self):
|
|
"""Test that verify_access_token returns None for non-UUID subject.
|
|
|
|
If the subject claim is not a valid UUID, the token should be rejected.
|
|
"""
|
|
payload = {
|
|
"sub": "not-a-uuid",
|
|
"exp": datetime.now(UTC) + timedelta(hours=1),
|
|
"iat": datetime.now(UTC),
|
|
"type": "access",
|
|
}
|
|
token = jwt.encode(
|
|
payload,
|
|
settings.secret_key.get_secret_value(),
|
|
algorithm=settings.jwt_algorithm,
|
|
)
|
|
|
|
result = verify_access_token(token)
|
|
assert result is None
|
|
|
|
|
|
class TestVerifyRefreshToken:
|
|
"""Tests for verify_refresh_token function."""
|
|
|
|
def test_returns_user_id_and_jti_for_valid_refresh_token(self):
|
|
"""Test that verify_refresh_token returns user ID and JTI.
|
|
|
|
A valid refresh token should return both values needed for
|
|
revocation checking.
|
|
"""
|
|
user_id = uuid.uuid4()
|
|
token, jti = create_refresh_token(user_id)
|
|
|
|
result = verify_refresh_token(token)
|
|
|
|
assert result is not None
|
|
result_user_id, result_jti = result
|
|
assert result_user_id == user_id
|
|
assert result_jti == jti
|
|
|
|
def test_returns_none_for_access_token(self):
|
|
"""Test that verify_refresh_token rejects access tokens.
|
|
|
|
Even valid access tokens should be rejected when verifying
|
|
as refresh tokens.
|
|
"""
|
|
user_id = uuid.uuid4()
|
|
token = create_access_token(user_id)
|
|
|
|
result = verify_refresh_token(token)
|
|
|
|
assert result is None
|
|
|
|
def test_returns_none_for_token_without_jti(self):
|
|
"""Test that verify_refresh_token rejects tokens missing JTI.
|
|
|
|
Refresh tokens must have a JTI for revocation tracking.
|
|
"""
|
|
payload = {
|
|
"sub": str(uuid.uuid4()),
|
|
"exp": datetime.now(UTC) + timedelta(days=7),
|
|
"iat": datetime.now(UTC),
|
|
"type": "refresh",
|
|
# No jti
|
|
}
|
|
token = jwt.encode(
|
|
payload,
|
|
settings.secret_key.get_secret_value(),
|
|
algorithm=settings.jwt_algorithm,
|
|
)
|
|
|
|
result = verify_refresh_token(token)
|
|
assert result is None
|
|
|
|
def test_returns_none_for_invalid_token(self):
|
|
"""Test that verify_refresh_token returns None for invalid tokens."""
|
|
result = verify_refresh_token("invalid.token.here")
|
|
assert result is None
|
|
|
|
|
|
class TestHelperFunctions:
|
|
"""Tests for helper functions."""
|
|
|
|
def test_get_token_expiration_seconds(self):
|
|
"""Test that get_token_expiration_seconds returns correct value.
|
|
|
|
Should return jwt_expire_minutes converted to seconds.
|
|
"""
|
|
result = get_token_expiration_seconds()
|
|
assert result == settings.jwt_expire_minutes * 60
|
|
|
|
def test_get_refresh_token_expiration(self):
|
|
"""Test that get_refresh_token_expiration returns future datetime.
|
|
|
|
Should return a datetime approximately jwt_refresh_expire_days
|
|
in the future.
|
|
"""
|
|
before = datetime.now(UTC)
|
|
result = get_refresh_token_expiration()
|
|
after = datetime.now(UTC)
|
|
|
|
expected_min = before + timedelta(days=settings.jwt_refresh_expire_days)
|
|
expected_max = after + timedelta(days=settings.jwt_refresh_expire_days)
|
|
|
|
assert expected_min <= result <= expected_max
|