mantimon-tcg/backend/tests/services/test_jwt_service.py
Cal Corum 996c43fbd9 Implement Phase 2: Authentication system
Complete OAuth-based authentication with JWT session management:

Core Services:
- JWT service for access/refresh token creation and verification
- Token store with Redis-backed refresh token revocation
- User service for CRUD operations and OAuth-based creation
- Google and Discord OAuth services with full flow support

API Endpoints:
- GET /api/auth/{google,discord} - Start OAuth flows
- GET /api/auth/{google,discord}/callback - Handle OAuth callbacks
- POST /api/auth/refresh - Exchange refresh token for new access token
- POST /api/auth/logout - Revoke single refresh token
- POST /api/auth/logout-all - Revoke all user sessions
- GET/PATCH /api/users/me - User profile management
- GET /api/users/me/linked-accounts - List OAuth providers
- GET /api/users/me/sessions - Count active sessions

Infrastructure:
- Pydantic schemas for auth/user request/response models
- FastAPI dependencies (get_current_user, get_current_premium_user)
- OAuthLinkedAccount model for multi-provider support
- Alembic migration for oauth_linked_accounts table

Dependencies added: email-validator, fakeredis (dev), respx (dev)

84 new tests, 1058 total passing
2026-01-27 21:49:59 -06:00

371 lines
12 KiB
Python

"""Tests for JWT service.
Tests the JWT token creation and verification functions used for
authentication throughout the application.
"""
import uuid
from datetime import UTC, datetime, timedelta
from jose import jwt
from app.config import settings
from app.schemas.auth import TokenType
from app.services.jwt_service import (
create_access_token,
create_refresh_token,
decode_token,
get_refresh_token_expiration,
get_token_expiration_seconds,
verify_access_token,
verify_refresh_token,
)
class TestCreateAccessToken:
"""Tests for create_access_token function."""
def test_creates_valid_jwt(self):
"""Test that create_access_token returns a valid JWT string.
The returned token should be decodable and contain the expected
claims including subject, expiration, and token type.
"""
user_id = uuid.uuid4()
token = create_access_token(user_id)
# Should be a valid JWT (three dot-separated parts)
assert isinstance(token, str)
assert token.count(".") == 2
# Should be decodable
payload = jwt.decode(
token,
settings.secret_key.get_secret_value(),
algorithms=[settings.jwt_algorithm],
)
assert payload["sub"] == str(user_id)
assert payload["type"] == TokenType.ACCESS.value
def test_sets_correct_expiration(self):
"""Test that access token expiration matches configured setting.
The token should expire approximately jwt_expire_minutes from now.
JWT timestamps have second precision, so we allow 1 second tolerance.
"""
user_id = uuid.uuid4()
before = datetime.now(UTC)
token = create_access_token(user_id)
after = datetime.now(UTC)
payload = jwt.decode(
token,
settings.secret_key.get_secret_value(),
algorithms=[settings.jwt_algorithm],
)
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
expected_min = (
before + timedelta(minutes=settings.jwt_expire_minutes) - timedelta(seconds=1)
)
expected_max = after + timedelta(minutes=settings.jwt_expire_minutes) + timedelta(seconds=1)
assert expected_min <= exp <= expected_max
def test_includes_issued_at(self):
"""Test that access token includes iat (issued at) claim.
The iat claim should be set to approximately the current time.
JWT timestamps have second precision, so we allow 1 second tolerance.
"""
user_id = uuid.uuid4()
before = datetime.now(UTC)
token = create_access_token(user_id)
after = datetime.now(UTC)
payload = jwt.decode(
token,
settings.secret_key.get_secret_value(),
algorithms=[settings.jwt_algorithm],
)
iat = datetime.fromtimestamp(payload["iat"], tz=UTC)
assert before - timedelta(seconds=1) <= iat <= after + timedelta(seconds=1)
class TestCreateRefreshToken:
"""Tests for create_refresh_token function."""
def test_creates_valid_jwt_with_jti(self):
"""Test that create_refresh_token returns a valid JWT and JTI.
The function should return a tuple of (token, jti) where the
token contains the JTI for revocation tracking.
"""
user_id = uuid.uuid4()
token, jti = create_refresh_token(user_id)
# Should return token and jti
assert isinstance(token, str)
assert isinstance(jti, str)
assert token.count(".") == 2
# JTI should be a valid UUID
uuid.UUID(jti) # Will raise if invalid
# Token should contain the JTI
payload = jwt.decode(
token,
settings.secret_key.get_secret_value(),
algorithms=[settings.jwt_algorithm],
)
assert payload["jti"] == jti
assert payload["type"] == TokenType.REFRESH.value
def test_sets_correct_expiration(self):
"""Test that refresh token expiration matches configured setting.
The token should expire approximately jwt_refresh_expire_days from now.
JWT timestamps have second precision, so we allow 1 second tolerance.
"""
user_id = uuid.uuid4()
before = datetime.now(UTC)
token, _ = create_refresh_token(user_id)
after = datetime.now(UTC)
payload = jwt.decode(
token,
settings.secret_key.get_secret_value(),
algorithms=[settings.jwt_algorithm],
)
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
expected_min = (
before + timedelta(days=settings.jwt_refresh_expire_days) - timedelta(seconds=1)
)
expected_max = (
after + timedelta(days=settings.jwt_refresh_expire_days) + timedelta(seconds=1)
)
assert expected_min <= exp <= expected_max
def test_generates_unique_jti(self):
"""Test that each refresh token gets a unique JTI.
Multiple calls should generate different JTIs to ensure
each token can be individually revoked.
"""
user_id = uuid.uuid4()
_, jti1 = create_refresh_token(user_id)
_, jti2 = create_refresh_token(user_id)
assert jti1 != jti2
class TestDecodeToken:
"""Tests for decode_token function."""
def test_decodes_valid_token(self):
"""Test that decode_token returns TokenPayload for valid tokens.
A valid token should be decoded into a TokenPayload with
all expected fields populated.
"""
user_id = uuid.uuid4()
token = create_access_token(user_id)
payload = decode_token(token)
assert payload is not None
assert payload.sub == str(user_id)
assert payload.type == TokenType.ACCESS
assert payload.exp is not None
assert payload.iat is not None
def test_returns_none_for_invalid_token(self):
"""Test that decode_token returns None for malformed tokens.
Invalid JWT strings should not raise exceptions but return None.
"""
result = decode_token("invalid.token.here")
assert result is None
def test_returns_none_for_wrong_signature(self):
"""Test that decode_token returns None for tokens with wrong signature.
Tokens signed with a different key should be rejected.
"""
user_id = uuid.uuid4()
# Create token with different secret
payload = {
"sub": str(user_id),
"exp": datetime.now(UTC) + timedelta(hours=1),
"iat": datetime.now(UTC),
"type": "access",
}
token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
result = decode_token(token)
assert result is None
def test_returns_none_for_expired_token(self):
"""Test that decode_token returns None for expired tokens.
Tokens past their expiration should be rejected.
"""
user_id = uuid.uuid4()
payload = {
"sub": str(user_id),
"exp": datetime.now(UTC) - timedelta(hours=1), # Already expired
"iat": datetime.now(UTC) - timedelta(hours=2),
"type": "access",
}
token = jwt.encode(
payload,
settings.secret_key.get_secret_value(),
algorithm=settings.jwt_algorithm,
)
result = decode_token(token)
assert result is None
class TestVerifyAccessToken:
"""Tests for verify_access_token function."""
def test_returns_user_id_for_valid_access_token(self):
"""Test that verify_access_token returns user ID for valid tokens.
A valid access token should return the UUID of the user.
"""
user_id = uuid.uuid4()
token = create_access_token(user_id)
result = verify_access_token(token)
assert result == user_id
def test_returns_none_for_refresh_token(self):
"""Test that verify_access_token rejects refresh tokens.
Even valid refresh tokens should be rejected when verifying
as access tokens to prevent token type confusion.
"""
user_id = uuid.uuid4()
token, _ = create_refresh_token(user_id)
result = verify_access_token(token)
assert result is None
def test_returns_none_for_invalid_token(self):
"""Test that verify_access_token returns None for invalid tokens."""
result = verify_access_token("invalid.token.here")
assert result is None
def test_returns_none_for_invalid_uuid_subject(self):
"""Test that verify_access_token returns None for non-UUID subject.
If the subject claim is not a valid UUID, the token should be rejected.
"""
payload = {
"sub": "not-a-uuid",
"exp": datetime.now(UTC) + timedelta(hours=1),
"iat": datetime.now(UTC),
"type": "access",
}
token = jwt.encode(
payload,
settings.secret_key.get_secret_value(),
algorithm=settings.jwt_algorithm,
)
result = verify_access_token(token)
assert result is None
class TestVerifyRefreshToken:
"""Tests for verify_refresh_token function."""
def test_returns_user_id_and_jti_for_valid_refresh_token(self):
"""Test that verify_refresh_token returns user ID and JTI.
A valid refresh token should return both values needed for
revocation checking.
"""
user_id = uuid.uuid4()
token, jti = create_refresh_token(user_id)
result = verify_refresh_token(token)
assert result is not None
result_user_id, result_jti = result
assert result_user_id == user_id
assert result_jti == jti
def test_returns_none_for_access_token(self):
"""Test that verify_refresh_token rejects access tokens.
Even valid access tokens should be rejected when verifying
as refresh tokens.
"""
user_id = uuid.uuid4()
token = create_access_token(user_id)
result = verify_refresh_token(token)
assert result is None
def test_returns_none_for_token_without_jti(self):
"""Test that verify_refresh_token rejects tokens missing JTI.
Refresh tokens must have a JTI for revocation tracking.
"""
payload = {
"sub": str(uuid.uuid4()),
"exp": datetime.now(UTC) + timedelta(days=7),
"iat": datetime.now(UTC),
"type": "refresh",
# No jti
}
token = jwt.encode(
payload,
settings.secret_key.get_secret_value(),
algorithm=settings.jwt_algorithm,
)
result = verify_refresh_token(token)
assert result is None
def test_returns_none_for_invalid_token(self):
"""Test that verify_refresh_token returns None for invalid tokens."""
result = verify_refresh_token("invalid.token.here")
assert result is None
class TestHelperFunctions:
"""Tests for helper functions."""
def test_get_token_expiration_seconds(self):
"""Test that get_token_expiration_seconds returns correct value.
Should return jwt_expire_minutes converted to seconds.
"""
result = get_token_expiration_seconds()
assert result == settings.jwt_expire_minutes * 60
def test_get_refresh_token_expiration(self):
"""Test that get_refresh_token_expiration returns future datetime.
Should return a datetime approximately jwt_refresh_expire_days
in the future.
"""
before = datetime.now(UTC)
result = get_refresh_token_expiration()
after = datetime.now(UTC)
expected_min = before + timedelta(days=settings.jwt_refresh_expire_days)
expected_max = after + timedelta(days=settings.jwt_refresh_expire_days)
assert expected_min <= result <= expected_max