API authentication: - Add X-API-Secret shared-secret header validation on /chat and /stats - /health remains public for monitoring - Auth is a no-op when API_SECRET is empty (dev mode) Rate limiting: - Add per-user sliding-window rate limiter on /chat (10 req/60s default) - Returns 429 with clear message when exceeded - Self-cleaning memory (prunes expired entries on each check) Exception sanitization: - Discord bot no longer exposes raw exception text to users - Error embeds show generic "Something went wrong" message - Full exception details logged server-side with context - query_chat_api RuntimeError no longer includes response body Async correctness: - Wrap synchronous RuleRepository.search() in run_in_executor() to prevent blocking the event loop during SentenceTransformer inference - Port contract stays synchronous; service owns the async boundary Test coverage: 101 passed, 1 skipped (11 new tests for auth + rate limiting) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
277 lines
8.7 KiB
Python
277 lines
8.7 KiB
Python
"""Discord bot for Strat-O-Matic rules Q&A."""
|
|
|
|
import logging
|
|
import discord
|
|
from discord import app_commands
|
|
from discord.ext import commands
|
|
import aiohttp
|
|
from typing import Optional
|
|
|
|
from .config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class StratChatbotBot(commands.Bot):
|
|
"""Discord bot for the rules chatbot."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the bot with default intents."""
|
|
intents = discord.Intents.default()
|
|
intents.message_content = True
|
|
super().__init__(command_prefix="!", intents=intents)
|
|
|
|
self.api_base_url: Optional[str] = None
|
|
self.session: Optional[aiohttp.ClientSession] = None
|
|
|
|
async def setup_hook(self):
|
|
"""Set up the bot's HTTP session and sync commands."""
|
|
self.session = aiohttp.ClientSession()
|
|
# Sync slash commands with Discord
|
|
if settings.discord_guild_id:
|
|
guild = discord.Object(id=int(settings.discord_guild_id))
|
|
self.tree.copy_global_to(guild=guild)
|
|
await self.tree.sync(guild=guild)
|
|
logger.info("Slash commands synced to guild %s", settings.discord_guild_id)
|
|
else:
|
|
await self.tree.sync()
|
|
logger.info("Slash commands synced globally")
|
|
|
|
async def close(self):
|
|
"""Cleanup on shutdown."""
|
|
if self.session:
|
|
await self.session.close()
|
|
await super().close()
|
|
|
|
async def query_chat_api(
|
|
self,
|
|
message: str,
|
|
user_id: str,
|
|
channel_id: str,
|
|
conversation_id: Optional[str] = None,
|
|
parent_message_id: Optional[str] = None,
|
|
) -> dict:
|
|
"""Send a request to the FastAPI chat endpoint."""
|
|
if not self.session:
|
|
raise RuntimeError("Bot HTTP session not initialized")
|
|
|
|
payload = {
|
|
"message": message,
|
|
"user_id": user_id,
|
|
"channel_id": channel_id,
|
|
"conversation_id": conversation_id,
|
|
"parent_message_id": parent_message_id,
|
|
}
|
|
|
|
async with self.session.post(
|
|
f"{self.api_base_url}/chat",
|
|
json=payload,
|
|
timeout=aiohttp.ClientTimeout(total=120),
|
|
) as response:
|
|
if response.status != 200:
|
|
error_text = await response.text()
|
|
logger.error(
|
|
"API returned %s for %s %s — body: %s",
|
|
response.status,
|
|
response.method,
|
|
response.url,
|
|
error_text,
|
|
)
|
|
raise RuntimeError(f"API error {response.status}")
|
|
return await response.json()
|
|
|
|
|
|
bot = StratChatbotBot()
|
|
|
|
|
|
@bot.event
|
|
async def on_ready():
|
|
"""Called when the bot is ready."""
|
|
if not bot.user:
|
|
return
|
|
logger.info("Bot logged in as %s (ID: %s)", bot.user, bot.user.id)
|
|
logger.info("Ready to answer Strat-O-Matic rules questions!")
|
|
|
|
|
|
@bot.tree.command(
|
|
name="ask", description="Ask a question about Strat-O-Matic league rules"
|
|
)
|
|
@app_commands.describe(
|
|
question="Your rules question (e.g., 'Can a runner steal on a 2-2 count?')"
|
|
)
|
|
async def ask_command(interaction: discord.Interaction, question: str):
|
|
"""Handle /ask command."""
|
|
await interaction.response.defer(ephemeral=False)
|
|
|
|
try:
|
|
result = await bot.query_chat_api(
|
|
message=question,
|
|
user_id=str(interaction.user.id),
|
|
channel_id=str(interaction.channel_id),
|
|
conversation_id=None, # New conversation
|
|
parent_message_id=None,
|
|
)
|
|
|
|
# Build response embed
|
|
embed = discord.Embed(
|
|
title="Rules Answer",
|
|
description=result["response"][:4000], # Discord limit
|
|
color=discord.Color.blue(),
|
|
)
|
|
|
|
# Add cited rules if any
|
|
if result.get("cited_rules"):
|
|
embed.add_field(
|
|
name="📋 Cited Rules",
|
|
value=", ".join([f"`{rid}`" for rid in result["cited_rules"]]),
|
|
inline=False,
|
|
)
|
|
|
|
# Add confidence indicator
|
|
confidence = result.get("confidence", 0.0)
|
|
if confidence < 0.4:
|
|
embed.add_field(
|
|
name="⚠️ Confidence",
|
|
value=f"Low ({confidence:.0%}) - A human review has been requested",
|
|
inline=False,
|
|
)
|
|
|
|
# Add conversation ID for follow-ups (full UUID so replies can be threaded)
|
|
embed.set_footer(
|
|
text=f"conv:{result['conversation_id']} | Reply to ask a follow-up"
|
|
)
|
|
|
|
await interaction.followup.send(embed=embed)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Error handling /ask from user %s: %s",
|
|
interaction.user.id,
|
|
e,
|
|
exc_info=True,
|
|
)
|
|
await interaction.followup.send(
|
|
embed=discord.Embed(
|
|
title="❌ Error",
|
|
description="Something went wrong while fetching your answer. Please try again later.",
|
|
color=discord.Color.red(),
|
|
)
|
|
)
|
|
|
|
|
|
@bot.event
|
|
async def on_message(message: discord.Message):
|
|
"""Handle follow-up messages via reply."""
|
|
# Ignore bot messages
|
|
if message.author.bot:
|
|
return
|
|
|
|
# Only handle replies to the bot's messages
|
|
if not message.reference or message.reference.message_id is None:
|
|
return
|
|
|
|
referenced = await message.channel.fetch_message(message.reference.message_id)
|
|
|
|
# Check if the referenced message was from this bot
|
|
if referenced.author != bot.user:
|
|
return
|
|
|
|
# Try to extract conversation ID from the footer
|
|
embed = referenced.embeds[0] if referenced.embeds else None
|
|
if not embed or not embed.footer:
|
|
await message.reply(
|
|
"❓ I couldn't find this conversation. Please use `/ask` to start a new question.",
|
|
mention_author=True,
|
|
)
|
|
return
|
|
|
|
footer_text = embed.footer.text or ""
|
|
if "conv:" not in footer_text:
|
|
await message.reply(
|
|
"❓ Could not determine conversation. Use `/ask` to start fresh.",
|
|
mention_author=True,
|
|
)
|
|
return
|
|
|
|
# Extract full conversation UUID from "conv:<uuid> | ..." format
|
|
try:
|
|
conversation_id = footer_text.split("conv:")[1].split(" ")[0].strip()
|
|
except (IndexError, AttributeError):
|
|
await message.reply(
|
|
"❓ Could not parse conversation ID. Use `/ask` to start fresh.",
|
|
mention_author=True,
|
|
)
|
|
return
|
|
|
|
# Get parent message ID (the original answer message)
|
|
parent_message_id = str(referenced.id)
|
|
|
|
# Send a loading placeholder and replace it with the real answer when ready
|
|
loading_msg = await message.reply(
|
|
"🔍 Looking into that follow-up...", mention_author=True
|
|
)
|
|
|
|
try:
|
|
result = await bot.query_chat_api(
|
|
message=message.content,
|
|
user_id=str(message.author.id),
|
|
channel_id=str(message.channel.id),
|
|
conversation_id=conversation_id,
|
|
parent_message_id=parent_message_id,
|
|
)
|
|
|
|
response_embed = discord.Embed(
|
|
title="Follow-up Answer",
|
|
description=result["response"][:4000],
|
|
color=discord.Color.green(),
|
|
)
|
|
|
|
if result.get("cited_rules"):
|
|
response_embed.add_field(
|
|
name="📋 Cited Rules",
|
|
value=", ".join([f"`{rid}`" for rid in result["cited_rules"]]),
|
|
inline=False,
|
|
)
|
|
|
|
if result.get("confidence", 0.0) < 0.4:
|
|
response_embed.add_field(
|
|
name="⚠️ Confidence",
|
|
value="Low - Human review requested",
|
|
inline=False,
|
|
)
|
|
|
|
# Carry the conversation ID forward so further replies stay in the same thread
|
|
response_embed.set_footer(
|
|
text=f"conv:{result['conversation_id']} | Reply to ask a follow-up"
|
|
)
|
|
|
|
await loading_msg.edit(content=None, embed=response_embed)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Error handling follow-up from user %s in channel %s: %s",
|
|
message.author.id,
|
|
message.channel.id,
|
|
e,
|
|
exc_info=True,
|
|
)
|
|
await loading_msg.edit(
|
|
content=None,
|
|
embed=discord.Embed(
|
|
title="❌ Error",
|
|
description="Something went wrong while processing your follow-up. Please try again later.",
|
|
color=discord.Color.red(),
|
|
),
|
|
)
|
|
|
|
|
|
def run_bot(api_base_url: str = "http://localhost:8000"):
|
|
"""Entry point to run the Discord bot."""
|
|
bot.api_base_url = api_base_url
|
|
|
|
if not settings.discord_bot_token:
|
|
logger.critical("DISCORD_BOT_TOKEN environment variable is required")
|
|
exit(1)
|
|
|
|
bot.run(settings.discord_bot_token)
|