fix: resolve HIGH-severity issues from code review

API authentication:
- Add X-API-Secret shared-secret header validation on /chat and /stats
- /health remains public for monitoring
- Auth is a no-op when API_SECRET is empty (dev mode)

Rate limiting:
- Add per-user sliding-window rate limiter on /chat (10 req/60s default)
- Returns 429 with clear message when exceeded
- Self-cleaning memory (prunes expired entries on each check)

Exception sanitization:
- Discord bot no longer exposes raw exception text to users
- Error embeds show generic "Something went wrong" message
- Full exception details logged server-side with context
- query_chat_api RuntimeError no longer includes response body

Async correctness:
- Wrap synchronous RuleRepository.search() in run_in_executor()
  to prevent blocking the event loop during SentenceTransformer inference
- Port contract stays synchronous; service owns the async boundary

Test coverage: 101 passed, 1 skipped (11 new tests for auth + rate limiting)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2026-03-08 16:00:26 -05:00
parent c3218f70c4
commit 43d36ce439
6 changed files with 343 additions and 11 deletions

View File

@ -8,6 +8,7 @@ tests can substitute fakes without monkey-patching.
"""
import logging
import time
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Request
@ -18,6 +19,58 @@ from domain.services import ChatService
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Rate limiter
# ---------------------------------------------------------------------------
class RateLimiter:
"""Sliding-window in-memory rate limiter keyed by user_id.
Tracks the timestamps of recent requests for each user. A request is
allowed when the number of timestamps within the current window is below
max_requests. Old timestamps are pruned on each check so memory does not
grow without bound.
Args:
max_requests: Maximum number of requests allowed per user per window.
window_seconds: Length of the sliding window in seconds.
"""
def __init__(self, max_requests: int = 10, window_seconds: float = 60.0) -> None:
self.max_requests = max_requests
self.window_seconds = window_seconds
self._timestamps: dict[str, list[float]] = {}
def check(self, user_id: str) -> bool:
"""Return True if the request is allowed, False if rate limited.
Prunes stale timestamps for the caller on every invocation, so the
dict entry naturally shrinks back to zero entries between bursts.
"""
now = time.monotonic()
cutoff = now - self.window_seconds
bucket = self._timestamps.get(user_id)
if bucket is None:
self._timestamps[user_id] = [now]
return True
# Drop timestamps outside the window, then check the count.
pruned = [ts for ts in bucket if ts > cutoff]
if len(pruned) >= self.max_requests:
self._timestamps[user_id] = pruned
return False
pruned.append(now)
self._timestamps[user_id] = pruned
return True
# Module-level singleton — shared across all requests in a single process.
_rate_limiter = RateLimiter()
router = APIRouter()
@ -84,6 +137,36 @@ def _get_rule_repository(request: Request) -> RuleRepository:
return request.app.state.rule_repository
def _check_rate_limit(body: ChatRequest) -> None:
"""Raise HTTP 429 if the caller has exceeded their per-window request quota.
user_id is taken from the parsed request body so it works consistently
regardless of how the Discord bot identifies its users. The check uses the
module-level _rate_limiter singleton so state is shared across requests.
"""
if not _rate_limiter.check(body.user_id):
raise HTTPException(
status_code=429,
detail="Rate limit exceeded. Please wait before sending another message.",
)
def _verify_api_secret(request: Request) -> None:
"""Enforce shared-secret authentication when API_SECRET is configured.
When api_secret on app.state is an empty string the check is skipped
entirely, preserving the existing open-access behaviour for local
development. Once a secret is set, the caller must supply a matching
X-API-Secret header or receive HTTP 401.
"""
secret: str = getattr(request.app.state, "api_secret", "")
if not secret:
return
header_value = request.headers.get("X-API-Secret", "")
if header_value != secret:
raise HTTPException(status_code=401, detail="Invalid or missing API secret.")
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@ -94,6 +177,8 @@ async def chat(
body: ChatRequest,
service: Annotated[ChatService, Depends(_get_chat_service)],
rules: Annotated[RuleRepository, Depends(_get_rule_repository)],
_auth: Annotated[None, Depends(_verify_api_secret)],
_rate: Annotated[None, Depends(_check_rate_limit)],
) -> ChatResponse:
"""Handle a rules Q&A request.
@ -152,6 +237,7 @@ async def health(
async def stats(
rules: Annotated[RuleRepository, Depends(_get_rule_repository)],
request: Request,
_auth: Annotated[None, Depends(_verify_api_secret)],
) -> dict:
"""Return extended statistics about the knowledge base and configuration."""
kb_stats = rules.get_stats()

View File

@ -1,5 +1,6 @@
"""Discord bot for Strat-O-Matic rules Q&A."""
import logging
import discord
from discord import app_commands
from discord.ext import commands
@ -8,6 +9,8 @@ from typing import Optional
from .config import settings
logger = logging.getLogger(__name__)
class StratChatbotBot(commands.Bot):
"""Discord bot for the rules chatbot."""
@ -29,10 +32,10 @@ class StratChatbotBot(commands.Bot):
guild = discord.Object(id=int(settings.discord_guild_id))
self.tree.copy_global_to(guild=guild)
await self.tree.sync(guild=guild)
print(f"Slash commands synced to guild {settings.discord_guild_id}")
logger.info("Slash commands synced to guild %s", settings.discord_guild_id)
else:
await self.tree.sync()
print("Slash commands synced globally")
logger.info("Slash commands synced globally")
async def close(self):
"""Cleanup on shutdown."""
@ -67,7 +70,14 @@ class StratChatbotBot(commands.Bot):
) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(f"API error {response.status}: {error_text}")
logger.error(
"API returned %s for %s %s — body: %s",
response.status,
response.method,
response.url,
error_text,
)
raise RuntimeError(f"API error {response.status}")
return await response.json()
@ -79,8 +89,8 @@ async def on_ready():
"""Called when the bot is ready."""
if not bot.user:
return
print(f"🤖 Bot logged in as {bot.user} (ID: {bot.user.id})")
print("Ready to answer Strat-O-Matic rules questions!")
logger.info("Bot logged in as %s (ID: %s)", bot.user, bot.user.id)
logger.info("Ready to answer Strat-O-Matic rules questions!")
@bot.tree.command(
@ -134,10 +144,16 @@ async def ask_command(interaction: discord.Interaction, question: str):
await interaction.followup.send(embed=embed)
except Exception as e:
logger.error(
"Error handling /ask from user %s: %s",
interaction.user.id,
e,
exc_info=True,
)
await interaction.followup.send(
embed=discord.Embed(
title="❌ Error",
description=f"Failed to get answer: {str(e)}",
description="Something went wrong while fetching your answer. Please try again later.",
color=discord.Color.red(),
)
)
@ -232,11 +248,18 @@ async def on_message(message: discord.Message):
await loading_msg.edit(content=None, embed=response_embed)
except Exception as e:
logger.error(
"Error handling follow-up from user %s in channel %s: %s",
message.author.id,
message.channel.id,
e,
exc_info=True,
)
await loading_msg.edit(
content=None,
embed=discord.Embed(
title="❌ Error",
description=f"Failed to process follow-up: {str(e)}",
description="Something went wrong while processing your follow-up. Please try again later.",
color=discord.Color.red(),
),
)
@ -247,7 +270,7 @@ def run_bot(api_base_url: str = "http://localhost:8000"):
bot.api_base_url = api_base_url
if not settings.discord_bot_token:
print("DISCORD_BOT_TOKEN environment variable is required")
logger.critical("DISCORD_BOT_TOKEN environment variable is required")
exit(1)
bot.run(settings.discord_bot_token)

View File

@ -107,6 +107,7 @@ def _make_lifespan(settings: Settings):
# Expose via app.state for the router's Depends helpers
app.state.chat_service = service
app.state.rule_repository = chroma_repo
app.state.api_secret = settings.api_secret
app.state.config_snapshot = {
"openrouter_model": settings.openrouter_model,
"top_k_rules": settings.top_k_rules,

View File

@ -64,6 +64,11 @@ class Settings(BaseSettings):
default="sqlite+aiosqlite:///./data/conversations.db", alias="DB_URL"
)
# ------------------------------------------------------------------
# API authentication
# ------------------------------------------------------------------
api_secret: str = Field(default="", alias="API_SECRET")
# ------------------------------------------------------------------
# Conversation / retrieval tuning
# ------------------------------------------------------------------

View File

@ -3,6 +3,7 @@
ChatService orchestrates the Q&A flow using only domain ports.
"""
import asyncio
import logging
from typing import Optional
@ -59,8 +60,14 @@ class ChatService:
parent_id=parent_message_id,
)
# Search for relevant rules
search_results = self.rules.search(query=message, top_k=self.top_k_rules)
# Search for relevant rules — offload the synchronous (CPU-bound)
# RuleRepository.search() to a thread so the event loop is not blocked
# while SentenceTransformer encodes the query.
loop = asyncio.get_running_loop()
search_results = await loop.run_in_executor(
None,
lambda: self.rules.search(query=message, top_k=self.top_k_rules),
)
# Get conversation history for context
history = await self.conversations.get_conversation_history(conv_id, limit=10)

View File

@ -22,10 +22,15 @@ What is tested
- POST /chat with a message that exceeds 4000 characters returns HTTP 422.
- POST /chat with a user_id that exceeds 64 characters returns HTTP 422.
- POST /chat when ChatService.answer_question raises returns HTTP 500.
- RateLimiter allows requests within the window and blocks once the limit is hit.
- RateLimiter resets after the window expires so the caller can send again.
- POST /chat returns 429 when the per-user rate limit is exceeded.
"""
from __future__ import annotations
from unittest.mock import patch
import pytest
import httpx
from fastapi import FastAPI
@ -33,7 +38,7 @@ from httpx import ASGITransport
from domain.models import RuleDocument
from domain.services import ChatService
from adapters.inbound.api import router
from adapters.inbound.api import router, RateLimiter
from tests.fakes import (
FakeRuleRepository,
FakeLLM,
@ -53,6 +58,7 @@ def make_test_app(
conversations: FakeConversationStore | None = None,
issues: FakeIssueTracker | None = None,
top_k_rules: int = 5,
api_secret: str = "",
) -> FastAPI:
"""Build a minimal FastAPI app with fakes wired into app.state.
@ -78,6 +84,7 @@ def make_test_app(
app.state.chat_service = service
app.state.rule_repository = _rules
app.state.api_secret = api_secret
app.state.config_snapshot = {
"openrouter_model": "fake-model",
"top_k_rules": top_k_rules,
@ -436,3 +443,206 @@ async def test_chat_with_parent_message_id_returns_200(client: httpx.AsyncClient
# The response's parent_message_id is the user turn message id,
# not the one we passed in — that's the service's threading model.
assert body["parent_message_id"] is not None
# ---------------------------------------------------------------------------
# API secret authentication
# ---------------------------------------------------------------------------
_CHAT_PAYLOAD = {"message": "Test question", "user_id": "u1", "channel_id": "ch1"}
@pytest.mark.asyncio
async def test_chat_no_secret_configured_allows_any_request():
"""When api_secret is empty (the default for local dev), POST /chat must
succeed without any X-API-Secret header.
This preserves the existing open-access behaviour so developers can run
the service locally without configuring a secret.
"""
app = make_test_app(api_secret="")
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.post("/chat", json=_CHAT_PAYLOAD)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_chat_missing_secret_header_returns_401():
"""When api_secret is configured, POST /chat without X-API-Secret must
return HTTP 401, preventing unauthenticated access to the LLM endpoint.
"""
app = make_test_app(api_secret="supersecret")
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.post("/chat", json=_CHAT_PAYLOAD)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_chat_wrong_secret_header_returns_401():
"""A request with an incorrect X-API-Secret value must return HTTP 401.
This guards against callers who know a header is required but are
guessing or have an outdated secret.
"""
app = make_test_app(api_secret="supersecret")
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.post(
"/chat", json=_CHAT_PAYLOAD, headers={"X-API-Secret": "wrongvalue"}
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_chat_correct_secret_header_returns_200():
"""A request with the correct X-API-Secret header must succeed and return
HTTP 200 when api_secret is configured.
"""
app = make_test_app(api_secret="supersecret")
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.post(
"/chat", json=_CHAT_PAYLOAD, headers={"X-API-Secret": "supersecret"}
)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_health_always_public():
"""GET /health must return 200 regardless of whether api_secret is set.
Health checks are used by monitoring systems that do not hold application
secrets; requiring auth there would break uptime probes.
"""
app = make_test_app(api_secret="supersecret")
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.get("/health")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_stats_missing_secret_header_returns_401():
"""GET /stats without X-API-Secret must return HTTP 401 when a secret is
configured.
The stats endpoint exposes configuration details (model names, retrieval
settings) that should be restricted to authenticated callers.
"""
app = make_test_app(api_secret="supersecret")
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.get("/stats")
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# RateLimiter unit tests
# ---------------------------------------------------------------------------
def test_rate_limiter_allows_requests_within_limit():
"""Requests below max_requests within the window must all return True.
We create a limiter with max_requests=3 and verify that three consecutive
calls for the same user are all permitted.
"""
limiter = RateLimiter(max_requests=3, window_seconds=60.0)
assert limiter.check("user-a") is True
assert limiter.check("user-a") is True
assert limiter.check("user-a") is True
def test_rate_limiter_blocks_when_limit_exceeded():
"""The (max_requests + 1)-th call within the window must return False.
This confirms the sliding-window boundary is enforced correctly: once a
user has consumed all allowed slots, further requests are rejected until
the window advances.
"""
limiter = RateLimiter(max_requests=3, window_seconds=60.0)
for _ in range(3):
limiter.check("user-b")
assert limiter.check("user-b") is False
def test_rate_limiter_resets_after_window_expires():
"""After the window has fully elapsed, a previously rate-limited user must
be allowed to send again.
We use unittest.mock.patch to freeze time.monotonic so the test runs
instantly: first we consume the quota at t=0, then advance the clock past
the window boundary and confirm the limiter grants the next request.
"""
limiter = RateLimiter(max_requests=2, window_seconds=10.0)
with patch("adapters.inbound.api.time") as mock_time:
# All requests happen at t=0.
mock_time.monotonic.return_value = 0.0
limiter.check("user-c")
limiter.check("user-c")
assert limiter.check("user-c") is False # quota exhausted at t=0
# Advance time past the full window so all timestamps are stale.
mock_time.monotonic.return_value = 11.0
assert limiter.check("user-c") is True # window reset; request allowed
def test_rate_limiter_isolates_different_users():
"""Rate limiting must be per-user: consuming user-x's quota must not affect
user-y's available requests.
This covers the dict-keying logic a bug that shares state across users
would cause false 429s for innocent callers.
"""
limiter = RateLimiter(max_requests=1, window_seconds=60.0)
limiter.check("user-x") # exhausts user-x's single slot
assert limiter.check("user-x") is False # user-x is blocked
assert limiter.check("user-y") is True # user-y has their own fresh bucket
# ---------------------------------------------------------------------------
# POST /chat — rate limit integration (HTTP 429)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_returns_429_when_rate_limit_exceeded():
"""POST /chat must return HTTP 429 once the per-user rate limit is hit.
We patch the module-level _rate_limiter so we can exercise the integration
between the FastAPI dependency and the limiter without waiting for real time
to pass. The first call returns 200; after patching check() to return False,
the second call must return 429.
"""
import adapters.inbound.api as api_module
# Use a tight limiter (1 request per 60 s) injected into the module so
# both the app and the dependency share the same instance.
tight_limiter = RateLimiter(max_requests=1, window_seconds=60.0)
original = api_module._rate_limiter
api_module._rate_limiter = tight_limiter
payload = {"message": "Hello", "user_id": "rl-user", "channel_id": "ch1"}
app = make_test_app()
try:
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp1 = await ac.post("/chat", json=payload)
assert resp1.status_code == 200 # first request is within limit
resp2 = await ac.post("/chat", json=payload)
assert resp2.status_code == 429 # second request is blocked
assert "Rate limit" in resp2.json()["detail"]
finally:
api_module._rate_limiter = original # restore to avoid polluting other tests