- ConnectionManager: Add redis_factory constructor parameter - GameService: Add engine_factory constructor parameter - AuthHandler: New class replacing standalone functions with token_verifier and conn_manager injection - Update all tests to use constructor DI instead of patch() - Update CLAUDE.md with factory injection patterns - Update services README with new patterns - Add socketio README documenting AuthHandler and events Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
290 lines
8.7 KiB
Python
290 lines
8.7 KiB
Python
"""Socket.IO authentication middleware for WebSocket connections.
|
|
|
|
This module provides JWT-based authentication for Socket.IO connections.
|
|
It validates access tokens and attaches user information to the socket session.
|
|
|
|
Authentication Flow:
|
|
1. Client connects with `auth: { token: "JWT_ACCESS_TOKEN" }`
|
|
2. Server extracts and validates the JWT
|
|
3. If valid, user_id is stored in socket session
|
|
4. If invalid, connection is rejected with appropriate error
|
|
|
|
Session Data:
|
|
After successful authentication, the socket session contains:
|
|
- user_id: str (UUID as string)
|
|
- authenticated_at: str (ISO timestamp)
|
|
|
|
Example:
|
|
# In connect handler:
|
|
auth_result = await auth_handler.authenticate_connection(sid, auth)
|
|
if not auth_result.success:
|
|
return False # Reject connection
|
|
|
|
# Later, get user_id:
|
|
session = await sio.get_session(sid, namespace="/game")
|
|
user_id = session.get("user_id")
|
|
"""
|
|
|
|
import logging
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from uuid import UUID
|
|
|
|
from app.services.connection_manager import ConnectionManager, connection_manager
|
|
from app.services.jwt_service import verify_access_token
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Type alias for token verifier - takes token string, returns user_id or None
|
|
TokenVerifier = Callable[[str], UUID | None]
|
|
|
|
|
|
@dataclass
|
|
class AuthResult:
|
|
"""Result of authentication attempt.
|
|
|
|
Attributes:
|
|
success: Whether authentication succeeded.
|
|
user_id: User's UUID if successful, None otherwise.
|
|
error_code: Error code for client if failed.
|
|
error_message: Human-readable error message if failed.
|
|
"""
|
|
|
|
success: bool
|
|
user_id: UUID | None = None
|
|
error_code: str | None = None
|
|
error_message: str | None = None
|
|
|
|
|
|
class AuthHandler:
|
|
"""Handler for Socket.IO authentication with injectable dependencies.
|
|
|
|
This class encapsulates authentication logic with dependencies injected
|
|
via constructor, making it testable without monkey patching.
|
|
|
|
Attributes:
|
|
_token_verifier: Function to verify JWT tokens.
|
|
_connection_manager: ConnectionManager for tracking connections.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
token_verifier: TokenVerifier | None = None,
|
|
conn_manager: ConnectionManager | None = None,
|
|
) -> None:
|
|
"""Initialize AuthHandler with dependencies.
|
|
|
|
Args:
|
|
token_verifier: Function to verify tokens. Uses jwt_service if not provided.
|
|
conn_manager: ConnectionManager instance. Uses global if not provided.
|
|
"""
|
|
self._token_verifier = token_verifier or verify_access_token
|
|
self._connection_manager = conn_manager or connection_manager
|
|
|
|
async def authenticate_connection(
|
|
self,
|
|
sid: str,
|
|
auth: dict[str, object] | None,
|
|
) -> AuthResult:
|
|
"""Authenticate a Socket.IO connection using JWT.
|
|
|
|
Extracts the JWT from auth data, validates it, and returns the result.
|
|
Does NOT modify socket session - caller should handle that.
|
|
|
|
Args:
|
|
sid: Socket session ID (for logging).
|
|
auth: Authentication data from connect event.
|
|
|
|
Returns:
|
|
AuthResult with success status and user_id or error details.
|
|
"""
|
|
token = extract_token(auth)
|
|
|
|
if token is None:
|
|
logger.debug(f"Connection {sid}: No token provided")
|
|
return AuthResult(
|
|
success=False,
|
|
error_code="missing_token",
|
|
error_message="Authentication token required",
|
|
)
|
|
|
|
user_id = self._token_verifier(token)
|
|
|
|
if user_id is None:
|
|
logger.debug(f"Connection {sid}: Invalid or expired token")
|
|
return AuthResult(
|
|
success=False,
|
|
error_code="invalid_token",
|
|
error_message="Invalid or expired token",
|
|
)
|
|
|
|
logger.debug(f"Connection {sid}: Authenticated as user {user_id}")
|
|
return AuthResult(
|
|
success=True,
|
|
user_id=user_id,
|
|
)
|
|
|
|
async def setup_authenticated_session(
|
|
self,
|
|
sio: object,
|
|
sid: str,
|
|
user_id: UUID,
|
|
namespace: str = "/game",
|
|
) -> None:
|
|
"""Set up socket session with authenticated user data.
|
|
|
|
Saves user_id and authentication timestamp to the socket session,
|
|
and registers the connection with ConnectionManager.
|
|
|
|
Args:
|
|
sio: Socket.IO AsyncServer instance (required).
|
|
sid: Socket session ID.
|
|
user_id: Authenticated user's UUID.
|
|
namespace: Socket.IO namespace.
|
|
"""
|
|
session_data = {
|
|
"user_id": str(user_id),
|
|
"authenticated_at": datetime.now(UTC).isoformat(),
|
|
}
|
|
await sio.save_session(sid, session_data, namespace=namespace)
|
|
|
|
await self._connection_manager.register_connection(sid, user_id)
|
|
|
|
logger.info(f"Session established: sid={sid}, user_id={user_id}")
|
|
|
|
async def cleanup_authenticated_session(
|
|
self,
|
|
sid: str,
|
|
namespace: str = "/game",
|
|
) -> str | None:
|
|
"""Clean up session data on disconnect.
|
|
|
|
Unregisters the connection from ConnectionManager and returns
|
|
the user_id for any additional cleanup needed.
|
|
|
|
Args:
|
|
sid: Socket session ID.
|
|
namespace: Socket.IO namespace.
|
|
|
|
Returns:
|
|
user_id if session was authenticated, None otherwise.
|
|
"""
|
|
conn_info = await self._connection_manager.unregister_connection(sid)
|
|
|
|
if conn_info:
|
|
logger.info(f"Session cleaned up: sid={sid}, user_id={conn_info.user_id}")
|
|
return conn_info.user_id
|
|
|
|
logger.debug(f"No session to clean up for {sid}")
|
|
return None
|
|
|
|
|
|
# Global singleton instance
|
|
auth_handler = AuthHandler()
|
|
|
|
|
|
def extract_token(auth: dict[str, object] | None) -> str | None:
|
|
"""Extract JWT token from Socket.IO auth data.
|
|
|
|
Clients should send the token in the auth dict:
|
|
socket.connect({ auth: { token: "JWT_TOKEN" } })
|
|
|
|
Also supports:
|
|
- auth.authorization: "Bearer TOKEN"
|
|
- auth.access_token: "TOKEN"
|
|
|
|
Args:
|
|
auth: Authentication data from Socket.IO connect.
|
|
|
|
Returns:
|
|
JWT token string if found, None otherwise.
|
|
|
|
Example:
|
|
token = extract_token({"token": "eyJ..."})
|
|
token = extract_token({"authorization": "Bearer eyJ..."})
|
|
"""
|
|
if auth is None:
|
|
return None
|
|
|
|
# Primary: auth.token
|
|
token = auth.get("token")
|
|
if token and isinstance(token, str):
|
|
return token
|
|
|
|
# Alternative: auth.authorization (Bearer token)
|
|
authorization = auth.get("authorization")
|
|
if authorization and isinstance(authorization, str):
|
|
if authorization.lower().startswith("bearer "):
|
|
return authorization[7:]
|
|
return authorization
|
|
|
|
# Alternative: auth.access_token
|
|
access_token = auth.get("access_token")
|
|
if access_token and isinstance(access_token, str):
|
|
return access_token
|
|
|
|
return None
|
|
|
|
|
|
async def get_session_user_id(
|
|
sio: object,
|
|
sid: str,
|
|
namespace: str = "/game",
|
|
) -> str | None:
|
|
"""Get the authenticated user_id from a socket session.
|
|
|
|
Convenience function to extract user_id from session data.
|
|
|
|
Args:
|
|
sio: Socket.IO AsyncServer instance.
|
|
sid: Socket session ID.
|
|
namespace: Socket.IO namespace.
|
|
|
|
Returns:
|
|
user_id string if authenticated, None otherwise.
|
|
|
|
Example:
|
|
user_id = await get_session_user_id(sio, sid)
|
|
if not user_id:
|
|
await sio.emit("error", {"message": "Not authenticated"}, to=sid)
|
|
return
|
|
"""
|
|
try:
|
|
session = await sio.get_session(sid, namespace=namespace)
|
|
return session.get("user_id") if session else None
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
async def require_auth(
|
|
sio: object,
|
|
sid: str,
|
|
namespace: str = "/game",
|
|
) -> str | None:
|
|
"""Require authentication for an event handler.
|
|
|
|
Returns the user_id if authenticated, None if not.
|
|
Logs a warning if authentication is missing.
|
|
|
|
Args:
|
|
sio: Socket.IO AsyncServer instance.
|
|
sid: Socket session ID.
|
|
namespace: Socket.IO namespace.
|
|
|
|
Returns:
|
|
user_id string if authenticated, None otherwise.
|
|
|
|
Example:
|
|
@sio.on("game:action", namespace="/game")
|
|
async def on_action(sid, data):
|
|
user_id = await require_auth(sio, sid)
|
|
if not user_id:
|
|
return {"error": "Not authenticated"}
|
|
# ... handle action
|
|
"""
|
|
user_id = await get_session_user_id(sio, sid, namespace)
|
|
if user_id is None:
|
|
logger.warning(f"Unauthenticated event from {sid}")
|
|
return user_id
|