From 43d36ce439c6fd7f1bf3a3e93772f17190206d9c Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Sun, 8 Mar 2026 16:00:26 -0500 Subject: [PATCH] fix: resolve HIGH-severity issues from code review API authentication: - Add X-API-Secret shared-secret header validation on /chat and /stats - /health remains public for monitoring - Auth is a no-op when API_SECRET is empty (dev mode) Rate limiting: - Add per-user sliding-window rate limiter on /chat (10 req/60s default) - Returns 429 with clear message when exceeded - Self-cleaning memory (prunes expired entries on each check) Exception sanitization: - Discord bot no longer exposes raw exception text to users - Error embeds show generic "Something went wrong" message - Full exception details logged server-side with context - query_chat_api RuntimeError no longer includes response body Async correctness: - Wrap synchronous RuleRepository.search() in run_in_executor() to prevent blocking the event loop during SentenceTransformer inference - Port contract stays synchronous; service owns the async boundary Test coverage: 101 passed, 1 skipped (11 new tests for auth + rate limiting) Co-Authored-By: Claude Opus 4.6 --- adapters/inbound/api.py | 86 +++++++++++++++ app/discord_bot.py | 39 +++++-- config/container.py | 1 + config/settings.py | 5 + domain/services.py | 11 +- tests/adapters/test_api.py | 212 ++++++++++++++++++++++++++++++++++++- 6 files changed, 343 insertions(+), 11 deletions(-) diff --git a/adapters/inbound/api.py b/adapters/inbound/api.py index 9049039..af1bace 100644 --- a/adapters/inbound/api.py +++ b/adapters/inbound/api.py @@ -8,6 +8,7 @@ tests can substitute fakes without monkey-patching. """ import logging +import time from typing import Annotated, Optional from fastapi import APIRouter, Depends, HTTPException, Request @@ -18,6 +19,58 @@ from domain.services import ChatService logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Rate limiter +# --------------------------------------------------------------------------- + + +class RateLimiter: + """Sliding-window in-memory rate limiter keyed by user_id. + + Tracks the timestamps of recent requests for each user. A request is + allowed when the number of timestamps within the current window is below + max_requests. Old timestamps are pruned on each check so memory does not + grow without bound. + + Args: + max_requests: Maximum number of requests allowed per user per window. + window_seconds: Length of the sliding window in seconds. + """ + + def __init__(self, max_requests: int = 10, window_seconds: float = 60.0) -> None: + self.max_requests = max_requests + self.window_seconds = window_seconds + self._timestamps: dict[str, list[float]] = {} + + def check(self, user_id: str) -> bool: + """Return True if the request is allowed, False if rate limited. + + Prunes stale timestamps for the caller on every invocation, so the + dict entry naturally shrinks back to zero entries between bursts. + """ + now = time.monotonic() + cutoff = now - self.window_seconds + + bucket = self._timestamps.get(user_id) + if bucket is None: + self._timestamps[user_id] = [now] + return True + + # Drop timestamps outside the window, then check the count. + pruned = [ts for ts in bucket if ts > cutoff] + if len(pruned) >= self.max_requests: + self._timestamps[user_id] = pruned + return False + + pruned.append(now) + self._timestamps[user_id] = pruned + return True + + +# Module-level singleton — shared across all requests in a single process. +_rate_limiter = RateLimiter() + router = APIRouter() @@ -84,6 +137,36 @@ def _get_rule_repository(request: Request) -> RuleRepository: return request.app.state.rule_repository +def _check_rate_limit(body: ChatRequest) -> None: + """Raise HTTP 429 if the caller has exceeded their per-window request quota. + + user_id is taken from the parsed request body so it works consistently + regardless of how the Discord bot identifies its users. The check uses the + module-level _rate_limiter singleton so state is shared across requests. + """ + if not _rate_limiter.check(body.user_id): + raise HTTPException( + status_code=429, + detail="Rate limit exceeded. Please wait before sending another message.", + ) + + +def _verify_api_secret(request: Request) -> None: + """Enforce shared-secret authentication when API_SECRET is configured. + + When api_secret on app.state is an empty string the check is skipped + entirely, preserving the existing open-access behaviour for local + development. Once a secret is set, the caller must supply a matching + X-API-Secret header or receive HTTP 401. + """ + secret: str = getattr(request.app.state, "api_secret", "") + if not secret: + return + header_value = request.headers.get("X-API-Secret", "") + if header_value != secret: + raise HTTPException(status_code=401, detail="Invalid or missing API secret.") + + # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @@ -94,6 +177,8 @@ async def chat( body: ChatRequest, service: Annotated[ChatService, Depends(_get_chat_service)], rules: Annotated[RuleRepository, Depends(_get_rule_repository)], + _auth: Annotated[None, Depends(_verify_api_secret)], + _rate: Annotated[None, Depends(_check_rate_limit)], ) -> ChatResponse: """Handle a rules Q&A request. @@ -152,6 +237,7 @@ async def health( async def stats( rules: Annotated[RuleRepository, Depends(_get_rule_repository)], request: Request, + _auth: Annotated[None, Depends(_verify_api_secret)], ) -> dict: """Return extended statistics about the knowledge base and configuration.""" kb_stats = rules.get_stats() diff --git a/app/discord_bot.py b/app/discord_bot.py index 83f0e69..3cd4b39 100644 --- a/app/discord_bot.py +++ b/app/discord_bot.py @@ -1,5 +1,6 @@ """Discord bot for Strat-O-Matic rules Q&A.""" +import logging import discord from discord import app_commands from discord.ext import commands @@ -8,6 +9,8 @@ from typing import Optional from .config import settings +logger = logging.getLogger(__name__) + class StratChatbotBot(commands.Bot): """Discord bot for the rules chatbot.""" @@ -29,10 +32,10 @@ class StratChatbotBot(commands.Bot): guild = discord.Object(id=int(settings.discord_guild_id)) self.tree.copy_global_to(guild=guild) await self.tree.sync(guild=guild) - print(f"Slash commands synced to guild {settings.discord_guild_id}") + logger.info("Slash commands synced to guild %s", settings.discord_guild_id) else: await self.tree.sync() - print("Slash commands synced globally") + logger.info("Slash commands synced globally") async def close(self): """Cleanup on shutdown.""" @@ -67,7 +70,14 @@ class StratChatbotBot(commands.Bot): ) as response: if response.status != 200: error_text = await response.text() - raise RuntimeError(f"API error {response.status}: {error_text}") + logger.error( + "API returned %s for %s %s — body: %s", + response.status, + response.method, + response.url, + error_text, + ) + raise RuntimeError(f"API error {response.status}") return await response.json() @@ -79,8 +89,8 @@ async def on_ready(): """Called when the bot is ready.""" if not bot.user: return - print(f"🤖 Bot logged in as {bot.user} (ID: {bot.user.id})") - print("Ready to answer Strat-O-Matic rules questions!") + logger.info("Bot logged in as %s (ID: %s)", bot.user, bot.user.id) + logger.info("Ready to answer Strat-O-Matic rules questions!") @bot.tree.command( @@ -134,10 +144,16 @@ async def ask_command(interaction: discord.Interaction, question: str): await interaction.followup.send(embed=embed) except Exception as e: + logger.error( + "Error handling /ask from user %s: %s", + interaction.user.id, + e, + exc_info=True, + ) await interaction.followup.send( embed=discord.Embed( title="❌ Error", - description=f"Failed to get answer: {str(e)}", + description="Something went wrong while fetching your answer. Please try again later.", color=discord.Color.red(), ) ) @@ -232,11 +248,18 @@ async def on_message(message: discord.Message): await loading_msg.edit(content=None, embed=response_embed) except Exception as e: + logger.error( + "Error handling follow-up from user %s in channel %s: %s", + message.author.id, + message.channel.id, + e, + exc_info=True, + ) await loading_msg.edit( content=None, embed=discord.Embed( title="❌ Error", - description=f"Failed to process follow-up: {str(e)}", + description="Something went wrong while processing your follow-up. Please try again later.", color=discord.Color.red(), ), ) @@ -247,7 +270,7 @@ def run_bot(api_base_url: str = "http://localhost:8000"): bot.api_base_url = api_base_url if not settings.discord_bot_token: - print("❌ DISCORD_BOT_TOKEN environment variable is required") + logger.critical("DISCORD_BOT_TOKEN environment variable is required") exit(1) bot.run(settings.discord_bot_token) diff --git a/config/container.py b/config/container.py index bf8bbaa..6c3b44f 100644 --- a/config/container.py +++ b/config/container.py @@ -107,6 +107,7 @@ def _make_lifespan(settings: Settings): # Expose via app.state for the router's Depends helpers app.state.chat_service = service app.state.rule_repository = chroma_repo + app.state.api_secret = settings.api_secret app.state.config_snapshot = { "openrouter_model": settings.openrouter_model, "top_k_rules": settings.top_k_rules, diff --git a/config/settings.py b/config/settings.py index 7b8fe19..0c5a669 100644 --- a/config/settings.py +++ b/config/settings.py @@ -64,6 +64,11 @@ class Settings(BaseSettings): default="sqlite+aiosqlite:///./data/conversations.db", alias="DB_URL" ) + # ------------------------------------------------------------------ + # API authentication + # ------------------------------------------------------------------ + api_secret: str = Field(default="", alias="API_SECRET") + # ------------------------------------------------------------------ # Conversation / retrieval tuning # ------------------------------------------------------------------ diff --git a/domain/services.py b/domain/services.py index de2e6ce..5130c2f 100644 --- a/domain/services.py +++ b/domain/services.py @@ -3,6 +3,7 @@ ChatService orchestrates the Q&A flow using only domain ports. """ +import asyncio import logging from typing import Optional @@ -59,8 +60,14 @@ class ChatService: parent_id=parent_message_id, ) - # Search for relevant rules - search_results = self.rules.search(query=message, top_k=self.top_k_rules) + # Search for relevant rules — offload the synchronous (CPU-bound) + # RuleRepository.search() to a thread so the event loop is not blocked + # while SentenceTransformer encodes the query. + loop = asyncio.get_running_loop() + search_results = await loop.run_in_executor( + None, + lambda: self.rules.search(query=message, top_k=self.top_k_rules), + ) # Get conversation history for context history = await self.conversations.get_conversation_history(conv_id, limit=10) diff --git a/tests/adapters/test_api.py b/tests/adapters/test_api.py index 0b051f9..23d300b 100644 --- a/tests/adapters/test_api.py +++ b/tests/adapters/test_api.py @@ -22,10 +22,15 @@ What is tested - POST /chat with a message that exceeds 4000 characters returns HTTP 422. - POST /chat with a user_id that exceeds 64 characters returns HTTP 422. - POST /chat when ChatService.answer_question raises returns HTTP 500. +- RateLimiter allows requests within the window and blocks once the limit is hit. +- RateLimiter resets after the window expires so the caller can send again. +- POST /chat returns 429 when the per-user rate limit is exceeded. """ from __future__ import annotations +from unittest.mock import patch + import pytest import httpx from fastapi import FastAPI @@ -33,7 +38,7 @@ from httpx import ASGITransport from domain.models import RuleDocument from domain.services import ChatService -from adapters.inbound.api import router +from adapters.inbound.api import router, RateLimiter from tests.fakes import ( FakeRuleRepository, FakeLLM, @@ -53,6 +58,7 @@ def make_test_app( conversations: FakeConversationStore | None = None, issues: FakeIssueTracker | None = None, top_k_rules: int = 5, + api_secret: str = "", ) -> FastAPI: """Build a minimal FastAPI app with fakes wired into app.state. @@ -78,6 +84,7 @@ def make_test_app( app.state.chat_service = service app.state.rule_repository = _rules + app.state.api_secret = api_secret app.state.config_snapshot = { "openrouter_model": "fake-model", "top_k_rules": top_k_rules, @@ -436,3 +443,206 @@ async def test_chat_with_parent_message_id_returns_200(client: httpx.AsyncClient # The response's parent_message_id is the user turn message id, # not the one we passed in — that's the service's threading model. assert body["parent_message_id"] is not None + + +# --------------------------------------------------------------------------- +# API secret authentication +# --------------------------------------------------------------------------- + +_CHAT_PAYLOAD = {"message": "Test question", "user_id": "u1", "channel_id": "ch1"} + + +@pytest.mark.asyncio +async def test_chat_no_secret_configured_allows_any_request(): + """When api_secret is empty (the default for local dev), POST /chat must + succeed without any X-API-Secret header. + + This preserves the existing open-access behaviour so developers can run + the service locally without configuring a secret. + """ + app = make_test_app(api_secret="") + async with httpx.AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: + resp = await ac.post("/chat", json=_CHAT_PAYLOAD) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_chat_missing_secret_header_returns_401(): + """When api_secret is configured, POST /chat without X-API-Secret must + return HTTP 401, preventing unauthenticated access to the LLM endpoint. + """ + app = make_test_app(api_secret="supersecret") + async with httpx.AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: + resp = await ac.post("/chat", json=_CHAT_PAYLOAD) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_chat_wrong_secret_header_returns_401(): + """A request with an incorrect X-API-Secret value must return HTTP 401. + + This guards against callers who know a header is required but are + guessing or have an outdated secret. + """ + app = make_test_app(api_secret="supersecret") + async with httpx.AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: + resp = await ac.post( + "/chat", json=_CHAT_PAYLOAD, headers={"X-API-Secret": "wrongvalue"} + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_chat_correct_secret_header_returns_200(): + """A request with the correct X-API-Secret header must succeed and return + HTTP 200 when api_secret is configured. + """ + app = make_test_app(api_secret="supersecret") + async with httpx.AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: + resp = await ac.post( + "/chat", json=_CHAT_PAYLOAD, headers={"X-API-Secret": "supersecret"} + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_health_always_public(): + """GET /health must return 200 regardless of whether api_secret is set. + + Health checks are used by monitoring systems that do not hold application + secrets; requiring auth there would break uptime probes. + """ + app = make_test_app(api_secret="supersecret") + async with httpx.AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: + resp = await ac.get("/health") + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_stats_missing_secret_header_returns_401(): + """GET /stats without X-API-Secret must return HTTP 401 when a secret is + configured. + + The stats endpoint exposes configuration details (model names, retrieval + settings) that should be restricted to authenticated callers. + """ + app = make_test_app(api_secret="supersecret") + async with httpx.AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: + resp = await ac.get("/stats") + assert resp.status_code == 401 + + +# --------------------------------------------------------------------------- +# RateLimiter unit tests +# --------------------------------------------------------------------------- + + +def test_rate_limiter_allows_requests_within_limit(): + """Requests below max_requests within the window must all return True. + + We create a limiter with max_requests=3 and verify that three consecutive + calls for the same user are all permitted. + """ + limiter = RateLimiter(max_requests=3, window_seconds=60.0) + assert limiter.check("user-a") is True + assert limiter.check("user-a") is True + assert limiter.check("user-a") is True + + +def test_rate_limiter_blocks_when_limit_exceeded(): + """The (max_requests + 1)-th call within the window must return False. + + This confirms the sliding-window boundary is enforced correctly: once a + user has consumed all allowed slots, further requests are rejected until + the window advances. + """ + limiter = RateLimiter(max_requests=3, window_seconds=60.0) + for _ in range(3): + limiter.check("user-b") + assert limiter.check("user-b") is False + + +def test_rate_limiter_resets_after_window_expires(): + """After the window has fully elapsed, a previously rate-limited user must + be allowed to send again. + + We use unittest.mock.patch to freeze time.monotonic so the test runs + instantly: first we consume the quota at t=0, then advance the clock past + the window boundary and confirm the limiter grants the next request. + """ + limiter = RateLimiter(max_requests=2, window_seconds=10.0) + + with patch("adapters.inbound.api.time") as mock_time: + # All requests happen at t=0. + mock_time.monotonic.return_value = 0.0 + limiter.check("user-c") + limiter.check("user-c") + assert limiter.check("user-c") is False # quota exhausted at t=0 + + # Advance time past the full window so all timestamps are stale. + mock_time.monotonic.return_value = 11.0 + assert limiter.check("user-c") is True # window reset; request allowed + + +def test_rate_limiter_isolates_different_users(): + """Rate limiting must be per-user: consuming user-x's quota must not affect + user-y's available requests. + + This covers the dict-keying logic — a bug that shares state across users + would cause false 429s for innocent callers. + """ + limiter = RateLimiter(max_requests=1, window_seconds=60.0) + limiter.check("user-x") # exhausts user-x's single slot + assert limiter.check("user-x") is False # user-x is blocked + assert limiter.check("user-y") is True # user-y has their own fresh bucket + + +# --------------------------------------------------------------------------- +# POST /chat — rate limit integration (HTTP 429) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_chat_returns_429_when_rate_limit_exceeded(): + """POST /chat must return HTTP 429 once the per-user rate limit is hit. + + We patch the module-level _rate_limiter so we can exercise the integration + between the FastAPI dependency and the limiter without waiting for real time + to pass. The first call returns 200; after patching check() to return False, + the second call must return 429. + """ + import adapters.inbound.api as api_module + + # Use a tight limiter (1 request per 60 s) injected into the module so + # both the app and the dependency share the same instance. + tight_limiter = RateLimiter(max_requests=1, window_seconds=60.0) + original = api_module._rate_limiter + api_module._rate_limiter = tight_limiter + + payload = {"message": "Hello", "user_id": "rl-user", "channel_id": "ch1"} + app = make_test_app() + + try: + async with httpx.AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: + resp1 = await ac.post("/chat", json=payload) + assert resp1.status_code == 200 # first request is within limit + + resp2 = await ac.post("/chat", json=payload) + assert resp2.status_code == 429 # second request is blocked + assert "Rate limit" in resp2.json()["detail"] + finally: + api_module._rate_limiter = original # restore to avoid polluting other tests