API authentication: - Add X-API-Secret shared-secret header validation on /chat and /stats - /health remains public for monitoring - Auth is a no-op when API_SECRET is empty (dev mode) Rate limiting: - Add per-user sliding-window rate limiter on /chat (10 req/60s default) - Returns 429 with clear message when exceeded - Self-cleaning memory (prunes expired entries on each check) Exception sanitization: - Discord bot no longer exposes raw exception text to users - Error embeds show generic "Something went wrong" message - Full exception details logged server-side with context - query_chat_api RuntimeError no longer includes response body Async correctness: - Wrap synchronous RuleRepository.search() in run_in_executor() to prevent blocking the event loop during SentenceTransformer inference - Port contract stays synchronous; service owns the async boundary Test coverage: 101 passed, 1 skipped (11 new tests for auth + rate limiting) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
252 lines
8.6 KiB
Python
252 lines
8.6 KiB
Python
"""FastAPI inbound adapter — thin HTTP layer over ChatService.
|
||
|
||
This module contains only routing / serialisation logic. All business rules
|
||
live in domain.services.ChatService; all storage / LLM calls live in outbound
|
||
adapters. The router reads ChatService and RuleRepository from app.state so
|
||
that the container (config/container.py) remains the single wiring point and
|
||
tests can substitute fakes without monkey-patching.
|
||
"""
|
||
|
||
import logging
|
||
import time
|
||
from typing import Annotated, Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||
from pydantic import BaseModel, Field
|
||
|
||
from domain.ports import RuleRepository
|
||
from domain.services import ChatService
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Rate limiter
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class RateLimiter:
|
||
"""Sliding-window in-memory rate limiter keyed by user_id.
|
||
|
||
Tracks the timestamps of recent requests for each user. A request is
|
||
allowed when the number of timestamps within the current window is below
|
||
max_requests. Old timestamps are pruned on each check so memory does not
|
||
grow without bound.
|
||
|
||
Args:
|
||
max_requests: Maximum number of requests allowed per user per window.
|
||
window_seconds: Length of the sliding window in seconds.
|
||
"""
|
||
|
||
def __init__(self, max_requests: int = 10, window_seconds: float = 60.0) -> None:
|
||
self.max_requests = max_requests
|
||
self.window_seconds = window_seconds
|
||
self._timestamps: dict[str, list[float]] = {}
|
||
|
||
def check(self, user_id: str) -> bool:
|
||
"""Return True if the request is allowed, False if rate limited.
|
||
|
||
Prunes stale timestamps for the caller on every invocation, so the
|
||
dict entry naturally shrinks back to zero entries between bursts.
|
||
"""
|
||
now = time.monotonic()
|
||
cutoff = now - self.window_seconds
|
||
|
||
bucket = self._timestamps.get(user_id)
|
||
if bucket is None:
|
||
self._timestamps[user_id] = [now]
|
||
return True
|
||
|
||
# Drop timestamps outside the window, then check the count.
|
||
pruned = [ts for ts in bucket if ts > cutoff]
|
||
if len(pruned) >= self.max_requests:
|
||
self._timestamps[user_id] = pruned
|
||
return False
|
||
|
||
pruned.append(now)
|
||
self._timestamps[user_id] = pruned
|
||
return True
|
||
|
||
|
||
# Module-level singleton — shared across all requests in a single process.
|
||
_rate_limiter = RateLimiter()
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Request / response Pydantic models
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class ChatRequest(BaseModel):
|
||
"""Payload accepted by POST /chat."""
|
||
|
||
message: str = Field(
|
||
...,
|
||
min_length=1,
|
||
max_length=4000,
|
||
description="The user's question (1–4000 characters).",
|
||
)
|
||
user_id: str = Field(
|
||
...,
|
||
min_length=1,
|
||
max_length=64,
|
||
description="Opaque caller identifier, e.g. Discord snowflake.",
|
||
)
|
||
channel_id: str = Field(
|
||
...,
|
||
min_length=1,
|
||
max_length=64,
|
||
description="Opaque channel identifier, e.g. Discord channel snowflake.",
|
||
)
|
||
conversation_id: Optional[str] = Field(
|
||
default=None,
|
||
description="Continue an existing conversation; omit to start a new one.",
|
||
)
|
||
parent_message_id: Optional[str] = Field(
|
||
default=None,
|
||
description="Thread parent message ID for Discord thread replies.",
|
||
)
|
||
|
||
|
||
class ChatResponse(BaseModel):
|
||
"""Payload returned by POST /chat."""
|
||
|
||
response: str
|
||
conversation_id: str
|
||
message_id: str
|
||
parent_message_id: Optional[str] = None
|
||
cited_rules: list[str]
|
||
confidence: float
|
||
needs_human: bool
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Dependency helpers — read from app.state set by the container
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _get_chat_service(request: Request) -> ChatService:
|
||
"""Extract the ChatService wired by the container from app.state."""
|
||
return request.app.state.chat_service
|
||
|
||
|
||
def _get_rule_repository(request: Request) -> RuleRepository:
|
||
"""Extract the RuleRepository wired by the container from app.state."""
|
||
return request.app.state.rule_repository
|
||
|
||
|
||
def _check_rate_limit(body: ChatRequest) -> None:
|
||
"""Raise HTTP 429 if the caller has exceeded their per-window request quota.
|
||
|
||
user_id is taken from the parsed request body so it works consistently
|
||
regardless of how the Discord bot identifies its users. The check uses the
|
||
module-level _rate_limiter singleton so state is shared across requests.
|
||
"""
|
||
if not _rate_limiter.check(body.user_id):
|
||
raise HTTPException(
|
||
status_code=429,
|
||
detail="Rate limit exceeded. Please wait before sending another message.",
|
||
)
|
||
|
||
|
||
def _verify_api_secret(request: Request) -> None:
|
||
"""Enforce shared-secret authentication when API_SECRET is configured.
|
||
|
||
When api_secret on app.state is an empty string the check is skipped
|
||
entirely, preserving the existing open-access behaviour for local
|
||
development. Once a secret is set, the caller must supply a matching
|
||
X-API-Secret header or receive HTTP 401.
|
||
"""
|
||
secret: str = getattr(request.app.state, "api_secret", "")
|
||
if not secret:
|
||
return
|
||
header_value = request.headers.get("X-API-Secret", "")
|
||
if header_value != secret:
|
||
raise HTTPException(status_code=401, detail="Invalid or missing API secret.")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Endpoints
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@router.post("/chat", response_model=ChatResponse)
|
||
async def chat(
|
||
body: ChatRequest,
|
||
service: Annotated[ChatService, Depends(_get_chat_service)],
|
||
rules: Annotated[RuleRepository, Depends(_get_rule_repository)],
|
||
_auth: Annotated[None, Depends(_verify_api_secret)],
|
||
_rate: Annotated[None, Depends(_check_rate_limit)],
|
||
) -> ChatResponse:
|
||
"""Handle a rules Q&A request.
|
||
|
||
Delegates entirely to ChatService.answer_question — no business logic here.
|
||
Returns HTTP 503 when the LLM adapter cannot be constructed (missing API key)
|
||
rather than producing a fake success response, so callers can distinguish
|
||
genuine answers from configuration errors.
|
||
"""
|
||
# The container raises at startup if the API key is required but absent;
|
||
# however if the service was created without a real LLM (e.g. missing key
|
||
# detected at request time), surface a clear service-unavailable rather than
|
||
# leaking a misleading 200 OK.
|
||
if not hasattr(service, "llm") or service.llm is None:
|
||
raise HTTPException(
|
||
status_code=503,
|
||
detail="LLM service is not available — check OPENROUTER_API_KEY configuration.",
|
||
)
|
||
|
||
try:
|
||
result = await service.answer_question(
|
||
message=body.message,
|
||
user_id=body.user_id,
|
||
channel_id=body.channel_id,
|
||
conversation_id=body.conversation_id,
|
||
parent_message_id=body.parent_message_id,
|
||
)
|
||
except Exception as exc:
|
||
logger.exception("Unhandled error in ChatService.answer_question")
|
||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||
|
||
return ChatResponse(
|
||
response=result.response,
|
||
conversation_id=result.conversation_id,
|
||
message_id=result.message_id,
|
||
parent_message_id=result.parent_message_id,
|
||
cited_rules=result.cited_rules,
|
||
confidence=result.confidence,
|
||
needs_human=result.needs_human,
|
||
)
|
||
|
||
|
||
@router.get("/health")
|
||
async def health(
|
||
rules: Annotated[RuleRepository, Depends(_get_rule_repository)],
|
||
) -> dict:
|
||
"""Return service health and a summary of the loaded knowledge base."""
|
||
stats = rules.get_stats()
|
||
return {
|
||
"status": "healthy",
|
||
"rules_count": stats.get("total_rules", 0),
|
||
"sections": stats.get("sections", {}),
|
||
}
|
||
|
||
|
||
@router.get("/stats")
|
||
async def stats(
|
||
rules: Annotated[RuleRepository, Depends(_get_rule_repository)],
|
||
request: Request,
|
||
_auth: Annotated[None, Depends(_verify_api_secret)],
|
||
) -> dict:
|
||
"""Return extended statistics about the knowledge base and configuration."""
|
||
kb_stats = rules.get_stats()
|
||
|
||
# Pull optional config snapshot from app.state (set by container).
|
||
config_snapshot: dict = getattr(request.app.state, "config_snapshot", {})
|
||
|
||
return {
|
||
"knowledge_base": kb_stats,
|
||
"config": config_snapshot,
|
||
}
|