strat-chatbot/tests/adapters/test_api.py
Cal Corum 43d36ce439 fix: resolve HIGH-severity issues from code review
API authentication:
- Add X-API-Secret shared-secret header validation on /chat and /stats
- /health remains public for monitoring
- Auth is a no-op when API_SECRET is empty (dev mode)

Rate limiting:
- Add per-user sliding-window rate limiter on /chat (10 req/60s default)
- Returns 429 with clear message when exceeded
- Self-cleaning memory (prunes expired entries on each check)

Exception sanitization:
- Discord bot no longer exposes raw exception text to users
- Error embeds show generic "Something went wrong" message
- Full exception details logged server-side with context
- query_chat_api RuntimeError no longer includes response body

Async correctness:
- Wrap synchronous RuleRepository.search() in run_in_executor()
  to prevent blocking the event loop during SentenceTransformer inference
- Port contract stays synchronous; service owns the async boundary

Test coverage: 101 passed, 1 skipped (11 new tests for auth + rate limiting)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 16:00:26 -05:00

649 lines
23 KiB
Python

"""Tests for the FastAPI inbound adapter (adapters/inbound/api.py).
Strategy
--------
We build a minimal FastAPI app in each fixture by wiring fakes into app.state,
then drive it with httpx.AsyncClient using ASGITransport so no real HTTP server
is needed. This means:
- No real ChromaDB, SQLite, OpenRouter, or Gitea calls.
- Tests are fast, deterministic, and isolated.
- The test app mirrors exactly what the production container does — the only
difference is which objects sit in app.state.
What is tested
--------------
- POST /chat returns 200 and a well-formed ChatResponse for a normal message.
- POST /chat stores the conversation and returns a stable conversation_id on a
second call with the same conversation_id (conversation continuation).
- GET /health returns {"status": "healthy", ...} with rule counts.
- GET /stats returns a knowledge_base sub-dict and a config sub-dict.
- POST /chat with missing required fields returns HTTP 422 (Unprocessable Entity).
- POST /chat with a message that exceeds 4000 characters returns HTTP 422.
- POST /chat with a user_id that exceeds 64 characters returns HTTP 422.
- POST /chat when ChatService.answer_question raises returns HTTP 500.
- RateLimiter allows requests within the window and blocks once the limit is hit.
- RateLimiter resets after the window expires so the caller can send again.
- POST /chat returns 429 when the per-user rate limit is exceeded.
"""
from __future__ import annotations
from unittest.mock import patch
import pytest
import httpx
from fastapi import FastAPI
from httpx import ASGITransport
from domain.models import RuleDocument
from domain.services import ChatService
from adapters.inbound.api import router, RateLimiter
from tests.fakes import (
FakeRuleRepository,
FakeLLM,
FakeConversationStore,
FakeIssueTracker,
)
# ---------------------------------------------------------------------------
# Test app factory
# ---------------------------------------------------------------------------
def make_test_app(
*,
rules: FakeRuleRepository | None = None,
llm: FakeLLM | None = None,
conversations: FakeConversationStore | None = None,
issues: FakeIssueTracker | None = None,
top_k_rules: int = 5,
api_secret: str = "",
) -> FastAPI:
"""Build a minimal FastAPI app with fakes wired into app.state.
The factory mirrors what config/container.py does in production, but uses
in-memory fakes so no external services are needed. Each test that calls
this gets a fresh, isolated set of fakes unless shared fixtures are passed.
"""
_rules = rules or FakeRuleRepository()
_llm = llm or FakeLLM()
_conversations = conversations or FakeConversationStore()
_issues = issues or FakeIssueTracker()
service = ChatService(
rules=_rules,
llm=_llm,
conversations=_conversations,
issues=_issues,
top_k_rules=top_k_rules,
)
app = FastAPI()
app.include_router(router)
app.state.chat_service = service
app.state.rule_repository = _rules
app.state.api_secret = api_secret
app.state.config_snapshot = {
"openrouter_model": "fake-model",
"top_k_rules": top_k_rules,
"embedding_model": "fake-embeddings",
}
return app
# ---------------------------------------------------------------------------
# Shared fixture: an async client backed by the test app
# ---------------------------------------------------------------------------
@pytest.fixture()
async def client() -> httpx.AsyncClient:
"""Return an AsyncClient wired to a fresh test app.
Each test function gets its own completely isolated set of fakes so that
state from one test cannot leak into another.
"""
app = make_test_app()
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
yield ac
# ---------------------------------------------------------------------------
# POST /chat — successful response
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_returns_200_with_valid_payload(client: httpx.AsyncClient):
"""A well-formed POST /chat request must return HTTP 200 and a response body
that maps one-to-one with the ChatResponse Pydantic model.
We verify every field so a structural change to ChatResult or ChatResponse
is caught immediately rather than silently producing a wrong value.
"""
payload = {
"message": "How many strikes to strike out?",
"user_id": "user-001",
"channel_id": "channel-001",
}
resp = await client.post("/chat", json=payload)
assert resp.status_code == 200
body = resp.json()
assert isinstance(body["response"], str)
assert len(body["response"]) > 0
assert isinstance(body["conversation_id"], str)
assert isinstance(body["message_id"], str)
assert isinstance(body["cited_rules"], list)
assert isinstance(body["confidence"], float)
assert isinstance(body["needs_human"], bool)
@pytest.mark.asyncio
async def test_chat_uses_rules_when_available():
"""When the FakeRuleRepository has documents matching the query, the FakeLLM
receives them and returns a high-confidence answer with cited_rules populated.
This exercises the full ChatService flow through the inbound adapter.
"""
rules_repo = FakeRuleRepository()
rules_repo.add_documents(
[
RuleDocument(
rule_id="1.1",
title="Batting Order",
section="Batting",
content="A batter gets three strikes before striking out.",
source_file="rules.pdf",
)
]
)
app = make_test_app(rules=rules_repo)
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.post(
"/chat",
json={
"message": "How many strikes before a batter strikes out?",
"user_id": "user-abc",
"channel_id": "ch-xyz",
},
)
assert resp.status_code == 200
body = resp.json()
# FakeLLM returns cited_rules when rules are found
assert len(body["cited_rules"]) > 0
assert body["confidence"] > 0.5
# ---------------------------------------------------------------------------
# POST /chat — conversation continuation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_continues_existing_conversation():
"""Supplying conversation_id in the request should resume the same
conversation rather than creating a new one.
We make two requests: the first creates a conversation and returns its ID;
the second passes that ID back and must return the same conversation_id.
This ensures the FakeConversationStore (and real SQLite adapter) behave
consistently from the router's perspective.
"""
conversations = FakeConversationStore()
app = make_test_app(conversations=conversations)
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
# First turn — no conversation_id
resp1 = await ac.post(
"/chat",
json={
"message": "First question",
"user_id": "user-42",
"channel_id": "ch-1",
},
)
assert resp1.status_code == 200
conv_id = resp1.json()["conversation_id"]
# Second turn — same conversation
resp2 = await ac.post(
"/chat",
json={
"message": "Follow-up question",
"user_id": "user-42",
"channel_id": "ch-1",
"conversation_id": conv_id,
},
)
assert resp2.status_code == 200
assert resp2.json()["conversation_id"] == conv_id
# ---------------------------------------------------------------------------
# GET /health
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_health_returns_healthy_status(client: httpx.AsyncClient):
"""GET /health must return {"status": "healthy", ...} with integer rule count
and a sections dict.
The FakeRuleRepository starts empty so rules_count should be 0.
"""
resp = await client.get("/health")
assert resp.status_code == 200
body = resp.json()
assert body["status"] == "healthy"
assert isinstance(body["rules_count"], int)
assert isinstance(body["sections"], dict)
@pytest.mark.asyncio
async def test_health_reflects_loaded_rules():
"""After adding documents to FakeRuleRepository, GET /health must show the
updated rule count. This confirms the router reads a live reference to the
repository, not a snapshot taken at startup.
"""
rules_repo = FakeRuleRepository()
rules_repo.add_documents(
[
RuleDocument(
rule_id="2.1",
title="Pitching",
section="Pitching",
content="The pitcher throws the ball.",
source_file="rules.pdf",
)
]
)
app = make_test_app(rules=rules_repo)
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.get("/health")
assert resp.status_code == 200
assert resp.json()["rules_count"] == 1
# ---------------------------------------------------------------------------
# GET /stats
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_stats_returns_knowledge_base_and_config(client: httpx.AsyncClient):
"""GET /stats must include a knowledge_base sub-dict (from RuleRepository.get_stats)
and a config sub-dict (from app.state.config_snapshot set by the container).
This ensures the stats endpoint exposes enough information for an operator
to confirm what model and retrieval settings are active.
"""
resp = await client.get("/stats")
assert resp.status_code == 200
body = resp.json()
assert "knowledge_base" in body
assert "config" in body
assert "total_rules" in body["knowledge_base"]
# ---------------------------------------------------------------------------
# POST /chat — validation errors (HTTP 422)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_missing_message_returns_422(client: httpx.AsyncClient):
"""Omitting the required 'message' field must trigger Pydantic validation and
return HTTP 422 Unprocessable Entity with a detail array describing the error.
We do NOT want a 500 — a missing field is a client error, not a server error.
"""
resp = await client.post("/chat", json={"user_id": "u1", "channel_id": "ch1"})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_missing_user_id_returns_422(client: httpx.AsyncClient):
"""Omitting 'user_id' must return HTTP 422."""
resp = await client.post("/chat", json={"message": "Hello", "channel_id": "ch1"})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_missing_channel_id_returns_422(client: httpx.AsyncClient):
"""Omitting 'channel_id' must return HTTP 422."""
resp = await client.post("/chat", json={"message": "Hello", "user_id": "u1"})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_message_too_long_returns_422(client: httpx.AsyncClient):
"""A message that exceeds 4000 characters must fail field-level validation
and return HTTP 422 rather than passing to the service layer.
The max_length constraint on ChatRequest.message enforces this.
"""
long_message = "x" * 4001
resp = await client.post(
"/chat",
json={"message": long_message, "user_id": "u1", "channel_id": "ch1"},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_user_id_too_long_returns_422(client: httpx.AsyncClient):
"""A user_id that exceeds 64 characters must return HTTP 422.
Discord snowflakes are at most 20 digits; 64 chars is a generous cap that
still prevents runaway strings from reaching the database layer.
"""
long_user_id = "u" * 65
resp = await client.post(
"/chat",
json={"message": "Hello", "user_id": long_user_id, "channel_id": "ch1"},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_channel_id_too_long_returns_422(client: httpx.AsyncClient):
"""A channel_id that exceeds 64 characters must return HTTP 422."""
long_channel_id = "c" * 65
resp = await client.post(
"/chat",
json={"message": "Hello", "user_id": "u1", "channel_id": long_channel_id},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_empty_message_returns_422(client: httpx.AsyncClient):
"""An empty string for 'message' must fail min_length=1 and return HTTP 422.
We never want an empty string propagated to the LLM — it would produce a
confusing response and waste tokens.
"""
resp = await client.post(
"/chat", json={"message": "", "user_id": "u1", "channel_id": "ch1"}
)
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# POST /chat — service-layer exception bubbles up as 500
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_service_exception_returns_500():
"""When ChatService.answer_question raises an unexpected exception the router
must catch it and return HTTP 500, not let the exception propagate and crash
the server process.
We use FakeLLM(force_error=...) to inject the failure deterministically.
"""
broken_llm = FakeLLM(force_error=RuntimeError("LLM exploded"))
app = make_test_app(llm=broken_llm)
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.post(
"/chat",
json={"message": "Hello", "user_id": "u1", "channel_id": "ch1"},
)
assert resp.status_code == 500
assert "LLM exploded" in resp.json()["detail"]
# ---------------------------------------------------------------------------
# POST /chat — parent_message_id thread reply
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_with_parent_message_id_returns_200(client: httpx.AsyncClient):
"""Supplying the optional parent_message_id must not cause an error.
The field passes through to ChatService and ends up in the conversation
store. We just assert a 200 here — the service-layer tests cover the
parent_id wiring in more detail.
"""
resp = await client.post(
"/chat",
json={
"message": "Thread reply",
"user_id": "u1",
"channel_id": "ch1",
"parent_message_id": "some-parent-uuid",
},
)
assert resp.status_code == 200
body = resp.json()
# The response's parent_message_id is the user turn message id,
# not the one we passed in — that's the service's threading model.
assert body["parent_message_id"] is not None
# ---------------------------------------------------------------------------
# API secret authentication
# ---------------------------------------------------------------------------
_CHAT_PAYLOAD = {"message": "Test question", "user_id": "u1", "channel_id": "ch1"}
@pytest.mark.asyncio
async def test_chat_no_secret_configured_allows_any_request():
"""When api_secret is empty (the default for local dev), POST /chat must
succeed without any X-API-Secret header.
This preserves the existing open-access behaviour so developers can run
the service locally without configuring a secret.
"""
app = make_test_app(api_secret="")
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.post("/chat", json=_CHAT_PAYLOAD)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_chat_missing_secret_header_returns_401():
"""When api_secret is configured, POST /chat without X-API-Secret must
return HTTP 401, preventing unauthenticated access to the LLM endpoint.
"""
app = make_test_app(api_secret="supersecret")
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.post("/chat", json=_CHAT_PAYLOAD)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_chat_wrong_secret_header_returns_401():
"""A request with an incorrect X-API-Secret value must return HTTP 401.
This guards against callers who know a header is required but are
guessing or have an outdated secret.
"""
app = make_test_app(api_secret="supersecret")
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.post(
"/chat", json=_CHAT_PAYLOAD, headers={"X-API-Secret": "wrongvalue"}
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_chat_correct_secret_header_returns_200():
"""A request with the correct X-API-Secret header must succeed and return
HTTP 200 when api_secret is configured.
"""
app = make_test_app(api_secret="supersecret")
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.post(
"/chat", json=_CHAT_PAYLOAD, headers={"X-API-Secret": "supersecret"}
)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_health_always_public():
"""GET /health must return 200 regardless of whether api_secret is set.
Health checks are used by monitoring systems that do not hold application
secrets; requiring auth there would break uptime probes.
"""
app = make_test_app(api_secret="supersecret")
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.get("/health")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_stats_missing_secret_header_returns_401():
"""GET /stats without X-API-Secret must return HTTP 401 when a secret is
configured.
The stats endpoint exposes configuration details (model names, retrieval
settings) that should be restricted to authenticated callers.
"""
app = make_test_app(api_secret="supersecret")
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.get("/stats")
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# RateLimiter unit tests
# ---------------------------------------------------------------------------
def test_rate_limiter_allows_requests_within_limit():
"""Requests below max_requests within the window must all return True.
We create a limiter with max_requests=3 and verify that three consecutive
calls for the same user are all permitted.
"""
limiter = RateLimiter(max_requests=3, window_seconds=60.0)
assert limiter.check("user-a") is True
assert limiter.check("user-a") is True
assert limiter.check("user-a") is True
def test_rate_limiter_blocks_when_limit_exceeded():
"""The (max_requests + 1)-th call within the window must return False.
This confirms the sliding-window boundary is enforced correctly: once a
user has consumed all allowed slots, further requests are rejected until
the window advances.
"""
limiter = RateLimiter(max_requests=3, window_seconds=60.0)
for _ in range(3):
limiter.check("user-b")
assert limiter.check("user-b") is False
def test_rate_limiter_resets_after_window_expires():
"""After the window has fully elapsed, a previously rate-limited user must
be allowed to send again.
We use unittest.mock.patch to freeze time.monotonic so the test runs
instantly: first we consume the quota at t=0, then advance the clock past
the window boundary and confirm the limiter grants the next request.
"""
limiter = RateLimiter(max_requests=2, window_seconds=10.0)
with patch("adapters.inbound.api.time") as mock_time:
# All requests happen at t=0.
mock_time.monotonic.return_value = 0.0
limiter.check("user-c")
limiter.check("user-c")
assert limiter.check("user-c") is False # quota exhausted at t=0
# Advance time past the full window so all timestamps are stale.
mock_time.monotonic.return_value = 11.0
assert limiter.check("user-c") is True # window reset; request allowed
def test_rate_limiter_isolates_different_users():
"""Rate limiting must be per-user: consuming user-x's quota must not affect
user-y's available requests.
This covers the dict-keying logic — a bug that shares state across users
would cause false 429s for innocent callers.
"""
limiter = RateLimiter(max_requests=1, window_seconds=60.0)
limiter.check("user-x") # exhausts user-x's single slot
assert limiter.check("user-x") is False # user-x is blocked
assert limiter.check("user-y") is True # user-y has their own fresh bucket
# ---------------------------------------------------------------------------
# POST /chat — rate limit integration (HTTP 429)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_returns_429_when_rate_limit_exceeded():
"""POST /chat must return HTTP 429 once the per-user rate limit is hit.
We patch the module-level _rate_limiter so we can exercise the integration
between the FastAPI dependency and the limiter without waiting for real time
to pass. The first call returns 200; after patching check() to return False,
the second call must return 429.
"""
import adapters.inbound.api as api_module
# Use a tight limiter (1 request per 60 s) injected into the module so
# both the app and the dependency share the same instance.
tight_limiter = RateLimiter(max_requests=1, window_seconds=60.0)
original = api_module._rate_limiter
api_module._rate_limiter = tight_limiter
payload = {"message": "Hello", "user_id": "rl-user", "channel_id": "ch1"}
app = make_test_app()
try:
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp1 = await ac.post("/chat", json=payload)
assert resp1.status_code == 200 # first request is within limit
resp2 = await ac.post("/chat", json=payload)
assert resp2.status_code == 429 # second request is blocked
assert "Rate limit" in resp2.json()["detail"]
finally:
api_module._rate_limiter = original # restore to avoid polluting other tests