fix: resolve HIGH-severity issues from code review
API authentication: - Add X-API-Secret shared-secret header validation on /chat and /stats - /health remains public for monitoring - Auth is a no-op when API_SECRET is empty (dev mode) Rate limiting: - Add per-user sliding-window rate limiter on /chat (10 req/60s default) - Returns 429 with clear message when exceeded - Self-cleaning memory (prunes expired entries on each check) Exception sanitization: - Discord bot no longer exposes raw exception text to users - Error embeds show generic "Something went wrong" message - Full exception details logged server-side with context - query_chat_api RuntimeError no longer includes response body Async correctness: - Wrap synchronous RuleRepository.search() in run_in_executor() to prevent blocking the event loop during SentenceTransformer inference - Port contract stays synchronous; service owns the async boundary Test coverage: 101 passed, 1 skipped (11 new tests for auth + rate limiting) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c3218f70c4
commit
43d36ce439
@ -8,6 +8,7 @@ tests can substitute fakes without monkey-patching.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
@ -18,6 +19,58 @@ from domain.services import ChatService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rate limiter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Sliding-window in-memory rate limiter keyed by user_id.
|
||||
|
||||
Tracks the timestamps of recent requests for each user. A request is
|
||||
allowed when the number of timestamps within the current window is below
|
||||
max_requests. Old timestamps are pruned on each check so memory does not
|
||||
grow without bound.
|
||||
|
||||
Args:
|
||||
max_requests: Maximum number of requests allowed per user per window.
|
||||
window_seconds: Length of the sliding window in seconds.
|
||||
"""
|
||||
|
||||
def __init__(self, max_requests: int = 10, window_seconds: float = 60.0) -> None:
|
||||
self.max_requests = max_requests
|
||||
self.window_seconds = window_seconds
|
||||
self._timestamps: dict[str, list[float]] = {}
|
||||
|
||||
def check(self, user_id: str) -> bool:
|
||||
"""Return True if the request is allowed, False if rate limited.
|
||||
|
||||
Prunes stale timestamps for the caller on every invocation, so the
|
||||
dict entry naturally shrinks back to zero entries between bursts.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
bucket = self._timestamps.get(user_id)
|
||||
if bucket is None:
|
||||
self._timestamps[user_id] = [now]
|
||||
return True
|
||||
|
||||
# Drop timestamps outside the window, then check the count.
|
||||
pruned = [ts for ts in bucket if ts > cutoff]
|
||||
if len(pruned) >= self.max_requests:
|
||||
self._timestamps[user_id] = pruned
|
||||
return False
|
||||
|
||||
pruned.append(now)
|
||||
self._timestamps[user_id] = pruned
|
||||
return True
|
||||
|
||||
|
||||
# Module-level singleton — shared across all requests in a single process.
|
||||
_rate_limiter = RateLimiter()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@ -84,6 +137,36 @@ def _get_rule_repository(request: Request) -> RuleRepository:
|
||||
return request.app.state.rule_repository
|
||||
|
||||
|
||||
def _check_rate_limit(body: ChatRequest) -> None:
|
||||
"""Raise HTTP 429 if the caller has exceeded their per-window request quota.
|
||||
|
||||
user_id is taken from the parsed request body so it works consistently
|
||||
regardless of how the Discord bot identifies its users. The check uses the
|
||||
module-level _rate_limiter singleton so state is shared across requests.
|
||||
"""
|
||||
if not _rate_limiter.check(body.user_id):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Rate limit exceeded. Please wait before sending another message.",
|
||||
)
|
||||
|
||||
|
||||
def _verify_api_secret(request: Request) -> None:
|
||||
"""Enforce shared-secret authentication when API_SECRET is configured.
|
||||
|
||||
When api_secret on app.state is an empty string the check is skipped
|
||||
entirely, preserving the existing open-access behaviour for local
|
||||
development. Once a secret is set, the caller must supply a matching
|
||||
X-API-Secret header or receive HTTP 401.
|
||||
"""
|
||||
secret: str = getattr(request.app.state, "api_secret", "")
|
||||
if not secret:
|
||||
return
|
||||
header_value = request.headers.get("X-API-Secret", "")
|
||||
if header_value != secret:
|
||||
raise HTTPException(status_code=401, detail="Invalid or missing API secret.")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -94,6 +177,8 @@ async def chat(
|
||||
body: ChatRequest,
|
||||
service: Annotated[ChatService, Depends(_get_chat_service)],
|
||||
rules: Annotated[RuleRepository, Depends(_get_rule_repository)],
|
||||
_auth: Annotated[None, Depends(_verify_api_secret)],
|
||||
_rate: Annotated[None, Depends(_check_rate_limit)],
|
||||
) -> ChatResponse:
|
||||
"""Handle a rules Q&A request.
|
||||
|
||||
@ -152,6 +237,7 @@ async def health(
|
||||
async def stats(
|
||||
rules: Annotated[RuleRepository, Depends(_get_rule_repository)],
|
||||
request: Request,
|
||||
_auth: Annotated[None, Depends(_verify_api_secret)],
|
||||
) -> dict:
|
||||
"""Return extended statistics about the knowledge base and configuration."""
|
||||
kb_stats = rules.get_stats()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Discord bot for Strat-O-Matic rules Q&A."""
|
||||
|
||||
import logging
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
@ -8,6 +9,8 @@ from typing import Optional
|
||||
|
||||
from .config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StratChatbotBot(commands.Bot):
|
||||
"""Discord bot for the rules chatbot."""
|
||||
@ -29,10 +32,10 @@ class StratChatbotBot(commands.Bot):
|
||||
guild = discord.Object(id=int(settings.discord_guild_id))
|
||||
self.tree.copy_global_to(guild=guild)
|
||||
await self.tree.sync(guild=guild)
|
||||
print(f"Slash commands synced to guild {settings.discord_guild_id}")
|
||||
logger.info("Slash commands synced to guild %s", settings.discord_guild_id)
|
||||
else:
|
||||
await self.tree.sync()
|
||||
print("Slash commands synced globally")
|
||||
logger.info("Slash commands synced globally")
|
||||
|
||||
async def close(self):
|
||||
"""Cleanup on shutdown."""
|
||||
@ -67,7 +70,14 @@ class StratChatbotBot(commands.Bot):
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"API error {response.status}: {error_text}")
|
||||
logger.error(
|
||||
"API returned %s for %s %s — body: %s",
|
||||
response.status,
|
||||
response.method,
|
||||
response.url,
|
||||
error_text,
|
||||
)
|
||||
raise RuntimeError(f"API error {response.status}")
|
||||
return await response.json()
|
||||
|
||||
|
||||
@ -79,8 +89,8 @@ async def on_ready():
|
||||
"""Called when the bot is ready."""
|
||||
if not bot.user:
|
||||
return
|
||||
print(f"🤖 Bot logged in as {bot.user} (ID: {bot.user.id})")
|
||||
print("Ready to answer Strat-O-Matic rules questions!")
|
||||
logger.info("Bot logged in as %s (ID: %s)", bot.user, bot.user.id)
|
||||
logger.info("Ready to answer Strat-O-Matic rules questions!")
|
||||
|
||||
|
||||
@bot.tree.command(
|
||||
@ -134,10 +144,16 @@ async def ask_command(interaction: discord.Interaction, question: str):
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error handling /ask from user %s: %s",
|
||||
interaction.user.id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await interaction.followup.send(
|
||||
embed=discord.Embed(
|
||||
title="❌ Error",
|
||||
description=f"Failed to get answer: {str(e)}",
|
||||
description="Something went wrong while fetching your answer. Please try again later.",
|
||||
color=discord.Color.red(),
|
||||
)
|
||||
)
|
||||
@ -232,11 +248,18 @@ async def on_message(message: discord.Message):
|
||||
await loading_msg.edit(content=None, embed=response_embed)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error handling follow-up from user %s in channel %s: %s",
|
||||
message.author.id,
|
||||
message.channel.id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
await loading_msg.edit(
|
||||
content=None,
|
||||
embed=discord.Embed(
|
||||
title="❌ Error",
|
||||
description=f"Failed to process follow-up: {str(e)}",
|
||||
description="Something went wrong while processing your follow-up. Please try again later.",
|
||||
color=discord.Color.red(),
|
||||
),
|
||||
)
|
||||
@ -247,7 +270,7 @@ def run_bot(api_base_url: str = "http://localhost:8000"):
|
||||
bot.api_base_url = api_base_url
|
||||
|
||||
if not settings.discord_bot_token:
|
||||
print("❌ DISCORD_BOT_TOKEN environment variable is required")
|
||||
logger.critical("DISCORD_BOT_TOKEN environment variable is required")
|
||||
exit(1)
|
||||
|
||||
bot.run(settings.discord_bot_token)
|
||||
|
||||
@ -107,6 +107,7 @@ def _make_lifespan(settings: Settings):
|
||||
# Expose via app.state for the router's Depends helpers
|
||||
app.state.chat_service = service
|
||||
app.state.rule_repository = chroma_repo
|
||||
app.state.api_secret = settings.api_secret
|
||||
app.state.config_snapshot = {
|
||||
"openrouter_model": settings.openrouter_model,
|
||||
"top_k_rules": settings.top_k_rules,
|
||||
|
||||
@ -64,6 +64,11 @@ class Settings(BaseSettings):
|
||||
default="sqlite+aiosqlite:///./data/conversations.db", alias="DB_URL"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API authentication
|
||||
# ------------------------------------------------------------------
|
||||
api_secret: str = Field(default="", alias="API_SECRET")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Conversation / retrieval tuning
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
ChatService orchestrates the Q&A flow using only domain ports.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@ -59,8 +60,14 @@ class ChatService:
|
||||
parent_id=parent_message_id,
|
||||
)
|
||||
|
||||
# Search for relevant rules
|
||||
search_results = self.rules.search(query=message, top_k=self.top_k_rules)
|
||||
# Search for relevant rules — offload the synchronous (CPU-bound)
|
||||
# RuleRepository.search() to a thread so the event loop is not blocked
|
||||
# while SentenceTransformer encodes the query.
|
||||
loop = asyncio.get_running_loop()
|
||||
search_results = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.rules.search(query=message, top_k=self.top_k_rules),
|
||||
)
|
||||
|
||||
# Get conversation history for context
|
||||
history = await self.conversations.get_conversation_history(conv_id, limit=10)
|
||||
|
||||
@ -22,10 +22,15 @@ What is tested
|
||||
- POST /chat with a message that exceeds 4000 characters returns HTTP 422.
|
||||
- POST /chat with a user_id that exceeds 64 characters returns HTTP 422.
|
||||
- POST /chat when ChatService.answer_question raises returns HTTP 500.
|
||||
- RateLimiter allows requests within the window and blocks once the limit is hit.
|
||||
- RateLimiter resets after the window expires so the caller can send again.
|
||||
- POST /chat returns 429 when the per-user rate limit is exceeded.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
from fastapi import FastAPI
|
||||
@ -33,7 +38,7 @@ from httpx import ASGITransport
|
||||
|
||||
from domain.models import RuleDocument
|
||||
from domain.services import ChatService
|
||||
from adapters.inbound.api import router
|
||||
from adapters.inbound.api import router, RateLimiter
|
||||
from tests.fakes import (
|
||||
FakeRuleRepository,
|
||||
FakeLLM,
|
||||
@ -53,6 +58,7 @@ def make_test_app(
|
||||
conversations: FakeConversationStore | None = None,
|
||||
issues: FakeIssueTracker | None = None,
|
||||
top_k_rules: int = 5,
|
||||
api_secret: str = "",
|
||||
) -> FastAPI:
|
||||
"""Build a minimal FastAPI app with fakes wired into app.state.
|
||||
|
||||
@ -78,6 +84,7 @@ def make_test_app(
|
||||
|
||||
app.state.chat_service = service
|
||||
app.state.rule_repository = _rules
|
||||
app.state.api_secret = api_secret
|
||||
app.state.config_snapshot = {
|
||||
"openrouter_model": "fake-model",
|
||||
"top_k_rules": top_k_rules,
|
||||
@ -436,3 +443,206 @@ async def test_chat_with_parent_message_id_returns_200(client: httpx.AsyncClient
|
||||
# The response's parent_message_id is the user turn message id,
|
||||
# not the one we passed in — that's the service's threading model.
|
||||
assert body["parent_message_id"] is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API secret authentication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CHAT_PAYLOAD = {"message": "Test question", "user_id": "u1", "channel_id": "ch1"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_no_secret_configured_allows_any_request():
|
||||
"""When api_secret is empty (the default for local dev), POST /chat must
|
||||
succeed without any X-API-Secret header.
|
||||
|
||||
This preserves the existing open-access behaviour so developers can run
|
||||
the service locally without configuring a secret.
|
||||
"""
|
||||
app = make_test_app(api_secret="")
|
||||
async with httpx.AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as ac:
|
||||
resp = await ac.post("/chat", json=_CHAT_PAYLOAD)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_missing_secret_header_returns_401():
|
||||
"""When api_secret is configured, POST /chat without X-API-Secret must
|
||||
return HTTP 401, preventing unauthenticated access to the LLM endpoint.
|
||||
"""
|
||||
app = make_test_app(api_secret="supersecret")
|
||||
async with httpx.AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as ac:
|
||||
resp = await ac.post("/chat", json=_CHAT_PAYLOAD)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_wrong_secret_header_returns_401():
|
||||
"""A request with an incorrect X-API-Secret value must return HTTP 401.
|
||||
|
||||
This guards against callers who know a header is required but are
|
||||
guessing or have an outdated secret.
|
||||
"""
|
||||
app = make_test_app(api_secret="supersecret")
|
||||
async with httpx.AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as ac:
|
||||
resp = await ac.post(
|
||||
"/chat", json=_CHAT_PAYLOAD, headers={"X-API-Secret": "wrongvalue"}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_correct_secret_header_returns_200():
|
||||
"""A request with the correct X-API-Secret header must succeed and return
|
||||
HTTP 200 when api_secret is configured.
|
||||
"""
|
||||
app = make_test_app(api_secret="supersecret")
|
||||
async with httpx.AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as ac:
|
||||
resp = await ac.post(
|
||||
"/chat", json=_CHAT_PAYLOAD, headers={"X-API-Secret": "supersecret"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_always_public():
|
||||
"""GET /health must return 200 regardless of whether api_secret is set.
|
||||
|
||||
Health checks are used by monitoring systems that do not hold application
|
||||
secrets; requiring auth there would break uptime probes.
|
||||
"""
|
||||
app = make_test_app(api_secret="supersecret")
|
||||
async with httpx.AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as ac:
|
||||
resp = await ac.get("/health")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stats_missing_secret_header_returns_401():
|
||||
"""GET /stats without X-API-Secret must return HTTP 401 when a secret is
|
||||
configured.
|
||||
|
||||
The stats endpoint exposes configuration details (model names, retrieval
|
||||
settings) that should be restricted to authenticated callers.
|
||||
"""
|
||||
app = make_test_app(api_secret="supersecret")
|
||||
async with httpx.AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as ac:
|
||||
resp = await ac.get("/stats")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RateLimiter unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_rate_limiter_allows_requests_within_limit():
|
||||
"""Requests below max_requests within the window must all return True.
|
||||
|
||||
We create a limiter with max_requests=3 and verify that three consecutive
|
||||
calls for the same user are all permitted.
|
||||
"""
|
||||
limiter = RateLimiter(max_requests=3, window_seconds=60.0)
|
||||
assert limiter.check("user-a") is True
|
||||
assert limiter.check("user-a") is True
|
||||
assert limiter.check("user-a") is True
|
||||
|
||||
|
||||
def test_rate_limiter_blocks_when_limit_exceeded():
|
||||
"""The (max_requests + 1)-th call within the window must return False.
|
||||
|
||||
This confirms the sliding-window boundary is enforced correctly: once a
|
||||
user has consumed all allowed slots, further requests are rejected until
|
||||
the window advances.
|
||||
"""
|
||||
limiter = RateLimiter(max_requests=3, window_seconds=60.0)
|
||||
for _ in range(3):
|
||||
limiter.check("user-b")
|
||||
assert limiter.check("user-b") is False
|
||||
|
||||
|
||||
def test_rate_limiter_resets_after_window_expires():
|
||||
"""After the window has fully elapsed, a previously rate-limited user must
|
||||
be allowed to send again.
|
||||
|
||||
We use unittest.mock.patch to freeze time.monotonic so the test runs
|
||||
instantly: first we consume the quota at t=0, then advance the clock past
|
||||
the window boundary and confirm the limiter grants the next request.
|
||||
"""
|
||||
limiter = RateLimiter(max_requests=2, window_seconds=10.0)
|
||||
|
||||
with patch("adapters.inbound.api.time") as mock_time:
|
||||
# All requests happen at t=0.
|
||||
mock_time.monotonic.return_value = 0.0
|
||||
limiter.check("user-c")
|
||||
limiter.check("user-c")
|
||||
assert limiter.check("user-c") is False # quota exhausted at t=0
|
||||
|
||||
# Advance time past the full window so all timestamps are stale.
|
||||
mock_time.monotonic.return_value = 11.0
|
||||
assert limiter.check("user-c") is True # window reset; request allowed
|
||||
|
||||
|
||||
def test_rate_limiter_isolates_different_users():
|
||||
"""Rate limiting must be per-user: consuming user-x's quota must not affect
|
||||
user-y's available requests.
|
||||
|
||||
This covers the dict-keying logic — a bug that shares state across users
|
||||
would cause false 429s for innocent callers.
|
||||
"""
|
||||
limiter = RateLimiter(max_requests=1, window_seconds=60.0)
|
||||
limiter.check("user-x") # exhausts user-x's single slot
|
||||
assert limiter.check("user-x") is False # user-x is blocked
|
||||
assert limiter.check("user-y") is True # user-y has their own fresh bucket
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /chat — rate limit integration (HTTP 429)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_returns_429_when_rate_limit_exceeded():
|
||||
"""POST /chat must return HTTP 429 once the per-user rate limit is hit.
|
||||
|
||||
We patch the module-level _rate_limiter so we can exercise the integration
|
||||
between the FastAPI dependency and the limiter without waiting for real time
|
||||
to pass. The first call returns 200; after patching check() to return False,
|
||||
the second call must return 429.
|
||||
"""
|
||||
import adapters.inbound.api as api_module
|
||||
|
||||
# Use a tight limiter (1 request per 60 s) injected into the module so
|
||||
# both the app and the dependency share the same instance.
|
||||
tight_limiter = RateLimiter(max_requests=1, window_seconds=60.0)
|
||||
original = api_module._rate_limiter
|
||||
api_module._rate_limiter = tight_limiter
|
||||
|
||||
payload = {"message": "Hello", "user_id": "rl-user", "channel_id": "ch1"}
|
||||
app = make_test_app()
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as ac:
|
||||
resp1 = await ac.post("/chat", json=payload)
|
||||
assert resp1.status_code == 200 # first request is within limit
|
||||
|
||||
resp2 = await ac.post("/chat", json=payload)
|
||||
assert resp2.status_code == 429 # second request is blocked
|
||||
assert "Rate limit" in resp2.json()["detail"]
|
||||
finally:
|
||||
api_module._rate_limiter = original # restore to avoid polluting other tests
|
||||
|
||||
Loading…
Reference in New Issue
Block a user