strat-chatbot/adapters/inbound/api.py
Cal Corum 43d36ce439 fix: resolve HIGH-severity issues from code review
API authentication:
- Add X-API-Secret shared-secret header validation on /chat and /stats
- /health remains public for monitoring
- Auth is a no-op when API_SECRET is empty (dev mode)

Rate limiting:
- Add per-user sliding-window rate limiter on /chat (10 req/60s default)
- Returns 429 with clear message when exceeded
- Self-cleaning memory (prunes expired entries on each check)

Exception sanitization:
- Discord bot no longer exposes raw exception text to users
- Error embeds show generic "Something went wrong" message
- Full exception details logged server-side with context
- query_chat_api RuntimeError no longer includes response body

Async correctness:
- Wrap synchronous RuleRepository.search() in run_in_executor()
  to prevent blocking the event loop during SentenceTransformer inference
- Port contract stays synchronous; service owns the async boundary

Test coverage: 101 passed, 1 skipped (11 new tests for auth + rate limiting)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 16:00:26 -05:00

252 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""FastAPI inbound adapter — thin HTTP layer over ChatService.
This module contains only routing / serialisation logic. All business rules
live in domain.services.ChatService; all storage / LLM calls live in outbound
adapters. The router reads ChatService and RuleRepository from app.state so
that the container (config/container.py) remains the single wiring point and
tests can substitute fakes without monkey-patching.
"""
import logging
import time
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field
from domain.ports import RuleRepository
from domain.services import ChatService
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Rate limiter
# ---------------------------------------------------------------------------
class RateLimiter:
"""Sliding-window in-memory rate limiter keyed by user_id.
Tracks the timestamps of recent requests for each user. A request is
allowed when the number of timestamps within the current window is below
max_requests. Old timestamps are pruned on each check so memory does not
grow without bound.
Args:
max_requests: Maximum number of requests allowed per user per window.
window_seconds: Length of the sliding window in seconds.
"""
def __init__(self, max_requests: int = 10, window_seconds: float = 60.0) -> None:
self.max_requests = max_requests
self.window_seconds = window_seconds
self._timestamps: dict[str, list[float]] = {}
def check(self, user_id: str) -> bool:
"""Return True if the request is allowed, False if rate limited.
Prunes stale timestamps for the caller on every invocation, so the
dict entry naturally shrinks back to zero entries between bursts.
"""
now = time.monotonic()
cutoff = now - self.window_seconds
bucket = self._timestamps.get(user_id)
if bucket is None:
self._timestamps[user_id] = [now]
return True
# Drop timestamps outside the window, then check the count.
pruned = [ts for ts in bucket if ts > cutoff]
if len(pruned) >= self.max_requests:
self._timestamps[user_id] = pruned
return False
pruned.append(now)
self._timestamps[user_id] = pruned
return True
# Module-level singleton — shared across all requests in a single process.
_rate_limiter = RateLimiter()
router = APIRouter()
# ---------------------------------------------------------------------------
# Request / response Pydantic models
# ---------------------------------------------------------------------------
class ChatRequest(BaseModel):
"""Payload accepted by POST /chat."""
message: str = Field(
...,
min_length=1,
max_length=4000,
description="The user's question (14000 characters).",
)
user_id: str = Field(
...,
min_length=1,
max_length=64,
description="Opaque caller identifier, e.g. Discord snowflake.",
)
channel_id: str = Field(
...,
min_length=1,
max_length=64,
description="Opaque channel identifier, e.g. Discord channel snowflake.",
)
conversation_id: Optional[str] = Field(
default=None,
description="Continue an existing conversation; omit to start a new one.",
)
parent_message_id: Optional[str] = Field(
default=None,
description="Thread parent message ID for Discord thread replies.",
)
class ChatResponse(BaseModel):
"""Payload returned by POST /chat."""
response: str
conversation_id: str
message_id: str
parent_message_id: Optional[str] = None
cited_rules: list[str]
confidence: float
needs_human: bool
# ---------------------------------------------------------------------------
# Dependency helpers — read from app.state set by the container
# ---------------------------------------------------------------------------
def _get_chat_service(request: Request) -> ChatService:
"""Extract the ChatService wired by the container from app.state."""
return request.app.state.chat_service
def _get_rule_repository(request: Request) -> RuleRepository:
"""Extract the RuleRepository wired by the container from app.state."""
return request.app.state.rule_repository
def _check_rate_limit(body: ChatRequest) -> None:
"""Raise HTTP 429 if the caller has exceeded their per-window request quota.
user_id is taken from the parsed request body so it works consistently
regardless of how the Discord bot identifies its users. The check uses the
module-level _rate_limiter singleton so state is shared across requests.
"""
if not _rate_limiter.check(body.user_id):
raise HTTPException(
status_code=429,
detail="Rate limit exceeded. Please wait before sending another message.",
)
def _verify_api_secret(request: Request) -> None:
"""Enforce shared-secret authentication when API_SECRET is configured.
When api_secret on app.state is an empty string the check is skipped
entirely, preserving the existing open-access behaviour for local
development. Once a secret is set, the caller must supply a matching
X-API-Secret header or receive HTTP 401.
"""
secret: str = getattr(request.app.state, "api_secret", "")
if not secret:
return
header_value = request.headers.get("X-API-Secret", "")
if header_value != secret:
raise HTTPException(status_code=401, detail="Invalid or missing API secret.")
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/chat", response_model=ChatResponse)
async def chat(
body: ChatRequest,
service: Annotated[ChatService, Depends(_get_chat_service)],
rules: Annotated[RuleRepository, Depends(_get_rule_repository)],
_auth: Annotated[None, Depends(_verify_api_secret)],
_rate: Annotated[None, Depends(_check_rate_limit)],
) -> ChatResponse:
"""Handle a rules Q&A request.
Delegates entirely to ChatService.answer_question — no business logic here.
Returns HTTP 503 when the LLM adapter cannot be constructed (missing API key)
rather than producing a fake success response, so callers can distinguish
genuine answers from configuration errors.
"""
# The container raises at startup if the API key is required but absent;
# however if the service was created without a real LLM (e.g. missing key
# detected at request time), surface a clear service-unavailable rather than
# leaking a misleading 200 OK.
if not hasattr(service, "llm") or service.llm is None:
raise HTTPException(
status_code=503,
detail="LLM service is not available — check OPENROUTER_API_KEY configuration.",
)
try:
result = await service.answer_question(
message=body.message,
user_id=body.user_id,
channel_id=body.channel_id,
conversation_id=body.conversation_id,
parent_message_id=body.parent_message_id,
)
except Exception as exc:
logger.exception("Unhandled error in ChatService.answer_question")
raise HTTPException(status_code=500, detail=str(exc)) from exc
return ChatResponse(
response=result.response,
conversation_id=result.conversation_id,
message_id=result.message_id,
parent_message_id=result.parent_message_id,
cited_rules=result.cited_rules,
confidence=result.confidence,
needs_human=result.needs_human,
)
@router.get("/health")
async def health(
rules: Annotated[RuleRepository, Depends(_get_rule_repository)],
) -> dict:
"""Return service health and a summary of the loaded knowledge base."""
stats = rules.get_stats()
return {
"status": "healthy",
"rules_count": stats.get("total_rules", 0),
"sections": stats.get("sections", {}),
}
@router.get("/stats")
async def stats(
rules: Annotated[RuleRepository, Depends(_get_rule_repository)],
request: Request,
_auth: Annotated[None, Depends(_verify_api_secret)],
) -> dict:
"""Return extended statistics about the knowledge base and configuration."""
kb_stats = rules.get_stats()
# Pull optional config snapshot from app.state (set by container).
config_snapshot: dict = getattr(request.app.state, "config_snapshot", {})
return {
"knowledge_base": kb_stats,
"config": config_snapshot,
}