From c42fea66bac2c061a0750682686b575968c817fb Mon Sep 17 00:00:00 2001 From: Cal Corum Date: Sun, 8 Mar 2026 15:19:26 -0500 Subject: [PATCH] feat: initial chatbot implementation with FastAPI, ChromaDB, Discord bot, and Gitea integration - Add vector store with sentence-transformers for semantic search - FastAPI backend with /chat and /health endpoints - Conversation state persistence via SQLite - OpenRouter integration with structured JSON responses - Discord bot with /ask slash command and reply-based follow-ups - Automated Gitea issue creation for unanswered questions - Docker support with docker-compose for easy deployment - Example rule file and ingestion script - Comprehensive documentation in README --- .env.example | 22 ++++ .gitignore | 43 +++++++ Dockerfile | 48 ++++++++ README.md | 229 +++++++++++++++++++++++++++++++++++ app/__init__.py | 1 + app/config.py | 53 ++++++++ app/database.py | 161 +++++++++++++++++++++++++ app/discord_bot.py | 240 +++++++++++++++++++++++++++++++++++++ app/gitea.py | 108 +++++++++++++++++ app/llm.py | 179 +++++++++++++++++++++++++++ app/main.py | 198 ++++++++++++++++++++++++++++++ app/models.py | 100 ++++++++++++++++ app/vector_store.py | 166 +++++++++++++++++++++++++ data/rules/example_rule.md | 20 ++++ docker-compose.yml | 87 ++++++++++++++ pyproject.toml | 44 +++++++ scripts/ingest_rules.py | 144 ++++++++++++++++++++++ setup.sh | 58 +++++++++ tests/test_basic.py | 63 ++++++++++ 19 files changed, 1964 insertions(+) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 README.md create mode 100644 app/__init__.py create mode 100644 app/config.py create mode 100644 app/database.py create mode 100644 app/discord_bot.py create mode 100644 app/gitea.py create mode 100644 app/llm.py create mode 100644 app/main.py create mode 100644 app/models.py create mode 100644 app/vector_store.py create mode 100644 data/rules/example_rule.md create mode 100644 docker-compose.yml create mode 100644 pyproject.toml create mode 100644 scripts/ingest_rules.py create mode 100755 setup.sh create mode 100644 tests/test_basic.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..741da48 --- /dev/null +++ b/.env.example @@ -0,0 +1,22 @@ +# OpenRouter Configuration +OPENROUTER_API_KEY=your_openrouter_api_key_here +OPENROUTER_MODEL=stepfun/step-3.5-flash:free + +# Discord Bot Configuration +DISCORD_BOT_TOKEN=your_discord_bot_token_here +DISCORD_GUILD_ID=your_guild_id_here # Optional, speeds up slash command sync + +# Gitea Configuration (for issue creation) +GITEA_TOKEN=your_gitea_token_here +GITEA_OWNER=cal +GITEA_REPO=strat-chatbot +GITEA_BASE_URL=https://git.manticorum.com/api/v1 + +# Application Configuration +DATA_DIR=./data +RULES_DIR=./data/rules +CHROMA_DIR=./data/chroma +DB_URL=sqlite+aiosqlite:///./data/conversations.db +CONVERSATION_TTL=1800 +TOP_K_RULES=10 +EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5c11d88 --- /dev/null +++ b/.gitignore @@ -0,0 +1,43 @@ +# Python +__pycache__/ +*.pyc +*.pyo +*.pyd +.Python +env/ +venv/ +.venv/ +*.egg-info/ +dist/ +build/ +poetry.lock + +# Data files (except example rules) +data/chroma/ +data/conversations.db + +# Environment +.env +.env.local + +# IDE +.vscode/ +.idea/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db + +# Logs +*.log + +# Temporary +tmp/ +temp/ +.mypy_cache/ +.pytest_cache/ + +# Docker +.dockerignore diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..ee5d1af --- /dev/null +++ b/Dockerfile @@ -0,0 +1,48 @@ +# Multi-stage build for Strat-Chatbot +FROM python:3.12-slim AS builder + +WORKDIR /app + +# Install system dependencies for sentence-transformers (PyTorch, etc.) +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + g++ \ + && rm -rf /var/lib/apt/lists/* + +# Install Poetry +RUN pip install --no-cache-dir poetry + +# Copy dependencies +COPY pyproject.toml ./ +COPY README.md ./ + +# Install dependencies +RUN poetry config virtualenvs.in-project true && \ + poetry install --no-interaction --no-ansi --only main + +# Final stage +FROM python:3.12-slim + +WORKDIR /app + +# Copy virtual environment from builder +COPY --from=builder /app/.venv .venv +ENV PATH="/app/.venv/bin:$PATH" + +# Create non-root user +RUN useradd --create-home --shell /bin/bash app && chown -R app:app /app +USER app + +# Copy application code +COPY --chown=app:app app/ ./app/ +COPY --chown=app:app data/ ./data/ +COPY --chown=app:app scripts/ ./scripts/ + +# Create data directories +RUN mkdir -p data/chroma data/rules + +# Expose ports +EXPOSE 8000 + +# Run FastAPI server +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..1e398cb --- /dev/null +++ b/README.md @@ -0,0 +1,229 @@ +# Strat-Chatbot + +AI-powered Q&A chatbot for Strat-O-Matic baseball league rules. + +## Features + +- **Natural language Q&A**: Ask questions about league rules in plain English +- **Semantic search**: Uses ChromaDB vector embeddings to find relevant rules +- **Rule citations**: Always cites specific rule IDs (e.g., "Rule 5.2.1(b)") +- **Conversation threading**: Maintains conversation context for follow-up questions +- **Gitea integration**: Automatically creates issues for unanswered questions +- **Discord integration**: Slash command `/ask` with reply-based follow-ups + +## Architecture + +``` +┌─────────┐ ┌──────────────┐ ┌─────────────┐ +│ Discord │────│ FastAPI │────│ ChromaDB │ +│ Bot │ │ (port 8000) │ │ (vectors) │ +└─────────┘ └──────────────┘ └─────────────┘ + │ + ┌───────▼──────┐ + │ Markdown │ + │ Rule Files │ + └──────────────┘ + │ + ┌───────▼──────┐ + │ OpenRouter │ + │ (LLM API) │ + └──────────────┘ + │ + ┌───────▼──────┐ + │ Gitea │ + │ Issues │ + └──────────────┘ +``` + +## Quick Start + +### Prerequisites + +- Docker & Docker Compose +- OpenRouter API key +- Discord bot token +- Gitea token (optional, for issue creation) + +### Setup + +1. **Clone and configure** + +```bash +cd strat-chatbot +cp .env.example .env +# Edit .env with your API keys and tokens +``` + +2. **Prepare rules** + +Place your rule documents in `data/rules/` as Markdown files with YAML frontmatter: + +```markdown +--- +rule_id: "5.2.1(b)" +title: "Stolen Base Attempts" +section: "Baserunning" +parent_rule: "5.2" +page_ref: "32" +--- + +When a runner attempts to steal... +``` + +3. **Ingest rules** + +```bash +# With Docker Compose (recommended) +docker compose up -d +docker compose exec api python scripts/ingest_rules.py + +# Or locally +uv sync +uv run scripts/ingest_rules.py +``` + +4. **Start services** + +```bash +docker compose up -d +``` + +The API will be available at http://localhost:8000 + +The Discord bot will connect and sync slash commands. + +### Runtime Configuration + +| Environment Variable | Required? | Description | +|---------------------|-----------|-------------| +| `OPENROUTER_API_KEY` | Yes | OpenRouter API key | +| `OPENROUTER_MODEL` | No | Model ID (default: `stepfun/step-3.5-flash:free`) | +| `DISCORD_BOT_TOKEN` | No | Discord bot token (omit to run API only) | +| `DISCORD_GUILD_ID` | No | Guild ID for slash command sync (faster than global) | +| `GITEA_TOKEN` | No | Gitea API token (for issue creation) | +| `GITEA_OWNER` | No | Gitea username (default: `cal`) | +| `GITEA_REPO` | No | Repository name (default: `strat-chatbot`) | + +## API Endpoints + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/health` | GET | Health check with stats | +| `/chat` | POST | Send a question and get a response | +| `/stats` | GET | Knowledge base and system statistics | + +### Chat Request + +```json +{ + "message": "Can a runner steal on a 2-2 count?", + "user_id": "123456789", + "channel_id": "987654321", + "conversation_id": "optional-uuid", + "parent_message_id": "optional-parent-uuid" +} +``` + +### Chat Response + +```json +{ + "response": "Yes, according to Rule 5.2.1(b)...", + "conversation_id": "conv-uuid", + "message_id": "msg-uuid", + "cited_rules": ["5.2.1(b)", "5.3"], + "confidence": 0.85, + "needs_human": false +} +``` + +## Development + +### Local Development (without Docker) + +```bash +# Install dependencies +uv sync + +# Ingest rules +uv run scripts/ingest_rules.py + +# Run API server +uv run app/main.py + +# In another terminal, run Discord bot +uv run app/discord_bot.py +``` + +### Project Structure + +``` +strat-chatbot/ +├── app/ +│ ├── __init__.py +│ ├── config.py # Configuration management +│ ├── database.py # SQLAlchemy conversation state +│ ├── gitea.py # Gitea API client +│ ├── llm.py # OpenRouter integration +│ ├── main.py # FastAPI app +│ ├── models.py # Pydantic models +│ ├── vector_store.py # ChromaDB wrapper +│ └── discord_bot.py # Discord bot +├── data/ +│ ├── chroma/ # Vector DB (auto-created) +│ └── rules/ # Your markdown rule files +├── scripts/ +│ └── ingest_rules.py # Ingestion pipeline +├── tests/ # Test files +├── .env.example +├── Dockerfile +├── docker-compose.yml +└── pyproject.toml +``` + +## Performance Optimizations + +- **Embedding cache**: ChromaDB persists embeddings on disk +- **Rule chunking**: Each rule is a separate document, no context fragmentation +- **Top-k search**: Configurable number of rules to retrieve (default: 10) +- **Conversation TTL**: 30 minutes to limit database size +- **Async operations**: All I/O is non-blocking + +## Testing the API + +```bash +curl -X POST http://localhost:8000/chat \ + -H "Content-Type: application/json" \ + -d '{ + "message": "What happens if the pitcher balks?", + "user_id": "test123", + "channel_id": "general" + }' +``` + +## Gitea Integration + +When the bot encounters a question it can't answer confidently (confidence < 0.4), it will automatically: + +1. Log the question to console +2. Create an issue in your configured Gitea repo +3. Include: user ID, channel, question, attempted rules, conversation link + +Issues are labeled with: +- `rules-gap` - needs a rule addition or clarification +- `ai-generated` - created by AI bot +- `needs-review` - requires human administrator attention + +## To-Do + +- [ ] Build OpenRouter Docker client with proper torch dependencies +- [ ] Add PDF ingestion support (convert PDF → Markdown) +- [ ] Implement rule change detection and incremental updates +- [ ] Add rate limiting per Discord user +- [ ] Create admin endpoints for rule management +- [ ] Add Prometheus metrics for monitoring +- [ ] Build unit and integration tests + +## License + +TBD diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..aacf55a --- /dev/null +++ b/app/__init__.py @@ -0,0 +1 @@ +"""Strat-Chatbot application package.""" diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..203a064 --- /dev/null +++ b/app/config.py @@ -0,0 +1,53 @@ +"""Configuration management using Pydantic Settings.""" + +from pathlib import Path +from pydantic_settings import BaseSettings +from pydantic import Field + + +class Settings(BaseSettings): + """Application settings with environment variable overrides.""" + + # OpenRouter + openrouter_api_key: str = Field(default="", env="OPENROUTER_API_KEY") + openrouter_model: str = Field( + default="stepfun/step-3.5-flash:free", env="OPENROUTER_MODEL" + ) + + # Discord + discord_bot_token: str = Field(default="", env="DISCORD_BOT_TOKEN") + discord_guild_id: str | None = Field(default=None, env="DISCORD_GUILD_ID") + + # Gitea + gitea_token: str = Field(default="", env="GITEA_TOKEN") + gitea_owner: str = Field(default="cal", env="GITEA_OWNER") + gitea_repo: str = Field(default="strat-chatbot", env="GITEA_REPO") + gitea_base_url: str = Field( + default="https://git.manticorum.com/api/v1", env="GITEA_BASE_URL" + ) + + # Paths + data_dir: Path = Field(default=Path("./data"), env="DATA_DIR") + rules_dir: Path = Field(default=Path("./data/rules"), env="RULES_DIR") + chroma_dir: Path = Field(default=Path("./data/chroma"), env="CHROMA_DIR") + + # Database + db_url: str = Field( + default="sqlite+aiosqlite:///./data/conversations.db", env="DB_URL" + ) + + # Conversation state TTL (seconds) + conversation_ttl: int = Field(default=1800, env="CONVERSATION_TTL") + + # Vector search + top_k_rules: int = Field(default=10, env="TOP_K_RULES") + embedding_model: str = Field( + default="sentence-transformers/all-MiniLM-L6-v2", env="EMBEDDING_MODEL" + ) + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + + +settings = Settings() diff --git a/app/database.py b/app/database.py new file mode 100644 index 0000000..e937541 --- /dev/null +++ b/app/database.py @@ -0,0 +1,161 @@ +"""SQLAlchemy-based conversation state management with aiosqlite.""" + +from datetime import datetime, timedelta +from typing import Optional +import uuid +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy import Column, String, DateTime, Boolean, ForeignKey, select +from .config import settings + +Base = declarative_base() + + +class ConversationTable(Base): + """SQLAlchemy model for conversations.""" + + __tablename__ = "conversations" + + id = Column(String, primary_key=True) + user_id = Column(String, nullable=False) + channel_id = Column(String, nullable=False) + created_at = Column(DateTime, default=datetime.utcnow) + last_activity = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + +class MessageTable(Base): + """SQLAlchemy model for messages.""" + + __tablename__ = "messages" + + id = Column(String, primary_key=True) + conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False) + content = Column(String, nullable=False) + is_user = Column(Boolean, nullable=False) + parent_id = Column(String, ForeignKey("messages.id"), nullable=True) + created_at = Column(DateTime, default=datetime.utcnow) + + +class ConversationManager: + """Manages conversation state in SQLite.""" + + def __init__(self, db_url: str): + """Initialize database engine and session factory.""" + self.engine = create_async_engine(db_url, echo=False) + self.async_session = sessionmaker( + self.engine, class_=AsyncSession, expire_on_commit=False + ) + + async def init_db(self): + """Create tables if they don't exist.""" + async with self.engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async def get_or_create_conversation( + self, user_id: str, channel_id: str, conversation_id: Optional[str] = None + ) -> str: + """Get existing conversation or create a new one.""" + async with self.async_session() as session: + if conversation_id: + result = await session.execute( + select(ConversationTable).where( + ConversationTable.id == conversation_id + ) + ) + conv = result.scalar_one_or_none() + if conv: + conv.last_activity = datetime.utcnow() + await session.commit() + return conv.id + + # Create new conversation + new_id = str(uuid.uuid4()) + conv = ConversationTable(id=new_id, user_id=user_id, channel_id=channel_id) + session.add(conv) + await session.commit() + return new_id + + async def add_message( + self, + conversation_id: str, + content: str, + is_user: bool, + parent_id: Optional[str] = None, + ) -> str: + """Add a message to a conversation.""" + message_id = str(uuid.uuid4()) + async with self.async_session() as session: + msg = MessageTable( + id=message_id, + conversation_id=conversation_id, + content=content, + is_user=is_user, + parent_id=parent_id, + ) + session.add(msg) + + # Update conversation activity + result = await session.execute( + select(ConversationTable).where(ConversationTable.id == conversation_id) + ) + conv = result.scalar_one_or_none() + if conv: + conv.last_activity = datetime.utcnow() + + await session.commit() + return message_id + + async def get_conversation_history( + self, conversation_id: str, limit: int = 10 + ) -> list[dict]: + """Get recent messages from a conversation in OpenAI format.""" + async with self.async_session() as session: + result = await session.execute( + select(MessageTable) + .where(MessageTable.conversation_id == conversation_id) + .order_by(MessageTable.created_at.desc()) + .limit(limit) + ) + messages = result.scalars().all() + # Reverse to get chronological order and convert to API format + history = [] + for msg in reversed(messages): + role = "user" if msg.is_user else "assistant" + history.append({"role": role, "content": msg.content}) + + return history + + async def cleanup_old_conversations(self, ttl_seconds: int = 1800): + """Delete conversations older than TTL to free up storage.""" + cutoff = datetime.utcnow() - timedelta(seconds=ttl_seconds) + async with self.async_session() as session: + # Find old conversations + result = await session.execute( + select(ConversationTable).where( + ConversationTable.last_activity < cutoff + ) + ) + old_convs = result.scalars().all() + + conv_ids = [conv.id for conv in old_convs] + if conv_ids: + # Delete messages first (cascade would handle but explicit is clear) + await session.execute( + sa.delete(MessageTable).where( + MessageTable.conversation_id.in_(conv_ids) + ) + ) + # Delete conversations + await session.execute( + sa.delete(ConversationTable).where( + ConversationTable.id.in_(conv_ids) + ) + ) + await session.commit() + print(f"Cleaned up {len(conv_ids)} old conversations") + + +async def get_conversation_manager() -> ConversationManager: + """Dependency for FastAPI to get a ConversationManager instance.""" + return ConversationManager(settings.db_url) diff --git a/app/discord_bot.py b/app/discord_bot.py new file mode 100644 index 0000000..0504c08 --- /dev/null +++ b/app/discord_bot.py @@ -0,0 +1,240 @@ +"""Discord bot for Strat-O-Matic rules Q&A.""" + +import discord +from discord import app_commands +from discord.ext import commands +import asyncio +import aiohttp +from typing import Optional +import uuid + +from .config import settings + + +class StratChatbotBot(commands.Bot): + """Discord bot for the rules chatbot.""" + + def __init__(self): + """Initialize the bot with default intents.""" + intents = discord.Intents.default() + intents.message_content = True + super().__init__(command_prefix="!", intents=intents) + + self.api_base_url: Optional[str] = None + self.session: Optional[aiohttp.ClientSession] = None + + async def setup_hook(self): + """Set up the bot's HTTP session and sync commands.""" + self.session = aiohttp.ClientSession() + # Sync slash commands with Discord + if settings.discord_guild_id: + guild = discord.Object(id=int(settings.discord_guild_id)) + self.tree.copy_global_to(guild=guild) + await self.tree.sync(guild=guild) + print(f"Slash commands synced to guild {settings.discord_guild_id}") + else: + await self.tree.sync() + print("Slash commands synced globally") + + async def close(self): + """Cleanup on shutdown.""" + if self.session: + await self.session.close() + await super().close() + + async def query_chat_api( + self, + message: str, + user_id: str, + channel_id: str, + conversation_id: Optional[str] = None, + parent_message_id: Optional[str] = None, + ) -> dict: + """Send a request to the FastAPI chat endpoint.""" + payload = { + "message": message, + "user_id": user_id, + "channel_id": channel_id, + "conversation_id": conversation_id, + "parent_message_id": parent_message_id, + } + + async with self.session.post( + f"{self.api_base_url}/chat", + json=payload, + timeout=aiohttp.ClientTimeout(total=120), + ) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"API error {response.status}: {error_text}") + return await response.json() + + +bot = StratChatbotBot() + + +@bot.event +async def on_ready(): + """Called when the bot is ready.""" + print(f"🤖 Bot logged in as {bot.user} (ID: {bot.user.id})") + print("Ready to answer Strat-O-Matic rules questions!") + + +@bot.tree.command( + name="ask", description="Ask a question about Strat-O-Matic league rules" +) +@app_commands.describe( + question="Your rules question (e.g., 'Can a runner steal on a 2-2 count?')" +) +async def ask_command(interaction: discord.Interaction, question: str): + """Handle /ask command.""" + await interaction.response.defer(ephemeral=False) + + try: + result = await bot.query_chat_api( + message=question, + user_id=str(interaction.user.id), + channel_id=str(interaction.channel_id), + conversation_id=None, # New conversation + parent_message_id=None, + ) + + # Build response embed + embed = discord.Embed( + title="Rules Answer", + description=result["response"][:4000], # Discord limit + color=discord.Color.blue(), + ) + + # Add cited rules if any + if result.get("cited_rules"): + embed.add_field( + name="📋 Cited Rules", + value=", ".join([f"`{rid}`" for rid in result["cited_rules"]]), + inline=False, + ) + + # Add confidence indicator + confidence = result.get("confidence", 0.0) + if confidence < 0.4: + embed.add_field( + name="⚠️ Confidence", + value=f"Low ({confidence:.0%}) - A human review has been requested", + inline=False, + ) + + # Add conversation ID for follow-ups + embed.set_footer( + text=f"Conversation: {result['conversation_id'][:8]}... " + f"| Reply to ask a follow-up" + ) + + await interaction.followup.send(embed=embed) + + except Exception as e: + await interaction.followup.send( + embed=discord.Embed( + title="❌ Error", + description=f"Failed to get answer: {str(e)}", + color=discord.Color.red(), + ) + ) + + +@bot.event +async def on_message(message: discord.Message): + """Handle follow-up messages via reply.""" + # Ignore bot messages + if message.author.bot: + return + + # Only handle replies to the bot's messages + if not message.reference: + return + + referenced = await message.channel.fetch_message(message.reference.message_id) + + # Check if the referenced message was from this bot + if referenced.author != bot.user: + return + + # Try to extract conversation ID from the footer + embed = referenced.embeds[0] if referenced.embeds else None + if not embed or not embed.footer: + await message.reply( + "❓ I couldn't find this conversation. Please use `/ask` to start a new question.", + mention_author=True, + ) + return + + footer_text = embed.footer.text or "" + if "Conversation:" not in footer_text: + await message.reply( + "❓ Could not determine conversation. Use `/ask` to start fresh.", + mention_author=True, + ) + return + + # Extract conversation ID (rough parsing) + try: + conv_id = footer_text.split("Conversation:")[1].split("...")[0].strip() + conversation_id = footer_text.split("Conversation:")[1].split(" ")[0].strip() + except: + conversation_id = None + + # Get parent message ID (the original answer message) + parent_message_id = str(referenced.id) + + # Forward to API + await message.reply("🔍 Looking into that follow-up...", mention_author=True) + + try: + result = await bot.query_chat_api( + message=message.content, + user_id=str(message.author.id), + channel_id=str(message.channel_id), + conversation_id=conversation_id, + parent_message_id=parent_message_id, + ) + + response_embed = discord.Embed( + title="Follow-up Answer", + description=result["response"][:4000], + color=discord.Color.green(), + ) + + if result.get("cited_rules"): + response_embed.add_field( + name="📋 Cited Rules", + value=", ".join([f"`{rid}`" for rid in result["cited_rules"]]), + inline=False, + ) + + if result.get("confidence", 0.0) < 0.4: + response_embed.add_field( + name="⚠️ Confidence", + value=f"Low - Human review requested", + inline=False, + ) + + await message.reply(embed=response_embed, mention_author=True) + + except Exception as e: + await message.reply( + embed=discord.Embed( + title="❌ Error", + description=f"Failed to process follow-up: {str(e)}", + color=discord.Color.red(), + ) + ) + + +def run_bot(api_base_url: str = "http://localhost:8000"): + """Entry point to run the Discord bot.""" + bot.api_base_url = api_base_url + + if not settings.discord_bot_token: + print("❌ DISCORD_BOT_TOKEN environment variable is required") + exit(1) + + bot.run(settings.discord_bot_token) diff --git a/app/gitea.py b/app/gitea.py new file mode 100644 index 0000000..9d72622 --- /dev/null +++ b/app/gitea.py @@ -0,0 +1,108 @@ +"""Gitea client for creating issues when questions need human review.""" + +import httpx +from typing import Optional +from .config import settings + + +class GiteaClient: + """Client for Gitea API interactions.""" + + def __init__(self): + """Initialize Gitea client with credentials.""" + self.token = settings.gitea_token + self.owner = settings.gitea_owner + self.repo = settings.gitea_repo + self.base_url = settings.gitea_base_url.rstrip("/") + self.headers = { + "Authorization": f"token {self.token}", + "Content-Type": "application/json", + "Accept": "application/json", + } + self._client: Optional[httpx.AsyncClient] = None + + async def __aenter__(self): + """Async context manager entry.""" + self._client = httpx.AsyncClient(timeout=30.0) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._client: + await self._client.aclose() + + async def create_issue( + self, + title: str, + body: str, + labels: Optional[list[str]] = None, + assignee: Optional[str] = None, + ) -> dict: + """Create a new issue in the configured repository.""" + if not self._client: + raise RuntimeError("GiteaClient must be used as async context manager") + + url = f"{self.base_url}/repos/{self.owner}/{self.repo}/issues" + + payload = {"title": title, "body": body} + + if labels: + payload["labels"] = labels + + if assignee: + payload["assignee"] = assignee + + response = await self._client.post(url, headers=self.headers, json=payload) + + if response.status_code not in (200, 201): + error_detail = response.text + raise RuntimeError( + f"Gitea API error creating issue: {response.status_code} - {error_detail}" + ) + + return response.json() + + async def create_unanswered_issue( + self, + question: str, + user_id: str, + channel_id: str, + attempted_rules: list[str], + conversation_id: str, + ) -> str: + """Create an issue for an unanswered question needing human review.""" + title = f"🤔 Unanswered rules question: {question[:80]}{'...' if len(question) > 80 else ''}" + + body = f"""## Unanswered Question + +**User:** {user_id} + +**Channel:** {channel_id} + +**Conversation ID:** {conversation_id} + +**Question:** +{question} + +**Searched Rules:** +{', '.join(attempted_rules) if attempted_rules else 'None'} + +**Additional Context:** +This question was asked in Discord and the bot could not provide a confident answer. The rules either don't cover this question or the information was ambiguous. + +--- + +*This issue was automatically created by the Strat-Chatbot.*""" + + labels = ["rules-gap", "ai-generated", "needs-review"] + + issue = await self.create_issue(title=title, body=body, labels=labels) + + return issue.get("html_url", "") + + +def get_gitea_client() -> Optional[GiteaClient]: + """Factory to get Gitea client if token is configured.""" + if settings.gitea_token: + return GiteaClient() + return None diff --git a/app/llm.py b/app/llm.py new file mode 100644 index 0000000..e804fc4 --- /dev/null +++ b/app/llm.py @@ -0,0 +1,179 @@ +"""OpenRouter LLM integration for answering rules questions.""" + +from typing import Optional +import json +import httpx +from .config import settings +from .models import RuleSearchResult, ChatResponse + +SYSTEM_PROMPT = """You are a helpful assistant for a Strat-O-Matic baseball league. +Your job is to answer questions about league rules and procedures using the provided rule excerpts. + +CRITICAL RULES: +1. ONLY use information from the provided rules. If the rules don't contain the answer, say so clearly. +2. ALWAYS cite rule IDs when referencing a rule (e.g., "Rule 5.2.1(b) states that...") +3. If multiple rules are relevant, cite all of them. +4. If you're uncertain or the rules are ambiguous, say so and suggest asking a league administrator. +5. Keep responses concise but complete. Use examples when helpful from the rules. +6. Do NOT make up rules or infer beyond what's explicitly stated. + +When answering: +- Start with a direct answer to the question +- Support with rule citations +- Include relevant details from the rules +- If no relevant rules found, explicitly state: "I don't have a rule that addresses this question." + +Response format (JSON): +{ + "answer": "Your response text", + "cited_rules": ["rule_id_1", "rule_id_2"], + "confidence": 0.0-1.0, + "needs_human": boolean +} + +Higher confidence (0.8-1.0) when rules clearly answer the question. +Lower confidence (0.3-0.7) when rules partially address the question or are ambiguous. +Very low confidence (0.0-0.2) when rules don't address the question at all. +""" + + +class OpenRouterClient: + """Client for OpenRouter API.""" + + def __init__(self): + """Initialize the client.""" + self.api_key = settings.openrouter_api_key + if not self.api_key: + raise ValueError("OPENROUTER_API_KEY is required") + self.model = settings.openrouter_model + self.base_url = "https://openrouter.ai/api/v1/chat/completions" + + async def generate_response( + self, + question: str, + rules: list[RuleSearchResult], + conversation_history: Optional[list[dict]] = None, + ) -> ChatResponse: + """Generate a response using the LLM with retrieved rules as context.""" + # Build context from rules + rules_context = "\n\n".join( + [f"Rule {r.rule_id}: {r.title}\n{r.content}" for r in rules] + ) + + if rules: + context_msg = ( + f"Here are the relevant rules for the question:\n\n{rules_context}" + ) + else: + context_msg = "No relevant rules were found in the knowledge base." + + # Build conversation history + messages = [{"role": "system", "content": SYSTEM_PROMPT}] + + if conversation_history: + # Add last few turns of conversation (limit to avoid token overflow) + messages.extend( + conversation_history[-6:] + ) # Last 3 exchanges (user+assistant) + + # Add current question with context + user_message = f"{context_msg}\n\nUser question: {question}\n\nAnswer the question based on the rules provided." + messages.append({"role": "user", "content": user_message}) + + # Call OpenRouter API + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post( + self.base_url, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + json={ + "model": self.model, + "messages": messages, + "temperature": 0.3, + "max_tokens": 1000, + "top_p": 0.9, + }, + ) + + if response.status_code != 200: + error_detail = response.text + raise RuntimeError( + f"OpenRouter API error: {response.status_code} - {error_detail}" + ) + + result = response.json() + content = result["choices"][0]["message"]["content"] + + # Parse the JSON response + try: + # Extract JSON from response (LLM might add markdown formatting) + if "```json" in content: + json_str = content.split("```json")[1].split("```")[0].strip() + else: + json_str = content.strip() + + parsed = json.loads(json_str) + + cited_rules = parsed.get("cited_rules", []) + if not cited_rules and rules: + # Fallback: extract rule IDs from the text if not properly returned + import re + + rule_ids = re.findall( + r"Rule\s+([\d\.\(\)a-b]+)", parsed.get("answer", "") + ) + cited_rules = list(set(rule_ids)) + + return ChatResponse( + response=parsed["answer"], + conversation_id="", # Will be set by caller + message_id="", # Will be set by caller + cited_rules=cited_rules, + confidence=float(parsed.get("confidence", 0.5)), + needs_human=bool(parsed.get("needs_human", False)), + ) + except (json.JSONDecodeError, KeyError) as e: + # If parsing fails, return what we can extract + return ChatResponse( + response=content, + conversation_id="", + message_id="", + cited_rules=[], + confidence=0.5, + needs_human=False, + ) + + +class MockLLMClient: + """Mock LLM client for testing without API calls.""" + + async def generate_response( + self, + question: str, + rules: list[RuleSearchResult], + conversation_history: Optional[list[dict]] = None, + ) -> ChatResponse: + """Return a mock response.""" + if rules: + rule_list = ", ".join([r.rule_id for r in rules]) + answer = f"Based on rule(s) {rule_list}, here's what you need to know..." + else: + answer = "I don't have a rule that addresses this question. You should ask a league administrator." + + return ChatResponse( + response=answer, + conversation_id="", + message_id="", + cited_rules=[r.rule_id for r in rules], + confidence=1.0 if rules else 0.0, + needs_human=not rules, + ) + + +def get_llm_client(use_mock: bool = False): + """Factory to get the appropriate LLM client.""" + if use_mock or not settings.openrouter_api_key: + return MockLLMClient() + return OpenRouterClient() diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..ffee99d --- /dev/null +++ b/app/main.py @@ -0,0 +1,198 @@ +"""FastAPI application for Strat-O-Matic rules chatbot.""" + +from contextlib import asynccontextmanager +from typing import Optional +import uuid + +from fastapi import FastAPI, HTTPException, Depends +import uvicorn +import sqlalchemy as sa + +from .config import settings +from .models import ChatRequest, ChatResponse +from .vector_store import VectorStore +from .database import ConversationManager, get_conversation_manager +from .llm import get_llm_client +from .gitea import GiteaClient + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage application lifespan - startup and shutdown.""" + # Startup + print("Initializing Strat-Chatbot...") + + # Initialize vector store + chroma_dir = settings.data_dir / "chroma" + vector_store = VectorStore(chroma_dir, settings.embedding_model) + print(f"Vector store ready at {chroma_dir} ({vector_store.count()} rules loaded)") + + # Initialize database + db_manager = ConversationManager(settings.db_url) + await db_manager.init_db() + print("Database initialized") + + # Initialize LLM client + llm_client = get_llm_client(use_mock=not settings.openrouter_api_key) + print(f"LLM client ready (model: {settings.openrouter_model})") + + # Initialize Gitea client + gitea_client = GiteaClient() if settings.gitea_token else None + + # Store in app state + app.state.vector_store = vector_store + app.state.db_manager = db_manager + app.state.llm_client = llm_client + app.state.gitea_client = gitea_client + + print("Strat-Chatbot ready!") + + yield + + # Shutdown + print("Shutting down...") + + +app = FastAPI( + title="Strat-Chatbot", + description="Strat-O-Matic rules Q&A API", + version="0.1.0", + lifespan=lifespan, +) + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + vector_store: VectorStore = app.state.vector_store + stats = vector_store.get_stats() + return { + "status": "healthy", + "rules_count": stats["total_rules"], + "sections": stats["sections"], + } + + +@app.post("/chat", response_model=ChatResponse) +async def chat( + request: ChatRequest, + db_manager: ConversationManager = Depends(get_conversation_manager), +): + """Handle chat requests from Discord.""" + vector_store: VectorStore = app.state.vector_store + llm_client = app.state.llm_client + gitea_client = app.state.gitea_client + + # Validate API key if using real LLM + if not settings.openrouter_api_key: + return ChatResponse( + response="⚠️ OpenRouter API key not configured. Set OPENROUTER_API_KEY environment variable.", + conversation_id=request.conversation_id or str(uuid.uuid4()), + message_id=str(uuid.uuid4()), + cited_rules=[], + confidence=0.0, + needs_human=True, + ) + + # Get or create conversation + conversation_id = await db_manager.get_or_create_conversation( + user_id=request.user_id, + channel_id=request.channel_id, + conversation_id=request.conversation_id, + ) + + # Save user message + user_message_id = await db_manager.add_message( + conversation_id=conversation_id, + content=request.message, + is_user=True, + parent_id=request.parent_message_id, + ) + + try: + # Search for relevant rules + search_results = vector_store.search( + query=request.message, top_k=settings.top_k_rules + ) + + # Get conversation history for context + history = await db_manager.get_conversation_history(conversation_id, limit=10) + + # Generate response from LLM + response = await llm_client.generate_response( + question=request.message, rules=search_results, conversation_history=history + ) + + # Save assistant message + assistant_message_id = await db_manager.add_message( + conversation_id=conversation_id, + content=response.response, + is_user=False, + parent_id=user_message_id, + ) + + # If needs human or confidence is low, create Gitea issue + if gitea_client and (response.needs_human or response.confidence < 0.4): + try: + issue_url = await gitea_client.create_unanswered_issue( + question=request.message, + user_id=request.user_id, + channel_id=request.channel_id, + attempted_rules=[r.rule_id for r in search_results], + conversation_id=conversation_id, + ) + print(f"Created Gitea issue: {issue_url}") + except Exception as e: + print(f"Failed to create Gitea issue: {e}") + + # Build final response + return ChatResponse( + response=response.response, + conversation_id=conversation_id, + message_id=assistant_message_id, + parent_message_id=user_message_id, + cited_rules=response.cited_rules, + confidence=response.confidence, + needs_human=response.needs_human, + ) + + except Exception as e: + print(f"Error processing chat request: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/stats") +async def stats(): + """Get statistics about the knowledge base and system.""" + vector_store: VectorStore = app.state.vector_store + db_manager: ConversationManager = app.state.db_manager + + # Get vector store stats + vs_stats = vector_store.get_stats() + + # Get database stats + async with db_manager.async_session() as session: + conv_count = await session.execute( + sa.text("SELECT COUNT(*) FROM conversations") + ) + msg_count = await session.execute(sa.text("SELECT COUNT(*) FROM messages")) + + total_conversations = conv_count.scalar() or 0 + total_messages = msg_count.scalar() or 0 + + return { + "knowledge_base": vs_stats, + "conversations": { + "total": total_conversations, + "total_messages": total_messages, + }, + "config": { + "openrouter_model": settings.openrouter_model, + "top_k_rules": settings.top_k_rules, + "embedding_model": settings.embedding_model, + }, + } + + +if __name__ == "__main__": + uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True) diff --git a/app/models.py b/app/models.py new file mode 100644 index 0000000..0ae2b66 --- /dev/null +++ b/app/models.py @@ -0,0 +1,100 @@ +"""Data models for rules and conversations.""" + +from pydantic import BaseModel, Field +from typing import Optional +from datetime import datetime + + +class RuleMetadata(BaseModel): + """Frontmatter metadata for a rule document.""" + + rule_id: str = Field(..., description="Unique rule identifier, e.g. '5.2.1(b)'") + title: str = Field(..., description="Rule title") + section: str = Field(..., description="Section/category name") + parent_rule: Optional[str] = Field( + None, description="Parent rule ID for hierarchical rules" + ) + last_updated: str = Field( + default_factory=lambda: datetime.now().strftime("%Y-%m-%d"), + description="Last update date", + ) + page_ref: Optional[str] = Field( + None, description="Reference to page number in rulebook" + ) + + +class RuleDocument(BaseModel): + """Complete rule document with metadata and content.""" + + metadata: RuleMetadata + content: str = Field(..., description="Rule text and examples") + source_file: str = Field(..., description="Source file path") + embedding: Optional[list[float]] = None + + def to_chroma_metadata(self) -> dict: + """Convert to ChromaDB metadata format.""" + return { + "rule_id": self.metadata.rule_id, + "title": self.metadata.title, + "section": self.metadata.section, + "parent_rule": self.metadata.parent_rule or "", + "page_ref": self.metadata.page_ref or "", + "last_updated": self.metadata.last_updated, + "source_file": self.source_file, + } + + +class Conversation(BaseModel): + """Conversation session.""" + + id: str + user_id: str # Discord user ID + channel_id: str # Discord channel ID + created_at: datetime = Field(default_factory=datetime.now) + last_activity: datetime = Field(default_factory=datetime.now) + + +class Message(BaseModel): + """Individual message in a conversation.""" + + id: str + conversation_id: str + content: str + is_user: bool + parent_id: Optional[str] = None + created_at: datetime = Field(default_factory=datetime.now) + + +class ChatRequest(BaseModel): + """Incoming chat request from Discord.""" + + message: str + conversation_id: Optional[str] = None + parent_message_id: Optional[str] = None + user_id: str + channel_id: str + + +class ChatResponse(BaseModel): + """Response to chat request.""" + + response: str + conversation_id: str + message_id: str + parent_message_id: Optional[str] = None + cited_rules: list[str] = Field(default_factory=list) + confidence: float = Field(..., ge=0.0, le=1.0) + needs_human: bool = Field( + default=False, + description="Whether the question needs human review (unanswered)", + ) + + +class RuleSearchResult(BaseModel): + """Result from vector search.""" + + rule_id: str + title: str + content: str + section: str + similarity: float = Field(..., ge=0.0, le=1.0) diff --git a/app/vector_store.py b/app/vector_store.py new file mode 100644 index 0000000..3fc85f7 --- /dev/null +++ b/app/vector_store.py @@ -0,0 +1,166 @@ +"""ChromaDB vector store for rule embeddings.""" + +from pathlib import Path +from typing import Optional +import chromadb +from chromadb.config import Settings as ChromaSettings +from sentence_transformers import SentenceTransformer +import numpy as np +from .config import settings +from .models import RuleDocument, RuleSearchResult + + +class VectorStore: + """Wrapper around ChromaDB for rule retrieval.""" + + def __init__(self, persist_dir: Path, embedding_model: str): + """Initialize vector store with embedding model.""" + self.persist_dir = Path(persist_dir) + self.persist_dir.mkdir(parents=True, exist_ok=True) + + chroma_settings = ChromaSettings( + anonymized_telemetry=False, is_persist_directory_actually_writable=True + ) + + self.client = chromadb.PersistentClient( + path=str(self.persist_dir), settings=chroma_settings + ) + + self.embedding_model = SentenceTransformer(embedding_model) + + def get_collection(self): + """Get or create the rules collection.""" + return self.client.get_or_create_collection( + name="rules", metadata={"hnsw:space": "cosine"} + ) + + def add_document(self, doc: RuleDocument) -> None: + """Add a single rule document to the vector store.""" + embedding = self.embedding_model.encode(doc.content).tolist() + + collection = self.get_collection() + collection.add( + ids=[doc.metadata.rule_id], + embeddings=[embedding], + documents=[doc.content], + metadatas=[doc.to_chroma_metadata()], + ) + + def add_documents(self, docs: list[RuleDocument]) -> None: + """Add multiple documents in batch.""" + if not docs: + return + + ids = [doc.metadata.rule_id for doc in docs] + contents = [doc.content for doc in docs] + embeddings = self.embedding_model.encode(contents).tolist() + metadatas = [doc.to_chroma_metadata() for doc in docs] + + collection = self.get_collection() + collection.add( + ids=ids, embeddings=embeddings, documents=contents, metadatas=metadatas + ) + + def search( + self, query: str, top_k: int = 10, section_filter: Optional[str] = None + ) -> list[RuleSearchResult]: + """Search for relevant rules using semantic similarity.""" + query_embedding = self.embedding_model.encode(query).tolist() + + collection = self.get_collection() + + where = None + if section_filter: + where = {"section": section_filter} + + results = collection.query( + query_embeddings=[query_embedding], + n_results=top_k, + where=where, + include=["documents", "metadatas", "distances"], + ) + + search_results = [] + if results and results["documents"] and results["documents"][0]: + for i in range(len(results["documents"][0])): + metadata = results["metadatas"][0][i] + distance = results["distances"][0][i] + similarity = 1 - distance # Convert cosine distance to similarity + + search_results.append( + RuleSearchResult( + rule_id=metadata["rule_id"], + title=metadata["title"], + content=results["documents"][0][i], + section=metadata["section"], + similarity=similarity, + ) + ) + + return search_results + + def delete_rule(self, rule_id: str) -> None: + """Remove a rule by its ID.""" + collection = self.get_collection() + collection.delete(ids=[rule_id]) + + def clear_all(self) -> None: + """Delete all rules from the collection.""" + self.client.delete_collection("rules") + self.get_collection() # Recreate empty collection + + def get_rule(self, rule_id: str) -> Optional[RuleSearchResult]: + """Retrieve a specific rule by ID.""" + collection = self.get_collection() + result = collection.get(ids=[rule_id], include=["documents", "metadatas"]) + + if result and result["documents"] and result["documents"][0]: + metadata = result["metadatas"][0][0] + return RuleSearchResult( + rule_id=metadata["rule_id"], + title=metadata["title"], + content=result["documents"][0][0], + section=metadata["section"], + similarity=1.0, + ) + return None + + def list_all_rules(self) -> list[RuleSearchResult]: + """Return all rules in the store.""" + collection = self.get_collection() + result = collection.get(include=["documents", "metadatas"]) + + all_rules = [] + if result and result["documents"]: + for i in range(len(result["documents"])): + metadata = result["metadatas"][i] + all_rules.append( + RuleSearchResult( + rule_id=metadata["rule_id"], + title=metadata["title"], + content=result["documents"][i], + section=metadata["section"], + similarity=1.0, + ) + ) + + return all_rules + + def count(self) -> int: + """Return the number of rules in the store.""" + collection = self.get_collection() + return collection.count() + + def get_stats(self) -> dict: + """Get statistics about the vector store.""" + collection = self.get_collection() + all_rules = self.list_all_rules() + sections = {} + for rule in all_rules: + sections[rule.section] = sections.get(rule.section, 0) + 1 + + return { + "total_rules": len(all_rules), + "sections": sections, + "persist_directory": str(self.persist_dir), + } diff --git a/data/rules/example_rule.md b/data/rules/example_rule.md new file mode 100644 index 0000000..85849aa --- /dev/null +++ b/data/rules/example_rule.md @@ -0,0 +1,20 @@ +--- +rule_id: "5.2.1(b)" +title: "Stolen Base Attempts" +section: "Baserunning" +parent_rule: "5.2" +page_ref: "32" +--- + +When a runner attempts to steal a base: +1. Roll 2 six-sided dice. +2. Add the result to the runner's **Steal** rating. +3. Compare to the catcher's **Caught Stealing** (CS) column on the defensive chart. +4. If the total equals or exceeds the CS number, the runner is successful. + +**Example**: Runner with SB-2 rolls a 7. Total = 7 + 2 = 9. Catcher's CS is 11. 9 < 11, so the steal is successful. + +**Important notes**: +- Runners can only steal if they are on base and there are no outs. +- Do not attempt to steal when the pitcher has a **Pickoff** rating of 5 or higher. +- A failed steal results in an out and advances any other runners only if instructed by the result. diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..33d4a16 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,87 @@ +version: '3.8' + +services: + chroma: + image: chromadb/chroma:latest + volumes: + - ./data/chroma:/chroma/chroma_storage + ports: + - "8001:8000" + environment: + - CHROMA_SERVER_HOST=0.0.0.0 + - CHROMA_SERVER_PORT=8000 + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/api/v1/heartbeat"] + interval: 10s + timeout: 5s + retries: 5 + + api: + build: + context: . + dockerfile: Dockerfile + volumes: + - ./data:/app/data + - ./app:/app/app + ports: + - "8000:8000" + environment: + - OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-} + - OPENROUTER_MODEL=${OPENROUTER_MODEL:-stepfun/step-3.5-flash:free} + - GITEA_TOKEN=${GITEA_TOKEN:-} + - GITEA_OWNER=${GITEA_OWNER:-cal} + - GITEA_REPO=${GITEA_REPO:-strat-chatbot} + - DATA_DIR=/app/data + - RULES_DIR=/app/data/rules + - CHROMA_DIR=/app/data/chroma + - DB_URL=sqlite+aiosqlite:///./data/conversations.db + - CONVERSATION_TTL=1800 + - TOP_K_RULES=10 + - EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 + depends_on: + chroma: + condition: service_healthy + command: > + sh -c " + # Wait for database file creation on first run + sleep 2 && + # Initialize database if it doesn't exist + python -c 'import asyncio; from app.database import ConversationManager; mgr = ConversationManager(\"sqlite+aiosqlite:///./data/conversations.db\"); asyncio.run(mgr.init_db())' || true && + uvicorn app.main:app --host 0.0.0.0 --port 8000 + " + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 15s + timeout: 10s + retries: 3 + start_period: 30s + + discord-bot: + build: + context: . + dockerfile: Dockerfile + volumes: + - ./data:/app/data + - ./app:/app/app + environment: + - OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-} + - OPENROUTER_MODEL=${OPENROUTER_MODEL:-stepfun/step-3.5-flash:free} + - DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-} + - DISCORD_GUILD_ID=${DISCORD_GUILD_ID:-} + - API_BASE_URL=http://api:8000 + depends_on: + api: + condition: service_healthy + # Override the default command to run the Discord bot + command: > + sh -c " + echo 'Waiting for API to be ready...' && + while ! curl -s http://api:8000/health > /dev/null; do sleep 2; done && + echo 'API ready, starting Discord bot...' && + python -m app.discord_bot + " + restart: unless-stopped + +volumes: + chroma_data: + app_data: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d7ae7d7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,44 @@ +[project] +name = "strat-chatbot" +version = "0.1.0" +description = "Strat-O-Matic rules Q&A chatbot" +requires-python = ">=3.11" +dependencies = [ + "fastapi>=0.115.0", + "uvicorn[standard]>=0.30.0", + "discord.py>=2.5.0", + "chromadb>=0.5.0", + "sentence-transformers>=3.0.0", + "openai>=1.0.0", + "python-dotenv>=1.0.0", + "sqlalchemy>=2.0.0", + "aiosqlite>=2.0.0", + "pydantic>=2.0.0", + "pydantic-settings>=2.0.0", + "httpx>=0.27.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "black>=24.0.0", + "ruff>=0.5.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.black] +line-length = 88 +target-version = ['py311'] + +[tool.ruff] +line-length = 88 +select = ["E", "F", "B", "I"] +target-version = "py311" + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] diff --git a/scripts/ingest_rules.py b/scripts/ingest_rules.py new file mode 100644 index 0000000..e6862e3 --- /dev/null +++ b/scripts/ingest_rules.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +""" +Ingest rule documents from markdown files into ChromaDB. + +The script reads all markdown files from the rules directory and adds them +to the vector store. Each file should have YAML frontmatter with metadata +fields matching RuleMetadata. + +Example frontmatter: +--- +rule_id: "5.2.1(b)" +title: "Stolen Base Attempts" +section: "Baserunning" +parent_rule: "5.2" +page_ref: "32" +--- + +Rule content here... +""" + +import sys +import re +from pathlib import Path +from typing import Optional +import yaml + +from app.config import settings +from app.vector_store import VectorStore +from app.models import RuleDocument, RuleMetadata + + +def parse_frontmatter(content: str) -> tuple[dict, str]: + """Parse YAML frontmatter from markdown content.""" + pattern = r"^---\s*\n(.*?)\n---\s*\n(.*)$" + match = re.match(pattern, content, re.DOTALL) + + if match: + frontmatter_str = match.group(1) + body_content = match.group(2).strip() + metadata = yaml.safe_load(frontmatter_str) or {} + return metadata, body_content + else: + raise ValueError("No valid YAML frontmatter found") + + +def load_markdown_file(filepath: Path) -> Optional[RuleDocument]: + """Load a single markdown file and convert to RuleDocument.""" + try: + content = filepath.read_text(encoding="utf-8") + metadata_dict, body = parse_frontmatter(content) + + # Validate and create metadata + metadata = RuleMetadata(**metadata_dict) + + # Use filename as source reference + source_file = str(filepath.relative_to(Path.cwd())) + + return RuleDocument(metadata=metadata, content=body, source_file=source_file) + except Exception as e: + print(f"Error loading {filepath}: {e}", file=sys.stderr) + return None + + +def ingest_rules( + rules_dir: Path, vector_store: VectorStore, clear_existing: bool = False +) -> None: + """Ingest all markdown rule files into the vector store.""" + if not rules_dir.exists(): + print(f"Rules directory does not exist: {rules_dir}") + sys.exit(1) + + if clear_existing: + print("Clearing existing vector store...") + vector_store.clear_all() + + # Find all markdown files + md_files = list(rules_dir.rglob("*.md")) + if not md_files: + print(f"No markdown files found in {rules_dir}") + sys.exit(1) + + print(f"Found {len(md_files)} markdown files to ingest") + + # Load and validate documents + documents = [] + for filepath in md_files: + doc = load_markdown_file(filepath) + if doc: + documents.append(doc) + print(f" Loaded: {doc.metadata.rule_id} - {doc.metadata.title}") + + print(f"Successfully loaded {len(documents)} documents") + + # Add to vector store + print("Adding to vector store (this may take a moment)...") + vector_store.add_documents(documents) + + print(f"\nIngestion complete!") + print(f"Total rules in store: {vector_store.count()}") + stats = vector_store.get_stats() + print("Sections:", ", ".join(f"{k}: {v}" for k, v in stats["sections"].items())) + + +def main(): + """Main entry point.""" + import argparse + + parser = argparse.ArgumentParser(description="Ingest rule documents into ChromaDB") + parser.add_argument( + "--rules-dir", + type=Path, + default=settings.rules_dir, + help="Directory containing markdown rule files", + ) + parser.add_argument( + "--data-dir", + type=Path, + default=settings.data_dir, + help="Data directory (chroma will be stored in data/chroma)", + ) + parser.add_argument( + "--clear", + action="store_true", + help="Clear existing vector store before ingesting", + ) + parser.add_argument( + "--embedding-model", + type=str, + default=settings.embedding_model, + help="Sentence transformer model name", + ) + + args = parser.parse_args() + + chroma_dir = args.data_dir / "chroma" + print(f"Initializing vector store at: {chroma_dir}") + print(f"Using embedding model: {args.embedding_model}") + + vector_store = VectorStore(chroma_dir, args.embedding_model) + ingest_rules(args.rules_dir, vector_store, clear_existing=args.clear) + + +if __name__ == "__main__": + main() diff --git a/setup.sh b/setup.sh new file mode 100755 index 0000000..595dd09 --- /dev/null +++ b/setup.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash +# Setup script for Strat-Chatbot + +set -e + +echo "=== Strat-Chatbot Setup ===" + +# Check for .env file +if [ ! -f .env ]; then + echo "Creating .env from template..." + cp .env.example .env + echo "⚠️ Please edit .env and add your OpenRouter API key (and optionally Discord/Gitea keys)" + exit 1 +fi + +# Create necessary directories +mkdir -p data/rules +mkdir -p data/chroma + +# Check if uv is installed +if ! command -v uv &> /dev/null; then + echo "Installing uv package manager..." + curl -LsSf https://astral.sh/uv/install.sh | sh + export PATH="$HOME/.local/bin:$PATH" +fi + +# Install dependencies +echo "Installing Python dependencies..." +uv sync + +# Initialize database +echo "Initializing database..." +uv run python -c "from app.database import ConversationManager; import asyncio; mgr = ConversationManager('sqlite+aiosqlite:///./data/conversations.db'); asyncio.run(mgr.init_db())" + +# Check if rules exist +if ! ls data/rules/*.md 1> /dev/null 2>&1; then + echo "⚠️ No rule files found in data/rules/" + echo " Please add your markdown rule files to data/rules/" + exit 1 +fi + +# Ingest rules +echo "Ingesting rules into vector store..." +uv run python scripts/ingest_rules.py + +echo "✅ Setup complete!" +echo "" +echo "Next steps:" +echo "1. Ensure your .env file has OPENROUTER_API_KEY set" +echo "2. (Optional) Set DISCORD_BOT_TOKEN to enable Discord bot" +echo "3. Start the API:" +echo " uv run app/main.py" +echo "" +echo "Or use Docker Compose:" +echo " docker compose up -d" +echo "" +echo "API will be at: http://localhost:8000" +echo "Docs at: http://localhost:8000/docs" diff --git a/tests/test_basic.py b/tests/test_basic.py new file mode 100644 index 0000000..34739d1 --- /dev/null +++ b/tests/test_basic.py @@ -0,0 +1,63 @@ +"""Basic test to verify the vector store and ingestion.""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent / "app")) + +from app.config import settings +from app.vector_store import VectorStore +from app.models import RuleDocument, RuleMetadata + + +def test_ingest_example_rule(): + """Test ingesting the example rule and searching.""" + # Override settings for test + test_data_dir = Path(__file__).parent.parent / "data" + test_chroma_dir = test_data_dir / "chroma_test" + test_rules_dir = test_data_dir / "rules" + + vs = VectorStore(test_chroma_dir, settings.embedding_model) + vs.clear_all() + + # Load example rule + example_rule_path = test_rules_dir / "example_rule.md" + if not example_rule_path.exists(): + print(f"Example rule not found at {example_rule_path}, skipping test") + return + + content = example_rule_path.read_text(encoding="utf-8") + import re + import yaml + + pattern = r"^---\s*\n(.*?)\n---\s*\n(.*)$" + match = re.match(pattern, content, re.DOTALL) + if match: + metadata_dict = yaml.safe_load(match.group(1)) + body = match.group(2).strip() + metadata = RuleMetadata(**metadata_dict) + doc = RuleDocument( + metadata=metadata, content=body, source_file=str(example_rule_path) + ) + vs.add_document(doc) + + # Verify count + assert vs.count() == 1, f"Expected 1 rule, got {vs.count()}" + + # Search for relevant content + results = vs.search("runner steal base", top_k=5) + assert len(results) > 0, "Expected at least one search result" + assert ( + results[0].rule_id == "5.2.1(b)" + ), f"Expected rule 5.2.1(b), got {results[0].rule_id}" + + print("✓ Test passed: Ingestion and search work correctly") + print(f" Found rule: {results[0].title}") + print(f" Similarity: {results[0].similarity:.2%}") + + # Cleanup + vs.clear_all() + + +if __name__ == "__main__": + test_ingest_example_rule()