mantimon-tcg/backend/app/socketio/auth.py
Cal Corum f512c7b2b3 Refactor to dependency injection pattern - no monkey patching
- ConnectionManager: Add redis_factory constructor parameter
- GameService: Add engine_factory constructor parameter
- AuthHandler: New class replacing standalone functions with
  token_verifier and conn_manager injection
- Update all tests to use constructor DI instead of patch()
- Update CLAUDE.md with factory injection patterns
- Update services README with new patterns
- Add socketio README documenting AuthHandler and events

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 22:54:57 -06:00

290 lines
8.7 KiB
Python

"""Socket.IO authentication middleware for WebSocket connections.
This module provides JWT-based authentication for Socket.IO connections.
It validates access tokens and attaches user information to the socket session.
Authentication Flow:
1. Client connects with `auth: { token: "JWT_ACCESS_TOKEN" }`
2. Server extracts and validates the JWT
3. If valid, user_id is stored in socket session
4. If invalid, connection is rejected with appropriate error
Session Data:
After successful authentication, the socket session contains:
- user_id: str (UUID as string)
- authenticated_at: str (ISO timestamp)
Example:
# In connect handler:
auth_result = await auth_handler.authenticate_connection(sid, auth)
if not auth_result.success:
return False # Reject connection
# Later, get user_id:
session = await sio.get_session(sid, namespace="/game")
user_id = session.get("user_id")
"""
import logging
from collections.abc import Callable
from dataclasses import dataclass
from datetime import UTC, datetime
from uuid import UUID
from app.services.connection_manager import ConnectionManager, connection_manager
from app.services.jwt_service import verify_access_token
logger = logging.getLogger(__name__)
# Type alias for token verifier - takes token string, returns user_id or None
TokenVerifier = Callable[[str], UUID | None]
@dataclass
class AuthResult:
"""Result of authentication attempt.
Attributes:
success: Whether authentication succeeded.
user_id: User's UUID if successful, None otherwise.
error_code: Error code for client if failed.
error_message: Human-readable error message if failed.
"""
success: bool
user_id: UUID | None = None
error_code: str | None = None
error_message: str | None = None
class AuthHandler:
"""Handler for Socket.IO authentication with injectable dependencies.
This class encapsulates authentication logic with dependencies injected
via constructor, making it testable without monkey patching.
Attributes:
_token_verifier: Function to verify JWT tokens.
_connection_manager: ConnectionManager for tracking connections.
"""
def __init__(
self,
token_verifier: TokenVerifier | None = None,
conn_manager: ConnectionManager | None = None,
) -> None:
"""Initialize AuthHandler with dependencies.
Args:
token_verifier: Function to verify tokens. Uses jwt_service if not provided.
conn_manager: ConnectionManager instance. Uses global if not provided.
"""
self._token_verifier = token_verifier or verify_access_token
self._connection_manager = conn_manager or connection_manager
async def authenticate_connection(
self,
sid: str,
auth: dict[str, object] | None,
) -> AuthResult:
"""Authenticate a Socket.IO connection using JWT.
Extracts the JWT from auth data, validates it, and returns the result.
Does NOT modify socket session - caller should handle that.
Args:
sid: Socket session ID (for logging).
auth: Authentication data from connect event.
Returns:
AuthResult with success status and user_id or error details.
"""
token = extract_token(auth)
if token is None:
logger.debug(f"Connection {sid}: No token provided")
return AuthResult(
success=False,
error_code="missing_token",
error_message="Authentication token required",
)
user_id = self._token_verifier(token)
if user_id is None:
logger.debug(f"Connection {sid}: Invalid or expired token")
return AuthResult(
success=False,
error_code="invalid_token",
error_message="Invalid or expired token",
)
logger.debug(f"Connection {sid}: Authenticated as user {user_id}")
return AuthResult(
success=True,
user_id=user_id,
)
async def setup_authenticated_session(
self,
sio: object,
sid: str,
user_id: UUID,
namespace: str = "/game",
) -> None:
"""Set up socket session with authenticated user data.
Saves user_id and authentication timestamp to the socket session,
and registers the connection with ConnectionManager.
Args:
sio: Socket.IO AsyncServer instance (required).
sid: Socket session ID.
user_id: Authenticated user's UUID.
namespace: Socket.IO namespace.
"""
session_data = {
"user_id": str(user_id),
"authenticated_at": datetime.now(UTC).isoformat(),
}
await sio.save_session(sid, session_data, namespace=namespace)
await self._connection_manager.register_connection(sid, user_id)
logger.info(f"Session established: sid={sid}, user_id={user_id}")
async def cleanup_authenticated_session(
self,
sid: str,
namespace: str = "/game",
) -> str | None:
"""Clean up session data on disconnect.
Unregisters the connection from ConnectionManager and returns
the user_id for any additional cleanup needed.
Args:
sid: Socket session ID.
namespace: Socket.IO namespace.
Returns:
user_id if session was authenticated, None otherwise.
"""
conn_info = await self._connection_manager.unregister_connection(sid)
if conn_info:
logger.info(f"Session cleaned up: sid={sid}, user_id={conn_info.user_id}")
return conn_info.user_id
logger.debug(f"No session to clean up for {sid}")
return None
# Global singleton instance
auth_handler = AuthHandler()
def extract_token(auth: dict[str, object] | None) -> str | None:
"""Extract JWT token from Socket.IO auth data.
Clients should send the token in the auth dict:
socket.connect({ auth: { token: "JWT_TOKEN" } })
Also supports:
- auth.authorization: "Bearer TOKEN"
- auth.access_token: "TOKEN"
Args:
auth: Authentication data from Socket.IO connect.
Returns:
JWT token string if found, None otherwise.
Example:
token = extract_token({"token": "eyJ..."})
token = extract_token({"authorization": "Bearer eyJ..."})
"""
if auth is None:
return None
# Primary: auth.token
token = auth.get("token")
if token and isinstance(token, str):
return token
# Alternative: auth.authorization (Bearer token)
authorization = auth.get("authorization")
if authorization and isinstance(authorization, str):
if authorization.lower().startswith("bearer "):
return authorization[7:]
return authorization
# Alternative: auth.access_token
access_token = auth.get("access_token")
if access_token and isinstance(access_token, str):
return access_token
return None
async def get_session_user_id(
sio: object,
sid: str,
namespace: str = "/game",
) -> str | None:
"""Get the authenticated user_id from a socket session.
Convenience function to extract user_id from session data.
Args:
sio: Socket.IO AsyncServer instance.
sid: Socket session ID.
namespace: Socket.IO namespace.
Returns:
user_id string if authenticated, None otherwise.
Example:
user_id = await get_session_user_id(sio, sid)
if not user_id:
await sio.emit("error", {"message": "Not authenticated"}, to=sid)
return
"""
try:
session = await sio.get_session(sid, namespace=namespace)
return session.get("user_id") if session else None
except Exception:
return None
async def require_auth(
sio: object,
sid: str,
namespace: str = "/game",
) -> str | None:
"""Require authentication for an event handler.
Returns the user_id if authenticated, None if not.
Logs a warning if authentication is missing.
Args:
sio: Socket.IO AsyncServer instance.
sid: Socket session ID.
namespace: Socket.IO namespace.
Returns:
user_id string if authenticated, None otherwise.
Example:
@sio.on("game:action", namespace="/game")
async def on_action(sid, data):
user_id = await require_auth(sio, sid)
if not user_id:
return {"error": "Not authenticated"}
# ... handle action
"""
user_id = await get_session_user_id(sio, sid, namespace)
if user_id is None:
logger.warning(f"Unauthenticated event from {sid}")
return user_id