"""FastAPI inbound adapter — thin HTTP layer over ChatService. This module contains only routing / serialisation logic. All business rules live in domain.services.ChatService; all storage / LLM calls live in outbound adapters. The router reads ChatService and RuleRepository from app.state so that the container (config/container.py) remains the single wiring point and tests can substitute fakes without monkey-patching. """ import logging import time from typing import Annotated, Optional from fastapi import APIRouter, Depends, HTTPException, Request from pydantic import BaseModel, Field from domain.ports import RuleRepository from domain.services import ChatService logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Rate limiter # --------------------------------------------------------------------------- class RateLimiter: """Sliding-window in-memory rate limiter keyed by user_id. Tracks the timestamps of recent requests for each user. A request is allowed when the number of timestamps within the current window is below max_requests. Old timestamps are pruned on each check so memory does not grow without bound. Args: max_requests: Maximum number of requests allowed per user per window. window_seconds: Length of the sliding window in seconds. """ def __init__(self, max_requests: int = 10, window_seconds: float = 60.0) -> None: self.max_requests = max_requests self.window_seconds = window_seconds self._timestamps: dict[str, list[float]] = {} def check(self, user_id: str) -> bool: """Return True if the request is allowed, False if rate limited. Prunes stale timestamps for the caller on every invocation, so the dict entry naturally shrinks back to zero entries between bursts. """ now = time.monotonic() cutoff = now - self.window_seconds bucket = self._timestamps.get(user_id) if bucket is None: self._timestamps[user_id] = [now] return True # Drop timestamps outside the window, then check the count. pruned = [ts for ts in bucket if ts > cutoff] if len(pruned) >= self.max_requests: self._timestamps[user_id] = pruned return False pruned.append(now) self._timestamps[user_id] = pruned return True # Module-level singleton — shared across all requests in a single process. _rate_limiter = RateLimiter() router = APIRouter() # --------------------------------------------------------------------------- # Request / response Pydantic models # --------------------------------------------------------------------------- class ChatRequest(BaseModel): """Payload accepted by POST /chat.""" message: str = Field( ..., min_length=1, max_length=4000, description="The user's question (1–4000 characters).", ) user_id: str = Field( ..., min_length=1, max_length=64, description="Opaque caller identifier, e.g. Discord snowflake.", ) channel_id: str = Field( ..., min_length=1, max_length=64, description="Opaque channel identifier, e.g. Discord channel snowflake.", ) conversation_id: Optional[str] = Field( default=None, description="Continue an existing conversation; omit to start a new one.", ) parent_message_id: Optional[str] = Field( default=None, description="Thread parent message ID for Discord thread replies.", ) class ChatResponse(BaseModel): """Payload returned by POST /chat.""" response: str conversation_id: str message_id: str parent_message_id: Optional[str] = None cited_rules: list[str] confidence: float needs_human: bool # --------------------------------------------------------------------------- # Dependency helpers — read from app.state set by the container # --------------------------------------------------------------------------- def _get_chat_service(request: Request) -> ChatService: """Extract the ChatService wired by the container from app.state.""" return request.app.state.chat_service def _get_rule_repository(request: Request) -> RuleRepository: """Extract the RuleRepository wired by the container from app.state.""" return request.app.state.rule_repository def _check_rate_limit(body: ChatRequest) -> None: """Raise HTTP 429 if the caller has exceeded their per-window request quota. user_id is taken from the parsed request body so it works consistently regardless of how the Discord bot identifies its users. The check uses the module-level _rate_limiter singleton so state is shared across requests. """ if not _rate_limiter.check(body.user_id): raise HTTPException( status_code=429, detail="Rate limit exceeded. Please wait before sending another message.", ) def _verify_api_secret(request: Request) -> None: """Enforce shared-secret authentication when API_SECRET is configured. When api_secret on app.state is an empty string the check is skipped entirely, preserving the existing open-access behaviour for local development. Once a secret is set, the caller must supply a matching X-API-Secret header or receive HTTP 401. """ secret: str = getattr(request.app.state, "api_secret", "") if not secret: return header_value = request.headers.get("X-API-Secret", "") if header_value != secret: raise HTTPException(status_code=401, detail="Invalid or missing API secret.") # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @router.post("/chat", response_model=ChatResponse) async def chat( body: ChatRequest, service: Annotated[ChatService, Depends(_get_chat_service)], rules: Annotated[RuleRepository, Depends(_get_rule_repository)], _auth: Annotated[None, Depends(_verify_api_secret)], _rate: Annotated[None, Depends(_check_rate_limit)], ) -> ChatResponse: """Handle a rules Q&A request. Delegates entirely to ChatService.answer_question — no business logic here. Returns HTTP 503 when the LLM adapter cannot be constructed (missing API key) rather than producing a fake success response, so callers can distinguish genuine answers from configuration errors. """ # The container raises at startup if the API key is required but absent; # however if the service was created without a real LLM (e.g. missing key # detected at request time), surface a clear service-unavailable rather than # leaking a misleading 200 OK. if not hasattr(service, "llm") or service.llm is None: raise HTTPException( status_code=503, detail="LLM service is not available — check OPENROUTER_API_KEY configuration.", ) try: result = await service.answer_question( message=body.message, user_id=body.user_id, channel_id=body.channel_id, conversation_id=body.conversation_id, parent_message_id=body.parent_message_id, ) except Exception as exc: logger.exception("Unhandled error in ChatService.answer_question") raise HTTPException(status_code=500, detail=str(exc)) from exc return ChatResponse( response=result.response, conversation_id=result.conversation_id, message_id=result.message_id, parent_message_id=result.parent_message_id, cited_rules=result.cited_rules, confidence=result.confidence, needs_human=result.needs_human, ) @router.get("/health") async def health( rules: Annotated[RuleRepository, Depends(_get_rule_repository)], ) -> dict: """Return service health and a summary of the loaded knowledge base.""" stats = rules.get_stats() return { "status": "healthy", "rules_count": stats.get("total_rules", 0), "sections": stats.get("sections", {}), } @router.get("/stats") async def stats( rules: Annotated[RuleRepository, Depends(_get_rule_repository)], request: Request, _auth: Annotated[None, Depends(_verify_api_secret)], ) -> dict: """Return extended statistics about the knowledge base and configuration.""" kb_stats = rules.get_stats() # Pull optional config snapshot from app.state (set by container). config_snapshot: dict = getattr(request.app.state, "config_snapshot", {}) return { "knowledge_base": kb_stats, "config": config_snapshot, }