"""Tests for JWT service. Tests the JWT token creation and verification functions used for authentication throughout the application. """ import uuid from datetime import UTC, datetime, timedelta from jose import jwt from app.config import settings from app.schemas.auth import TokenType from app.services.jwt_service import ( create_access_token, create_refresh_token, decode_token, get_refresh_token_expiration, get_token_expiration_seconds, verify_access_token, verify_refresh_token, ) class TestCreateAccessToken: """Tests for create_access_token function.""" def test_creates_valid_jwt(self): """Test that create_access_token returns a valid JWT string. The returned token should be decodable and contain the expected claims including subject, expiration, and token type. """ user_id = uuid.uuid4() token = create_access_token(user_id) # Should be a valid JWT (three dot-separated parts) assert isinstance(token, str) assert token.count(".") == 2 # Should be decodable payload = jwt.decode( token, settings.secret_key.get_secret_value(), algorithms=[settings.jwt_algorithm], ) assert payload["sub"] == str(user_id) assert payload["type"] == TokenType.ACCESS.value def test_sets_correct_expiration(self): """Test that access token expiration matches configured setting. The token should expire approximately jwt_expire_minutes from now. JWT timestamps have second precision, so we allow 1 second tolerance. """ user_id = uuid.uuid4() before = datetime.now(UTC) token = create_access_token(user_id) after = datetime.now(UTC) payload = jwt.decode( token, settings.secret_key.get_secret_value(), algorithms=[settings.jwt_algorithm], ) exp = datetime.fromtimestamp(payload["exp"], tz=UTC) expected_min = ( before + timedelta(minutes=settings.jwt_expire_minutes) - timedelta(seconds=1) ) expected_max = after + timedelta(minutes=settings.jwt_expire_minutes) + timedelta(seconds=1) assert expected_min <= exp <= expected_max def test_includes_issued_at(self): """Test that access token includes iat (issued at) claim. The iat claim should be set to approximately the current time. JWT timestamps have second precision, so we allow 1 second tolerance. """ user_id = uuid.uuid4() before = datetime.now(UTC) token = create_access_token(user_id) after = datetime.now(UTC) payload = jwt.decode( token, settings.secret_key.get_secret_value(), algorithms=[settings.jwt_algorithm], ) iat = datetime.fromtimestamp(payload["iat"], tz=UTC) assert before - timedelta(seconds=1) <= iat <= after + timedelta(seconds=1) class TestCreateRefreshToken: """Tests for create_refresh_token function.""" def test_creates_valid_jwt_with_jti(self): """Test that create_refresh_token returns a valid JWT and JTI. The function should return a tuple of (token, jti) where the token contains the JTI for revocation tracking. """ user_id = uuid.uuid4() token, jti = create_refresh_token(user_id) # Should return token and jti assert isinstance(token, str) assert isinstance(jti, str) assert token.count(".") == 2 # JTI should be a valid UUID uuid.UUID(jti) # Will raise if invalid # Token should contain the JTI payload = jwt.decode( token, settings.secret_key.get_secret_value(), algorithms=[settings.jwt_algorithm], ) assert payload["jti"] == jti assert payload["type"] == TokenType.REFRESH.value def test_sets_correct_expiration(self): """Test that refresh token expiration matches configured setting. The token should expire approximately jwt_refresh_expire_days from now. JWT timestamps have second precision, so we allow 1 second tolerance. """ user_id = uuid.uuid4() before = datetime.now(UTC) token, _ = create_refresh_token(user_id) after = datetime.now(UTC) payload = jwt.decode( token, settings.secret_key.get_secret_value(), algorithms=[settings.jwt_algorithm], ) exp = datetime.fromtimestamp(payload["exp"], tz=UTC) expected_min = ( before + timedelta(days=settings.jwt_refresh_expire_days) - timedelta(seconds=1) ) expected_max = ( after + timedelta(days=settings.jwt_refresh_expire_days) + timedelta(seconds=1) ) assert expected_min <= exp <= expected_max def test_generates_unique_jti(self): """Test that each refresh token gets a unique JTI. Multiple calls should generate different JTIs to ensure each token can be individually revoked. """ user_id = uuid.uuid4() _, jti1 = create_refresh_token(user_id) _, jti2 = create_refresh_token(user_id) assert jti1 != jti2 class TestDecodeToken: """Tests for decode_token function.""" def test_decodes_valid_token(self): """Test that decode_token returns TokenPayload for valid tokens. A valid token should be decoded into a TokenPayload with all expected fields populated. """ user_id = uuid.uuid4() token = create_access_token(user_id) payload = decode_token(token) assert payload is not None assert payload.sub == str(user_id) assert payload.type == TokenType.ACCESS assert payload.exp is not None assert payload.iat is not None def test_returns_none_for_invalid_token(self): """Test that decode_token returns None for malformed tokens. Invalid JWT strings should not raise exceptions but return None. """ result = decode_token("invalid.token.here") assert result is None def test_returns_none_for_wrong_signature(self): """Test that decode_token returns None for tokens with wrong signature. Tokens signed with a different key should be rejected. """ user_id = uuid.uuid4() # Create token with different secret payload = { "sub": str(user_id), "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC), "type": "access", } token = jwt.encode(payload, "wrong-secret", algorithm="HS256") result = decode_token(token) assert result is None def test_returns_none_for_expired_token(self): """Test that decode_token returns None for expired tokens. Tokens past their expiration should be rejected. """ user_id = uuid.uuid4() payload = { "sub": str(user_id), "exp": datetime.now(UTC) - timedelta(hours=1), # Already expired "iat": datetime.now(UTC) - timedelta(hours=2), "type": "access", } token = jwt.encode( payload, settings.secret_key.get_secret_value(), algorithm=settings.jwt_algorithm, ) result = decode_token(token) assert result is None class TestVerifyAccessToken: """Tests for verify_access_token function.""" def test_returns_user_id_for_valid_access_token(self): """Test that verify_access_token returns user ID for valid tokens. A valid access token should return the UUID of the user. """ user_id = uuid.uuid4() token = create_access_token(user_id) result = verify_access_token(token) assert result == user_id def test_returns_none_for_refresh_token(self): """Test that verify_access_token rejects refresh tokens. Even valid refresh tokens should be rejected when verifying as access tokens to prevent token type confusion. """ user_id = uuid.uuid4() token, _ = create_refresh_token(user_id) result = verify_access_token(token) assert result is None def test_returns_none_for_invalid_token(self): """Test that verify_access_token returns None for invalid tokens.""" result = verify_access_token("invalid.token.here") assert result is None def test_returns_none_for_invalid_uuid_subject(self): """Test that verify_access_token returns None for non-UUID subject. If the subject claim is not a valid UUID, the token should be rejected. """ payload = { "sub": "not-a-uuid", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC), "type": "access", } token = jwt.encode( payload, settings.secret_key.get_secret_value(), algorithm=settings.jwt_algorithm, ) result = verify_access_token(token) assert result is None class TestVerifyRefreshToken: """Tests for verify_refresh_token function.""" def test_returns_user_id_and_jti_for_valid_refresh_token(self): """Test that verify_refresh_token returns user ID and JTI. A valid refresh token should return both values needed for revocation checking. """ user_id = uuid.uuid4() token, jti = create_refresh_token(user_id) result = verify_refresh_token(token) assert result is not None result_user_id, result_jti = result assert result_user_id == user_id assert result_jti == jti def test_returns_none_for_access_token(self): """Test that verify_refresh_token rejects access tokens. Even valid access tokens should be rejected when verifying as refresh tokens. """ user_id = uuid.uuid4() token = create_access_token(user_id) result = verify_refresh_token(token) assert result is None def test_returns_none_for_token_without_jti(self): """Test that verify_refresh_token rejects tokens missing JTI. Refresh tokens must have a JTI for revocation tracking. """ payload = { "sub": str(uuid.uuid4()), "exp": datetime.now(UTC) + timedelta(days=7), "iat": datetime.now(UTC), "type": "refresh", # No jti } token = jwt.encode( payload, settings.secret_key.get_secret_value(), algorithm=settings.jwt_algorithm, ) result = verify_refresh_token(token) assert result is None def test_returns_none_for_invalid_token(self): """Test that verify_refresh_token returns None for invalid tokens.""" result = verify_refresh_token("invalid.token.here") assert result is None class TestHelperFunctions: """Tests for helper functions.""" def test_get_token_expiration_seconds(self): """Test that get_token_expiration_seconds returns correct value. Should return jwt_expire_minutes converted to seconds. """ result = get_token_expiration_seconds() assert result == settings.jwt_expire_minutes * 60 def test_get_refresh_token_expiration(self): """Test that get_refresh_token_expiration returns future datetime. Should return a datetime approximately jwt_refresh_expire_days in the future. """ before = datetime.now(UTC) result = get_refresh_token_expiration() after = datetime.now(UTC) expected_min = before + timedelta(days=settings.jwt_refresh_expire_days) expected_max = after + timedelta(days=settings.jwt_refresh_expire_days) assert expected_min <= result <= expected_max