feat: initial chatbot implementation with FastAPI, ChromaDB, Discord bot, and Gitea integration
- Add vector store with sentence-transformers for semantic search - FastAPI backend with /chat and /health endpoints - Conversation state persistence via SQLite - OpenRouter integration with structured JSON responses - Discord bot with /ask slash command and reply-based follow-ups - Automated Gitea issue creation for unanswered questions - Docker support with docker-compose for easy deployment - Example rule file and ingestion script - Comprehensive documentation in README
This commit is contained in:
commit
c42fea66ba
22
.env.example
Normal file
22
.env.example
Normal file
@ -0,0 +1,22 @@
|
||||
# OpenRouter Configuration
|
||||
OPENROUTER_API_KEY=your_openrouter_api_key_here
|
||||
OPENROUTER_MODEL=stepfun/step-3.5-flash:free
|
||||
|
||||
# Discord Bot Configuration
|
||||
DISCORD_BOT_TOKEN=your_discord_bot_token_here
|
||||
DISCORD_GUILD_ID=your_guild_id_here # Optional, speeds up slash command sync
|
||||
|
||||
# Gitea Configuration (for issue creation)
|
||||
GITEA_TOKEN=your_gitea_token_here
|
||||
GITEA_OWNER=cal
|
||||
GITEA_REPO=strat-chatbot
|
||||
GITEA_BASE_URL=https://git.manticorum.com/api/v1
|
||||
|
||||
# Application Configuration
|
||||
DATA_DIR=./data
|
||||
RULES_DIR=./data/rules
|
||||
CHROMA_DIR=./data/chroma
|
||||
DB_URL=sqlite+aiosqlite:///./data/conversations.db
|
||||
CONVERSATION_TTL=1800
|
||||
TOP_K_RULES=10
|
||||
EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
||||
43
.gitignore
vendored
Normal file
43
.gitignore
vendored
Normal file
@ -0,0 +1,43 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
env/
|
||||
venv/
|
||||
.venv/
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
poetry.lock
|
||||
|
||||
# Data files (except example rules)
|
||||
data/chroma/
|
||||
data/conversations.db
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Temporary
|
||||
tmp/
|
||||
temp/
|
||||
.mypy_cache/
|
||||
.pytest_cache/
|
||||
|
||||
# Docker
|
||||
.dockerignore
|
||||
48
Dockerfile
Normal file
48
Dockerfile
Normal file
@ -0,0 +1,48 @@
|
||||
# Multi-stage build for Strat-Chatbot
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies for sentence-transformers (PyTorch, etc.)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc \
|
||||
g++ \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Poetry
|
||||
RUN pip install --no-cache-dir poetry
|
||||
|
||||
# Copy dependencies
|
||||
COPY pyproject.toml ./
|
||||
COPY README.md ./
|
||||
|
||||
# Install dependencies
|
||||
RUN poetry config virtualenvs.in-project true && \
|
||||
poetry install --no-interaction --no-ansi --only main
|
||||
|
||||
# Final stage
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy virtual environment from builder
|
||||
COPY --from=builder /app/.venv .venv
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd --create-home --shell /bin/bash app && chown -R app:app /app
|
||||
USER app
|
||||
|
||||
# Copy application code
|
||||
COPY --chown=app:app app/ ./app/
|
||||
COPY --chown=app:app data/ ./data/
|
||||
COPY --chown=app:app scripts/ ./scripts/
|
||||
|
||||
# Create data directories
|
||||
RUN mkdir -p data/chroma data/rules
|
||||
|
||||
# Expose ports
|
||||
EXPOSE 8000
|
||||
|
||||
# Run FastAPI server
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
229
README.md
Normal file
229
README.md
Normal file
@ -0,0 +1,229 @@
|
||||
# Strat-Chatbot
|
||||
|
||||
AI-powered Q&A chatbot for Strat-O-Matic baseball league rules.
|
||||
|
||||
## Features
|
||||
|
||||
- **Natural language Q&A**: Ask questions about league rules in plain English
|
||||
- **Semantic search**: Uses ChromaDB vector embeddings to find relevant rules
|
||||
- **Rule citations**: Always cites specific rule IDs (e.g., "Rule 5.2.1(b)")
|
||||
- **Conversation threading**: Maintains conversation context for follow-up questions
|
||||
- **Gitea integration**: Automatically creates issues for unanswered questions
|
||||
- **Discord integration**: Slash command `/ask` with reply-based follow-ups
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────┐ ┌──────────────┐ ┌─────────────┐
|
||||
│ Discord │────│ FastAPI │────│ ChromaDB │
|
||||
│ Bot │ │ (port 8000) │ │ (vectors) │
|
||||
└─────────┘ └──────────────┘ └─────────────┘
|
||||
│
|
||||
┌───────▼──────┐
|
||||
│ Markdown │
|
||||
│ Rule Files │
|
||||
└──────────────┘
|
||||
│
|
||||
┌───────▼──────┐
|
||||
│ OpenRouter │
|
||||
│ (LLM API) │
|
||||
└──────────────┘
|
||||
│
|
||||
┌───────▼──────┐
|
||||
│ Gitea │
|
||||
│ Issues │
|
||||
└──────────────┘
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Docker & Docker Compose
|
||||
- OpenRouter API key
|
||||
- Discord bot token
|
||||
- Gitea token (optional, for issue creation)
|
||||
|
||||
### Setup
|
||||
|
||||
1. **Clone and configure**
|
||||
|
||||
```bash
|
||||
cd strat-chatbot
|
||||
cp .env.example .env
|
||||
# Edit .env with your API keys and tokens
|
||||
```
|
||||
|
||||
2. **Prepare rules**
|
||||
|
||||
Place your rule documents in `data/rules/` as Markdown files with YAML frontmatter:
|
||||
|
||||
```markdown
|
||||
---
|
||||
rule_id: "5.2.1(b)"
|
||||
title: "Stolen Base Attempts"
|
||||
section: "Baserunning"
|
||||
parent_rule: "5.2"
|
||||
page_ref: "32"
|
||||
---
|
||||
|
||||
When a runner attempts to steal...
|
||||
```
|
||||
|
||||
3. **Ingest rules**
|
||||
|
||||
```bash
|
||||
# With Docker Compose (recommended)
|
||||
docker compose up -d
|
||||
docker compose exec api python scripts/ingest_rules.py
|
||||
|
||||
# Or locally
|
||||
uv sync
|
||||
uv run scripts/ingest_rules.py
|
||||
```
|
||||
|
||||
4. **Start services**
|
||||
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
The API will be available at http://localhost:8000
|
||||
|
||||
The Discord bot will connect and sync slash commands.
|
||||
|
||||
### Runtime Configuration
|
||||
|
||||
| Environment Variable | Required? | Description |
|
||||
|---------------------|-----------|-------------|
|
||||
| `OPENROUTER_API_KEY` | Yes | OpenRouter API key |
|
||||
| `OPENROUTER_MODEL` | No | Model ID (default: `stepfun/step-3.5-flash:free`) |
|
||||
| `DISCORD_BOT_TOKEN` | No | Discord bot token (omit to run API only) |
|
||||
| `DISCORD_GUILD_ID` | No | Guild ID for slash command sync (faster than global) |
|
||||
| `GITEA_TOKEN` | No | Gitea API token (for issue creation) |
|
||||
| `GITEA_OWNER` | No | Gitea username (default: `cal`) |
|
||||
| `GITEA_REPO` | No | Repository name (default: `strat-chatbot`) |
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/health` | GET | Health check with stats |
|
||||
| `/chat` | POST | Send a question and get a response |
|
||||
| `/stats` | GET | Knowledge base and system statistics |
|
||||
|
||||
### Chat Request
|
||||
|
||||
```json
|
||||
{
|
||||
"message": "Can a runner steal on a 2-2 count?",
|
||||
"user_id": "123456789",
|
||||
"channel_id": "987654321",
|
||||
"conversation_id": "optional-uuid",
|
||||
"parent_message_id": "optional-parent-uuid"
|
||||
}
|
||||
```
|
||||
|
||||
### Chat Response
|
||||
|
||||
```json
|
||||
{
|
||||
"response": "Yes, according to Rule 5.2.1(b)...",
|
||||
"conversation_id": "conv-uuid",
|
||||
"message_id": "msg-uuid",
|
||||
"cited_rules": ["5.2.1(b)", "5.3"],
|
||||
"confidence": 0.85,
|
||||
"needs_human": false
|
||||
}
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Local Development (without Docker)
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
uv sync
|
||||
|
||||
# Ingest rules
|
||||
uv run scripts/ingest_rules.py
|
||||
|
||||
# Run API server
|
||||
uv run app/main.py
|
||||
|
||||
# In another terminal, run Discord bot
|
||||
uv run app/discord_bot.py
|
||||
```
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
strat-chatbot/
|
||||
├── app/
|
||||
│ ├── __init__.py
|
||||
│ ├── config.py # Configuration management
|
||||
│ ├── database.py # SQLAlchemy conversation state
|
||||
│ ├── gitea.py # Gitea API client
|
||||
│ ├── llm.py # OpenRouter integration
|
||||
│ ├── main.py # FastAPI app
|
||||
│ ├── models.py # Pydantic models
|
||||
│ ├── vector_store.py # ChromaDB wrapper
|
||||
│ └── discord_bot.py # Discord bot
|
||||
├── data/
|
||||
│ ├── chroma/ # Vector DB (auto-created)
|
||||
│ └── rules/ # Your markdown rule files
|
||||
├── scripts/
|
||||
│ └── ingest_rules.py # Ingestion pipeline
|
||||
├── tests/ # Test files
|
||||
├── .env.example
|
||||
├── Dockerfile
|
||||
├── docker-compose.yml
|
||||
└── pyproject.toml
|
||||
```
|
||||
|
||||
## Performance Optimizations
|
||||
|
||||
- **Embedding cache**: ChromaDB persists embeddings on disk
|
||||
- **Rule chunking**: Each rule is a separate document, no context fragmentation
|
||||
- **Top-k search**: Configurable number of rules to retrieve (default: 10)
|
||||
- **Conversation TTL**: 30 minutes to limit database size
|
||||
- **Async operations**: All I/O is non-blocking
|
||||
|
||||
## Testing the API
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/chat \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"message": "What happens if the pitcher balks?",
|
||||
"user_id": "test123",
|
||||
"channel_id": "general"
|
||||
}'
|
||||
```
|
||||
|
||||
## Gitea Integration
|
||||
|
||||
When the bot encounters a question it can't answer confidently (confidence < 0.4), it will automatically:
|
||||
|
||||
1. Log the question to console
|
||||
2. Create an issue in your configured Gitea repo
|
||||
3. Include: user ID, channel, question, attempted rules, conversation link
|
||||
|
||||
Issues are labeled with:
|
||||
- `rules-gap` - needs a rule addition or clarification
|
||||
- `ai-generated` - created by AI bot
|
||||
- `needs-review` - requires human administrator attention
|
||||
|
||||
## To-Do
|
||||
|
||||
- [ ] Build OpenRouter Docker client with proper torch dependencies
|
||||
- [ ] Add PDF ingestion support (convert PDF → Markdown)
|
||||
- [ ] Implement rule change detection and incremental updates
|
||||
- [ ] Add rate limiting per Discord user
|
||||
- [ ] Create admin endpoints for rule management
|
||||
- [ ] Add Prometheus metrics for monitoring
|
||||
- [ ] Build unit and integration tests
|
||||
|
||||
## License
|
||||
|
||||
TBD
|
||||
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Strat-Chatbot application package."""
|
||||
53
app/config.py
Normal file
53
app/config.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""Configuration management using Pydantic Settings."""
|
||||
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings with environment variable overrides."""
|
||||
|
||||
# OpenRouter
|
||||
openrouter_api_key: str = Field(default="", env="OPENROUTER_API_KEY")
|
||||
openrouter_model: str = Field(
|
||||
default="stepfun/step-3.5-flash:free", env="OPENROUTER_MODEL"
|
||||
)
|
||||
|
||||
# Discord
|
||||
discord_bot_token: str = Field(default="", env="DISCORD_BOT_TOKEN")
|
||||
discord_guild_id: str | None = Field(default=None, env="DISCORD_GUILD_ID")
|
||||
|
||||
# Gitea
|
||||
gitea_token: str = Field(default="", env="GITEA_TOKEN")
|
||||
gitea_owner: str = Field(default="cal", env="GITEA_OWNER")
|
||||
gitea_repo: str = Field(default="strat-chatbot", env="GITEA_REPO")
|
||||
gitea_base_url: str = Field(
|
||||
default="https://git.manticorum.com/api/v1", env="GITEA_BASE_URL"
|
||||
)
|
||||
|
||||
# Paths
|
||||
data_dir: Path = Field(default=Path("./data"), env="DATA_DIR")
|
||||
rules_dir: Path = Field(default=Path("./data/rules"), env="RULES_DIR")
|
||||
chroma_dir: Path = Field(default=Path("./data/chroma"), env="CHROMA_DIR")
|
||||
|
||||
# Database
|
||||
db_url: str = Field(
|
||||
default="sqlite+aiosqlite:///./data/conversations.db", env="DB_URL"
|
||||
)
|
||||
|
||||
# Conversation state TTL (seconds)
|
||||
conversation_ttl: int = Field(default=1800, env="CONVERSATION_TTL")
|
||||
|
||||
# Vector search
|
||||
top_k_rules: int = Field(default=10, env="TOP_K_RULES")
|
||||
embedding_model: str = Field(
|
||||
default="sentence-transformers/all-MiniLM-L6-v2", env="EMBEDDING_MODEL"
|
||||
)
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
161
app/database.py
Normal file
161
app/database.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""SQLAlchemy-based conversation state management with aiosqlite."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
import uuid
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from sqlalchemy import Column, String, DateTime, Boolean, ForeignKey, select
|
||||
from .config import settings
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class ConversationTable(Base):
|
||||
"""SQLAlchemy model for conversations."""
|
||||
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
channel_id = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
last_activity = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class MessageTable(Base):
|
||||
"""SQLAlchemy model for messages."""
|
||||
|
||||
__tablename__ = "messages"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False)
|
||||
content = Column(String, nullable=False)
|
||||
is_user = Column(Boolean, nullable=False)
|
||||
parent_id = Column(String, ForeignKey("messages.id"), nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""Manages conversation state in SQLite."""
|
||||
|
||||
def __init__(self, db_url: str):
|
||||
"""Initialize database engine and session factory."""
|
||||
self.engine = create_async_engine(db_url, echo=False)
|
||||
self.async_session = sessionmaker(
|
||||
self.engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async def init_db(self):
|
||||
"""Create tables if they don't exist."""
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async def get_or_create_conversation(
|
||||
self, user_id: str, channel_id: str, conversation_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""Get existing conversation or create a new one."""
|
||||
async with self.async_session() as session:
|
||||
if conversation_id:
|
||||
result = await session.execute(
|
||||
select(ConversationTable).where(
|
||||
ConversationTable.id == conversation_id
|
||||
)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
if conv:
|
||||
conv.last_activity = datetime.utcnow()
|
||||
await session.commit()
|
||||
return conv.id
|
||||
|
||||
# Create new conversation
|
||||
new_id = str(uuid.uuid4())
|
||||
conv = ConversationTable(id=new_id, user_id=user_id, channel_id=channel_id)
|
||||
session.add(conv)
|
||||
await session.commit()
|
||||
return new_id
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
content: str,
|
||||
is_user: bool,
|
||||
parent_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Add a message to a conversation."""
|
||||
message_id = str(uuid.uuid4())
|
||||
async with self.async_session() as session:
|
||||
msg = MessageTable(
|
||||
id=message_id,
|
||||
conversation_id=conversation_id,
|
||||
content=content,
|
||||
is_user=is_user,
|
||||
parent_id=parent_id,
|
||||
)
|
||||
session.add(msg)
|
||||
|
||||
# Update conversation activity
|
||||
result = await session.execute(
|
||||
select(ConversationTable).where(ConversationTable.id == conversation_id)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
if conv:
|
||||
conv.last_activity = datetime.utcnow()
|
||||
|
||||
await session.commit()
|
||||
return message_id
|
||||
|
||||
async def get_conversation_history(
|
||||
self, conversation_id: str, limit: int = 10
|
||||
) -> list[dict]:
|
||||
"""Get recent messages from a conversation in OpenAI format."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(MessageTable)
|
||||
.where(MessageTable.conversation_id == conversation_id)
|
||||
.order_by(MessageTable.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
messages = result.scalars().all()
|
||||
# Reverse to get chronological order and convert to API format
|
||||
history = []
|
||||
for msg in reversed(messages):
|
||||
role = "user" if msg.is_user else "assistant"
|
||||
history.append({"role": role, "content": msg.content})
|
||||
|
||||
return history
|
||||
|
||||
async def cleanup_old_conversations(self, ttl_seconds: int = 1800):
|
||||
"""Delete conversations older than TTL to free up storage."""
|
||||
cutoff = datetime.utcnow() - timedelta(seconds=ttl_seconds)
|
||||
async with self.async_session() as session:
|
||||
# Find old conversations
|
||||
result = await session.execute(
|
||||
select(ConversationTable).where(
|
||||
ConversationTable.last_activity < cutoff
|
||||
)
|
||||
)
|
||||
old_convs = result.scalars().all()
|
||||
|
||||
conv_ids = [conv.id for conv in old_convs]
|
||||
if conv_ids:
|
||||
# Delete messages first (cascade would handle but explicit is clear)
|
||||
await session.execute(
|
||||
sa.delete(MessageTable).where(
|
||||
MessageTable.conversation_id.in_(conv_ids)
|
||||
)
|
||||
)
|
||||
# Delete conversations
|
||||
await session.execute(
|
||||
sa.delete(ConversationTable).where(
|
||||
ConversationTable.id.in_(conv_ids)
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
print(f"Cleaned up {len(conv_ids)} old conversations")
|
||||
|
||||
|
||||
async def get_conversation_manager() -> ConversationManager:
|
||||
"""Dependency for FastAPI to get a ConversationManager instance."""
|
||||
return ConversationManager(settings.db_url)
|
||||
240
app/discord_bot.py
Normal file
240
app/discord_bot.py
Normal file
@ -0,0 +1,240 @@
|
||||
"""Discord bot for Strat-O-Matic rules Q&A."""
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
class StratChatbotBot(commands.Bot):
|
||||
"""Discord bot for the rules chatbot."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the bot with default intents."""
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
super().__init__(command_prefix="!", intents=intents)
|
||||
|
||||
self.api_base_url: Optional[str] = None
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def setup_hook(self):
|
||||
"""Set up the bot's HTTP session and sync commands."""
|
||||
self.session = aiohttp.ClientSession()
|
||||
# Sync slash commands with Discord
|
||||
if settings.discord_guild_id:
|
||||
guild = discord.Object(id=int(settings.discord_guild_id))
|
||||
self.tree.copy_global_to(guild=guild)
|
||||
await self.tree.sync(guild=guild)
|
||||
print(f"Slash commands synced to guild {settings.discord_guild_id}")
|
||||
else:
|
||||
await self.tree.sync()
|
||||
print("Slash commands synced globally")
|
||||
|
||||
async def close(self):
|
||||
"""Cleanup on shutdown."""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
await super().close()
|
||||
|
||||
async def query_chat_api(
|
||||
self,
|
||||
message: str,
|
||||
user_id: str,
|
||||
channel_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
parent_message_id: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Send a request to the FastAPI chat endpoint."""
|
||||
payload = {
|
||||
"message": message,
|
||||
"user_id": user_id,
|
||||
"channel_id": channel_id,
|
||||
"conversation_id": conversation_id,
|
||||
"parent_message_id": parent_message_id,
|
||||
}
|
||||
|
||||
async with self.session.post(
|
||||
f"{self.api_base_url}/chat",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=120),
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"API error {response.status}: {error_text}")
|
||||
return await response.json()
|
||||
|
||||
|
||||
bot = StratChatbotBot()
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
"""Called when the bot is ready."""
|
||||
print(f"🤖 Bot logged in as {bot.user} (ID: {bot.user.id})")
|
||||
print("Ready to answer Strat-O-Matic rules questions!")
|
||||
|
||||
|
||||
@bot.tree.command(
|
||||
name="ask", description="Ask a question about Strat-O-Matic league rules"
|
||||
)
|
||||
@app_commands.describe(
|
||||
question="Your rules question (e.g., 'Can a runner steal on a 2-2 count?')"
|
||||
)
|
||||
async def ask_command(interaction: discord.Interaction, question: str):
|
||||
"""Handle /ask command."""
|
||||
await interaction.response.defer(ephemeral=False)
|
||||
|
||||
try:
|
||||
result = await bot.query_chat_api(
|
||||
message=question,
|
||||
user_id=str(interaction.user.id),
|
||||
channel_id=str(interaction.channel_id),
|
||||
conversation_id=None, # New conversation
|
||||
parent_message_id=None,
|
||||
)
|
||||
|
||||
# Build response embed
|
||||
embed = discord.Embed(
|
||||
title="Rules Answer",
|
||||
description=result["response"][:4000], # Discord limit
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
# Add cited rules if any
|
||||
if result.get("cited_rules"):
|
||||
embed.add_field(
|
||||
name="📋 Cited Rules",
|
||||
value=", ".join([f"`{rid}`" for rid in result["cited_rules"]]),
|
||||
inline=False,
|
||||
)
|
||||
|
||||
# Add confidence indicator
|
||||
confidence = result.get("confidence", 0.0)
|
||||
if confidence < 0.4:
|
||||
embed.add_field(
|
||||
name="⚠️ Confidence",
|
||||
value=f"Low ({confidence:.0%}) - A human review has been requested",
|
||||
inline=False,
|
||||
)
|
||||
|
||||
# Add conversation ID for follow-ups
|
||||
embed.set_footer(
|
||||
text=f"Conversation: {result['conversation_id'][:8]}... "
|
||||
f"| Reply to ask a follow-up"
|
||||
)
|
||||
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
except Exception as e:
|
||||
await interaction.followup.send(
|
||||
embed=discord.Embed(
|
||||
title="❌ Error",
|
||||
description=f"Failed to get answer: {str(e)}",
|
||||
color=discord.Color.red(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_message(message: discord.Message):
|
||||
"""Handle follow-up messages via reply."""
|
||||
# Ignore bot messages
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
# Only handle replies to the bot's messages
|
||||
if not message.reference:
|
||||
return
|
||||
|
||||
referenced = await message.channel.fetch_message(message.reference.message_id)
|
||||
|
||||
# Check if the referenced message was from this bot
|
||||
if referenced.author != bot.user:
|
||||
return
|
||||
|
||||
# Try to extract conversation ID from the footer
|
||||
embed = referenced.embeds[0] if referenced.embeds else None
|
||||
if not embed or not embed.footer:
|
||||
await message.reply(
|
||||
"❓ I couldn't find this conversation. Please use `/ask` to start a new question.",
|
||||
mention_author=True,
|
||||
)
|
||||
return
|
||||
|
||||
footer_text = embed.footer.text or ""
|
||||
if "Conversation:" not in footer_text:
|
||||
await message.reply(
|
||||
"❓ Could not determine conversation. Use `/ask` to start fresh.",
|
||||
mention_author=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Extract conversation ID (rough parsing)
|
||||
try:
|
||||
conv_id = footer_text.split("Conversation:")[1].split("...")[0].strip()
|
||||
conversation_id = footer_text.split("Conversation:")[1].split(" ")[0].strip()
|
||||
except:
|
||||
conversation_id = None
|
||||
|
||||
# Get parent message ID (the original answer message)
|
||||
parent_message_id = str(referenced.id)
|
||||
|
||||
# Forward to API
|
||||
await message.reply("🔍 Looking into that follow-up...", mention_author=True)
|
||||
|
||||
try:
|
||||
result = await bot.query_chat_api(
|
||||
message=message.content,
|
||||
user_id=str(message.author.id),
|
||||
channel_id=str(message.channel_id),
|
||||
conversation_id=conversation_id,
|
||||
parent_message_id=parent_message_id,
|
||||
)
|
||||
|
||||
response_embed = discord.Embed(
|
||||
title="Follow-up Answer",
|
||||
description=result["response"][:4000],
|
||||
color=discord.Color.green(),
|
||||
)
|
||||
|
||||
if result.get("cited_rules"):
|
||||
response_embed.add_field(
|
||||
name="📋 Cited Rules",
|
||||
value=", ".join([f"`{rid}`" for rid in result["cited_rules"]]),
|
||||
inline=False,
|
||||
)
|
||||
|
||||
if result.get("confidence", 0.0) < 0.4:
|
||||
response_embed.add_field(
|
||||
name="⚠️ Confidence",
|
||||
value=f"Low - Human review requested",
|
||||
inline=False,
|
||||
)
|
||||
|
||||
await message.reply(embed=response_embed, mention_author=True)
|
||||
|
||||
except Exception as e:
|
||||
await message.reply(
|
||||
embed=discord.Embed(
|
||||
title="❌ Error",
|
||||
description=f"Failed to process follow-up: {str(e)}",
|
||||
color=discord.Color.red(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_bot(api_base_url: str = "http://localhost:8000"):
|
||||
"""Entry point to run the Discord bot."""
|
||||
bot.api_base_url = api_base_url
|
||||
|
||||
if not settings.discord_bot_token:
|
||||
print("❌ DISCORD_BOT_TOKEN environment variable is required")
|
||||
exit(1)
|
||||
|
||||
bot.run(settings.discord_bot_token)
|
||||
108
app/gitea.py
Normal file
108
app/gitea.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""Gitea client for creating issues when questions need human review."""
|
||||
|
||||
import httpx
|
||||
from typing import Optional
|
||||
from .config import settings
|
||||
|
||||
|
||||
class GiteaClient:
|
||||
"""Client for Gitea API interactions."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Gitea client with credentials."""
|
||||
self.token = settings.gitea_token
|
||||
self.owner = settings.gitea_owner
|
||||
self.repo = settings.gitea_repo
|
||||
self.base_url = settings.gitea_base_url.rstrip("/")
|
||||
self.headers = {
|
||||
"Authorization": f"token {self.token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
self._client = httpx.AsyncClient(timeout=30.0)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
|
||||
async def create_issue(
|
||||
self,
|
||||
title: str,
|
||||
body: str,
|
||||
labels: Optional[list[str]] = None,
|
||||
assignee: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create a new issue in the configured repository."""
|
||||
if not self._client:
|
||||
raise RuntimeError("GiteaClient must be used as async context manager")
|
||||
|
||||
url = f"{self.base_url}/repos/{self.owner}/{self.repo}/issues"
|
||||
|
||||
payload = {"title": title, "body": body}
|
||||
|
||||
if labels:
|
||||
payload["labels"] = labels
|
||||
|
||||
if assignee:
|
||||
payload["assignee"] = assignee
|
||||
|
||||
response = await self._client.post(url, headers=self.headers, json=payload)
|
||||
|
||||
if response.status_code not in (200, 201):
|
||||
error_detail = response.text
|
||||
raise RuntimeError(
|
||||
f"Gitea API error creating issue: {response.status_code} - {error_detail}"
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def create_unanswered_issue(
|
||||
self,
|
||||
question: str,
|
||||
user_id: str,
|
||||
channel_id: str,
|
||||
attempted_rules: list[str],
|
||||
conversation_id: str,
|
||||
) -> str:
|
||||
"""Create an issue for an unanswered question needing human review."""
|
||||
title = f"🤔 Unanswered rules question: {question[:80]}{'...' if len(question) > 80 else ''}"
|
||||
|
||||
body = f"""## Unanswered Question
|
||||
|
||||
**User:** {user_id}
|
||||
|
||||
**Channel:** {channel_id}
|
||||
|
||||
**Conversation ID:** {conversation_id}
|
||||
|
||||
**Question:**
|
||||
{question}
|
||||
|
||||
**Searched Rules:**
|
||||
{', '.join(attempted_rules) if attempted_rules else 'None'}
|
||||
|
||||
**Additional Context:**
|
||||
This question was asked in Discord and the bot could not provide a confident answer. The rules either don't cover this question or the information was ambiguous.
|
||||
|
||||
---
|
||||
|
||||
*This issue was automatically created by the Strat-Chatbot.*"""
|
||||
|
||||
labels = ["rules-gap", "ai-generated", "needs-review"]
|
||||
|
||||
issue = await self.create_issue(title=title, body=body, labels=labels)
|
||||
|
||||
return issue.get("html_url", "")
|
||||
|
||||
|
||||
def get_gitea_client() -> Optional[GiteaClient]:
|
||||
"""Factory to get Gitea client if token is configured."""
|
||||
if settings.gitea_token:
|
||||
return GiteaClient()
|
||||
return None
|
||||
179
app/llm.py
Normal file
179
app/llm.py
Normal file
@ -0,0 +1,179 @@
|
||||
"""OpenRouter LLM integration for answering rules questions."""
|
||||
|
||||
from typing import Optional
|
||||
import json
|
||||
import httpx
|
||||
from .config import settings
|
||||
from .models import RuleSearchResult, ChatResponse
|
||||
|
||||
SYSTEM_PROMPT = """You are a helpful assistant for a Strat-O-Matic baseball league.
|
||||
Your job is to answer questions about league rules and procedures using the provided rule excerpts.
|
||||
|
||||
CRITICAL RULES:
|
||||
1. ONLY use information from the provided rules. If the rules don't contain the answer, say so clearly.
|
||||
2. ALWAYS cite rule IDs when referencing a rule (e.g., "Rule 5.2.1(b) states that...")
|
||||
3. If multiple rules are relevant, cite all of them.
|
||||
4. If you're uncertain or the rules are ambiguous, say so and suggest asking a league administrator.
|
||||
5. Keep responses concise but complete. Use examples when helpful from the rules.
|
||||
6. Do NOT make up rules or infer beyond what's explicitly stated.
|
||||
|
||||
When answering:
|
||||
- Start with a direct answer to the question
|
||||
- Support with rule citations
|
||||
- Include relevant details from the rules
|
||||
- If no relevant rules found, explicitly state: "I don't have a rule that addresses this question."
|
||||
|
||||
Response format (JSON):
|
||||
{
|
||||
"answer": "Your response text",
|
||||
"cited_rules": ["rule_id_1", "rule_id_2"],
|
||||
"confidence": 0.0-1.0,
|
||||
"needs_human": boolean
|
||||
}
|
||||
|
||||
Higher confidence (0.8-1.0) when rules clearly answer the question.
|
||||
Lower confidence (0.3-0.7) when rules partially address the question or are ambiguous.
|
||||
Very low confidence (0.0-0.2) when rules don't address the question at all.
|
||||
"""
|
||||
|
||||
|
||||
class OpenRouterClient:
|
||||
"""Client for OpenRouter API."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the client."""
|
||||
self.api_key = settings.openrouter_api_key
|
||||
if not self.api_key:
|
||||
raise ValueError("OPENROUTER_API_KEY is required")
|
||||
self.model = settings.openrouter_model
|
||||
self.base_url = "https://openrouter.ai/api/v1/chat/completions"
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
question: str,
|
||||
rules: list[RuleSearchResult],
|
||||
conversation_history: Optional[list[dict]] = None,
|
||||
) -> ChatResponse:
|
||||
"""Generate a response using the LLM with retrieved rules as context."""
|
||||
# Build context from rules
|
||||
rules_context = "\n\n".join(
|
||||
[f"Rule {r.rule_id}: {r.title}\n{r.content}" for r in rules]
|
||||
)
|
||||
|
||||
if rules:
|
||||
context_msg = (
|
||||
f"Here are the relevant rules for the question:\n\n{rules_context}"
|
||||
)
|
||||
else:
|
||||
context_msg = "No relevant rules were found in the knowledge base."
|
||||
|
||||
# Build conversation history
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
|
||||
if conversation_history:
|
||||
# Add last few turns of conversation (limit to avoid token overflow)
|
||||
messages.extend(
|
||||
conversation_history[-6:]
|
||||
) # Last 3 exchanges (user+assistant)
|
||||
|
||||
# Add current question with context
|
||||
user_message = f"{context_msg}\n\nUser question: {question}\n\nAnswer the question based on the rules provided."
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# Call OpenRouter API
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 1000,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_detail = response.text
|
||||
raise RuntimeError(
|
||||
f"OpenRouter API error: {response.status_code} - {error_detail}"
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
|
||||
# Parse the JSON response
|
||||
try:
|
||||
# Extract JSON from response (LLM might add markdown formatting)
|
||||
if "```json" in content:
|
||||
json_str = content.split("```json")[1].split("```")[0].strip()
|
||||
else:
|
||||
json_str = content.strip()
|
||||
|
||||
parsed = json.loads(json_str)
|
||||
|
||||
cited_rules = parsed.get("cited_rules", [])
|
||||
if not cited_rules and rules:
|
||||
# Fallback: extract rule IDs from the text if not properly returned
|
||||
import re
|
||||
|
||||
rule_ids = re.findall(
|
||||
r"Rule\s+([\d\.\(\)a-b]+)", parsed.get("answer", "")
|
||||
)
|
||||
cited_rules = list(set(rule_ids))
|
||||
|
||||
return ChatResponse(
|
||||
response=parsed["answer"],
|
||||
conversation_id="", # Will be set by caller
|
||||
message_id="", # Will be set by caller
|
||||
cited_rules=cited_rules,
|
||||
confidence=float(parsed.get("confidence", 0.5)),
|
||||
needs_human=bool(parsed.get("needs_human", False)),
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
# If parsing fails, return what we can extract
|
||||
return ChatResponse(
|
||||
response=content,
|
||||
conversation_id="",
|
||||
message_id="",
|
||||
cited_rules=[],
|
||||
confidence=0.5,
|
||||
needs_human=False,
|
||||
)
|
||||
|
||||
|
||||
class MockLLMClient:
|
||||
"""Mock LLM client for testing without API calls."""
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
question: str,
|
||||
rules: list[RuleSearchResult],
|
||||
conversation_history: Optional[list[dict]] = None,
|
||||
) -> ChatResponse:
|
||||
"""Return a mock response."""
|
||||
if rules:
|
||||
rule_list = ", ".join([r.rule_id for r in rules])
|
||||
answer = f"Based on rule(s) {rule_list}, here's what you need to know..."
|
||||
else:
|
||||
answer = "I don't have a rule that addresses this question. You should ask a league administrator."
|
||||
|
||||
return ChatResponse(
|
||||
response=answer,
|
||||
conversation_id="",
|
||||
message_id="",
|
||||
cited_rules=[r.rule_id for r in rules],
|
||||
confidence=1.0 if rules else 0.0,
|
||||
needs_human=not rules,
|
||||
)
|
||||
|
||||
|
||||
def get_llm_client(use_mock: bool = False):
|
||||
"""Factory to get the appropriate LLM client."""
|
||||
if use_mock or not settings.openrouter_api_key:
|
||||
return MockLLMClient()
|
||||
return OpenRouterClient()
|
||||
198
app/main.py
Normal file
198
app/main.py
Normal file
@ -0,0 +1,198 @@
|
||||
"""FastAPI application for Strat-O-Matic rules chatbot."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Depends
|
||||
import uvicorn
|
||||
import sqlalchemy as sa
|
||||
|
||||
from .config import settings
|
||||
from .models import ChatRequest, ChatResponse
|
||||
from .vector_store import VectorStore
|
||||
from .database import ConversationManager, get_conversation_manager
|
||||
from .llm import get_llm_client
|
||||
from .gitea import GiteaClient
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage application lifespan - startup and shutdown."""
|
||||
# Startup
|
||||
print("Initializing Strat-Chatbot...")
|
||||
|
||||
# Initialize vector store
|
||||
chroma_dir = settings.data_dir / "chroma"
|
||||
vector_store = VectorStore(chroma_dir, settings.embedding_model)
|
||||
print(f"Vector store ready at {chroma_dir} ({vector_store.count()} rules loaded)")
|
||||
|
||||
# Initialize database
|
||||
db_manager = ConversationManager(settings.db_url)
|
||||
await db_manager.init_db()
|
||||
print("Database initialized")
|
||||
|
||||
# Initialize LLM client
|
||||
llm_client = get_llm_client(use_mock=not settings.openrouter_api_key)
|
||||
print(f"LLM client ready (model: {settings.openrouter_model})")
|
||||
|
||||
# Initialize Gitea client
|
||||
gitea_client = GiteaClient() if settings.gitea_token else None
|
||||
|
||||
# Store in app state
|
||||
app.state.vector_store = vector_store
|
||||
app.state.db_manager = db_manager
|
||||
app.state.llm_client = llm_client
|
||||
app.state.gitea_client = gitea_client
|
||||
|
||||
print("Strat-Chatbot ready!")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
print("Shutting down...")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Strat-Chatbot",
|
||||
description="Strat-O-Matic rules Q&A API",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
vector_store: VectorStore = app.state.vector_store
|
||||
stats = vector_store.get_stats()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"rules_count": stats["total_rules"],
|
||||
"sections": stats["sections"],
|
||||
}
|
||||
|
||||
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat(
|
||||
request: ChatRequest,
|
||||
db_manager: ConversationManager = Depends(get_conversation_manager),
|
||||
):
|
||||
"""Handle chat requests from Discord."""
|
||||
vector_store: VectorStore = app.state.vector_store
|
||||
llm_client = app.state.llm_client
|
||||
gitea_client = app.state.gitea_client
|
||||
|
||||
# Validate API key if using real LLM
|
||||
if not settings.openrouter_api_key:
|
||||
return ChatResponse(
|
||||
response="⚠️ OpenRouter API key not configured. Set OPENROUTER_API_KEY environment variable.",
|
||||
conversation_id=request.conversation_id or str(uuid.uuid4()),
|
||||
message_id=str(uuid.uuid4()),
|
||||
cited_rules=[],
|
||||
confidence=0.0,
|
||||
needs_human=True,
|
||||
)
|
||||
|
||||
# Get or create conversation
|
||||
conversation_id = await db_manager.get_or_create_conversation(
|
||||
user_id=request.user_id,
|
||||
channel_id=request.channel_id,
|
||||
conversation_id=request.conversation_id,
|
||||
)
|
||||
|
||||
# Save user message
|
||||
user_message_id = await db_manager.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=request.message,
|
||||
is_user=True,
|
||||
parent_id=request.parent_message_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Search for relevant rules
|
||||
search_results = vector_store.search(
|
||||
query=request.message, top_k=settings.top_k_rules
|
||||
)
|
||||
|
||||
# Get conversation history for context
|
||||
history = await db_manager.get_conversation_history(conversation_id, limit=10)
|
||||
|
||||
# Generate response from LLM
|
||||
response = await llm_client.generate_response(
|
||||
question=request.message, rules=search_results, conversation_history=history
|
||||
)
|
||||
|
||||
# Save assistant message
|
||||
assistant_message_id = await db_manager.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=response.response,
|
||||
is_user=False,
|
||||
parent_id=user_message_id,
|
||||
)
|
||||
|
||||
# If needs human or confidence is low, create Gitea issue
|
||||
if gitea_client and (response.needs_human or response.confidence < 0.4):
|
||||
try:
|
||||
issue_url = await gitea_client.create_unanswered_issue(
|
||||
question=request.message,
|
||||
user_id=request.user_id,
|
||||
channel_id=request.channel_id,
|
||||
attempted_rules=[r.rule_id for r in search_results],
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
print(f"Created Gitea issue: {issue_url}")
|
||||
except Exception as e:
|
||||
print(f"Failed to create Gitea issue: {e}")
|
||||
|
||||
# Build final response
|
||||
return ChatResponse(
|
||||
response=response.response,
|
||||
conversation_id=conversation_id,
|
||||
message_id=assistant_message_id,
|
||||
parent_message_id=user_message_id,
|
||||
cited_rules=response.cited_rules,
|
||||
confidence=response.confidence,
|
||||
needs_human=response.needs_human,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing chat request: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
async def stats():
|
||||
"""Get statistics about the knowledge base and system."""
|
||||
vector_store: VectorStore = app.state.vector_store
|
||||
db_manager: ConversationManager = app.state.db_manager
|
||||
|
||||
# Get vector store stats
|
||||
vs_stats = vector_store.get_stats()
|
||||
|
||||
# Get database stats
|
||||
async with db_manager.async_session() as session:
|
||||
conv_count = await session.execute(
|
||||
sa.text("SELECT COUNT(*) FROM conversations")
|
||||
)
|
||||
msg_count = await session.execute(sa.text("SELECT COUNT(*) FROM messages"))
|
||||
|
||||
total_conversations = conv_count.scalar() or 0
|
||||
total_messages = msg_count.scalar() or 0
|
||||
|
||||
return {
|
||||
"knowledge_base": vs_stats,
|
||||
"conversations": {
|
||||
"total": total_conversations,
|
||||
"total_messages": total_messages,
|
||||
},
|
||||
"config": {
|
||||
"openrouter_model": settings.openrouter_model,
|
||||
"top_k_rules": settings.top_k_rules,
|
||||
"embedding_model": settings.embedding_model,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
|
||||
100
app/models.py
Normal file
100
app/models.py
Normal file
@ -0,0 +1,100 @@
|
||||
"""Data models for rules and conversations."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class RuleMetadata(BaseModel):
|
||||
"""Frontmatter metadata for a rule document."""
|
||||
|
||||
rule_id: str = Field(..., description="Unique rule identifier, e.g. '5.2.1(b)'")
|
||||
title: str = Field(..., description="Rule title")
|
||||
section: str = Field(..., description="Section/category name")
|
||||
parent_rule: Optional[str] = Field(
|
||||
None, description="Parent rule ID for hierarchical rules"
|
||||
)
|
||||
last_updated: str = Field(
|
||||
default_factory=lambda: datetime.now().strftime("%Y-%m-%d"),
|
||||
description="Last update date",
|
||||
)
|
||||
page_ref: Optional[str] = Field(
|
||||
None, description="Reference to page number in rulebook"
|
||||
)
|
||||
|
||||
|
||||
class RuleDocument(BaseModel):
|
||||
"""Complete rule document with metadata and content."""
|
||||
|
||||
metadata: RuleMetadata
|
||||
content: str = Field(..., description="Rule text and examples")
|
||||
source_file: str = Field(..., description="Source file path")
|
||||
embedding: Optional[list[float]] = None
|
||||
|
||||
def to_chroma_metadata(self) -> dict:
|
||||
"""Convert to ChromaDB metadata format."""
|
||||
return {
|
||||
"rule_id": self.metadata.rule_id,
|
||||
"title": self.metadata.title,
|
||||
"section": self.metadata.section,
|
||||
"parent_rule": self.metadata.parent_rule or "",
|
||||
"page_ref": self.metadata.page_ref or "",
|
||||
"last_updated": self.metadata.last_updated,
|
||||
"source_file": self.source_file,
|
||||
}
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
"""Conversation session."""
|
||||
|
||||
id: str
|
||||
user_id: str # Discord user ID
|
||||
channel_id: str # Discord channel ID
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
last_activity: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Individual message in a conversation."""
|
||||
|
||||
id: str
|
||||
conversation_id: str
|
||||
content: str
|
||||
is_user: bool
|
||||
parent_id: Optional[str] = None
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""Incoming chat request from Discord."""
|
||||
|
||||
message: str
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
user_id: str
|
||||
channel_id: str
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Response to chat request."""
|
||||
|
||||
response: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
parent_message_id: Optional[str] = None
|
||||
cited_rules: list[str] = Field(default_factory=list)
|
||||
confidence: float = Field(..., ge=0.0, le=1.0)
|
||||
needs_human: bool = Field(
|
||||
default=False,
|
||||
description="Whether the question needs human review (unanswered)",
|
||||
)
|
||||
|
||||
|
||||
class RuleSearchResult(BaseModel):
|
||||
"""Result from vector search."""
|
||||
|
||||
rule_id: str
|
||||
title: str
|
||||
content: str
|
||||
section: str
|
||||
similarity: float = Field(..., ge=0.0, le=1.0)
|
||||
166
app/vector_store.py
Normal file
166
app/vector_store.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""ChromaDB vector store for rule embeddings."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import chromadb
|
||||
from chromadb.config import Settings as ChromaSettings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import numpy as np
|
||||
from .config import settings
|
||||
from .models import RuleDocument, RuleSearchResult
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""Wrapper around ChromaDB for rule retrieval."""
|
||||
|
||||
def __init__(self, persist_dir: Path, embedding_model: str):
|
||||
"""Initialize vector store with embedding model."""
|
||||
self.persist_dir = Path(persist_dir)
|
||||
self.persist_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
chroma_settings = ChromaSettings(
|
||||
anonymized_telemetry=False, is_persist_directory_actually_writable=True
|
||||
)
|
||||
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=str(self.persist_dir), settings=chroma_settings
|
||||
)
|
||||
|
||||
self.embedding_model = SentenceTransformer(embedding_model)
|
||||
|
||||
def get_collection(self):
|
||||
"""Get or create the rules collection."""
|
||||
return self.client.get_or_create_collection(
|
||||
name="rules", metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
def add_document(self, doc: RuleDocument) -> None:
|
||||
"""Add a single rule document to the vector store."""
|
||||
embedding = self.embedding_model.encode(doc.content).tolist()
|
||||
|
||||
collection = self.get_collection()
|
||||
collection.add(
|
||||
ids=[doc.metadata.rule_id],
|
||||
embeddings=[embedding],
|
||||
documents=[doc.content],
|
||||
metadatas=[doc.to_chroma_metadata()],
|
||||
)
|
||||
|
||||
def add_documents(self, docs: list[RuleDocument]) -> None:
|
||||
"""Add multiple documents in batch."""
|
||||
if not docs:
|
||||
return
|
||||
|
||||
ids = [doc.metadata.rule_id for doc in docs]
|
||||
contents = [doc.content for doc in docs]
|
||||
embeddings = self.embedding_model.encode(contents).tolist()
|
||||
metadatas = [doc.to_chroma_metadata() for doc in docs]
|
||||
|
||||
collection = self.get_collection()
|
||||
collection.add(
|
||||
ids=ids, embeddings=embeddings, documents=contents, metadatas=metadatas
|
||||
)
|
||||
|
||||
def search(
|
||||
self, query: str, top_k: int = 10, section_filter: Optional[str] = None
|
||||
) -> list[RuleSearchResult]:
|
||||
"""Search for relevant rules using semantic similarity."""
|
||||
query_embedding = self.embedding_model.encode(query).tolist()
|
||||
|
||||
collection = self.get_collection()
|
||||
|
||||
where = None
|
||||
if section_filter:
|
||||
where = {"section": section_filter}
|
||||
|
||||
results = collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=top_k,
|
||||
where=where,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
|
||||
search_results = []
|
||||
if results and results["documents"] and results["documents"][0]:
|
||||
for i in range(len(results["documents"][0])):
|
||||
metadata = results["metadatas"][0][i]
|
||||
distance = results["distances"][0][i]
|
||||
similarity = 1 - distance # Convert cosine distance to similarity
|
||||
|
||||
search_results.append(
|
||||
RuleSearchResult(
|
||||
rule_id=metadata["rule_id"],
|
||||
title=metadata["title"],
|
||||
content=results["documents"][0][i],
|
||||
section=metadata["section"],
|
||||
similarity=similarity,
|
||||
)
|
||||
)
|
||||
|
||||
return search_results
|
||||
|
||||
def delete_rule(self, rule_id: str) -> None:
|
||||
"""Remove a rule by its ID."""
|
||||
collection = self.get_collection()
|
||||
collection.delete(ids=[rule_id])
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Delete all rules from the collection."""
|
||||
self.client.delete_collection("rules")
|
||||
self.get_collection() # Recreate empty collection
|
||||
|
||||
def get_rule(self, rule_id: str) -> Optional[RuleSearchResult]:
|
||||
"""Retrieve a specific rule by ID."""
|
||||
collection = self.get_collection()
|
||||
result = collection.get(ids=[rule_id], include=["documents", "metadatas"])
|
||||
|
||||
if result and result["documents"] and result["documents"][0]:
|
||||
metadata = result["metadatas"][0][0]
|
||||
return RuleSearchResult(
|
||||
rule_id=metadata["rule_id"],
|
||||
title=metadata["title"],
|
||||
content=result["documents"][0][0],
|
||||
section=metadata["section"],
|
||||
similarity=1.0,
|
||||
)
|
||||
return None
|
||||
|
||||
def list_all_rules(self) -> list[RuleSearchResult]:
|
||||
"""Return all rules in the store."""
|
||||
collection = self.get_collection()
|
||||
result = collection.get(include=["documents", "metadatas"])
|
||||
|
||||
all_rules = []
|
||||
if result and result["documents"]:
|
||||
for i in range(len(result["documents"])):
|
||||
metadata = result["metadatas"][i]
|
||||
all_rules.append(
|
||||
RuleSearchResult(
|
||||
rule_id=metadata["rule_id"],
|
||||
title=metadata["title"],
|
||||
content=result["documents"][i],
|
||||
section=metadata["section"],
|
||||
similarity=1.0,
|
||||
)
|
||||
)
|
||||
|
||||
return all_rules
|
||||
|
||||
def count(self) -> int:
|
||||
"""Return the number of rules in the store."""
|
||||
collection = self.get_collection()
|
||||
return collection.count()
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get statistics about the vector store."""
|
||||
collection = self.get_collection()
|
||||
all_rules = self.list_all_rules()
|
||||
sections = {}
|
||||
for rule in all_rules:
|
||||
sections[rule.section] = sections.get(rule.section, 0) + 1
|
||||
|
||||
return {
|
||||
"total_rules": len(all_rules),
|
||||
"sections": sections,
|
||||
"persist_directory": str(self.persist_dir),
|
||||
}
|
||||
20
data/rules/example_rule.md
Normal file
20
data/rules/example_rule.md
Normal file
@ -0,0 +1,20 @@
|
||||
---
|
||||
rule_id: "5.2.1(b)"
|
||||
title: "Stolen Base Attempts"
|
||||
section: "Baserunning"
|
||||
parent_rule: "5.2"
|
||||
page_ref: "32"
|
||||
---
|
||||
|
||||
When a runner attempts to steal a base:
|
||||
1. Roll 2 six-sided dice.
|
||||
2. Add the result to the runner's **Steal** rating.
|
||||
3. Compare to the catcher's **Caught Stealing** (CS) column on the defensive chart.
|
||||
4. If the total equals or exceeds the CS number, the runner is successful.
|
||||
|
||||
**Example**: Runner with SB-2 rolls a 7. Total = 7 + 2 = 9. Catcher's CS is 11. 9 < 11, so the steal is successful.
|
||||
|
||||
**Important notes**:
|
||||
- Runners can only steal if they are on base and there are no outs.
|
||||
- Do not attempt to steal when the pitcher has a **Pickoff** rating of 5 or higher.
|
||||
- A failed steal results in an out and advances any other runners only if instructed by the result.
|
||||
87
docker-compose.yml
Normal file
87
docker-compose.yml
Normal file
@ -0,0 +1,87 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
chroma:
|
||||
image: chromadb/chroma:latest
|
||||
volumes:
|
||||
- ./data/chroma:/chroma/chroma_storage
|
||||
ports:
|
||||
- "8001:8000"
|
||||
environment:
|
||||
- CHROMA_SERVER_HOST=0.0.0.0
|
||||
- CHROMA_SERVER_PORT=8000
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/api/v1/heartbeat"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
api:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
- ./app:/app/app
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
|
||||
- OPENROUTER_MODEL=${OPENROUTER_MODEL:-stepfun/step-3.5-flash:free}
|
||||
- GITEA_TOKEN=${GITEA_TOKEN:-}
|
||||
- GITEA_OWNER=${GITEA_OWNER:-cal}
|
||||
- GITEA_REPO=${GITEA_REPO:-strat-chatbot}
|
||||
- DATA_DIR=/app/data
|
||||
- RULES_DIR=/app/data/rules
|
||||
- CHROMA_DIR=/app/data/chroma
|
||||
- DB_URL=sqlite+aiosqlite:///./data/conversations.db
|
||||
- CONVERSATION_TTL=1800
|
||||
- TOP_K_RULES=10
|
||||
- EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
||||
depends_on:
|
||||
chroma:
|
||||
condition: service_healthy
|
||||
command: >
|
||||
sh -c "
|
||||
# Wait for database file creation on first run
|
||||
sleep 2 &&
|
||||
# Initialize database if it doesn't exist
|
||||
python -c 'import asyncio; from app.database import ConversationManager; mgr = ConversationManager(\"sqlite+aiosqlite:///./data/conversations.db\"); asyncio.run(mgr.init_db())' || true &&
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 15s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 30s
|
||||
|
||||
discord-bot:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
- ./app:/app/app
|
||||
environment:
|
||||
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
|
||||
- OPENROUTER_MODEL=${OPENROUTER_MODEL:-stepfun/step-3.5-flash:free}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_GUILD_ID=${DISCORD_GUILD_ID:-}
|
||||
- API_BASE_URL=http://api:8000
|
||||
depends_on:
|
||||
api:
|
||||
condition: service_healthy
|
||||
# Override the default command to run the Discord bot
|
||||
command: >
|
||||
sh -c "
|
||||
echo 'Waiting for API to be ready...' &&
|
||||
while ! curl -s http://api:8000/health > /dev/null; do sleep 2; done &&
|
||||
echo 'API ready, starting Discord bot...' &&
|
||||
python -m app.discord_bot
|
||||
"
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
chroma_data:
|
||||
app_data:
|
||||
44
pyproject.toml
Normal file
44
pyproject.toml
Normal file
@ -0,0 +1,44 @@
|
||||
[project]
|
||||
name = "strat-chatbot"
|
||||
version = "0.1.0"
|
||||
description = "Strat-O-Matic rules Q&A chatbot"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn[standard]>=0.30.0",
|
||||
"discord.py>=2.5.0",
|
||||
"chromadb>=0.5.0",
|
||||
"sentence-transformers>=3.0.0",
|
||||
"openai>=1.0.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"sqlalchemy>=2.0.0",
|
||||
"aiosqlite>=2.0.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic-settings>=2.0.0",
|
||||
"httpx>=0.27.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"black>=24.0.0",
|
||||
"ruff>=0.5.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py311']
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
select = ["E", "F", "B", "I"]
|
||||
target-version = "py311"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
144
scripts/ingest_rules.py
Normal file
144
scripts/ingest_rules.py
Normal file
@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Ingest rule documents from markdown files into ChromaDB.
|
||||
|
||||
The script reads all markdown files from the rules directory and adds them
|
||||
to the vector store. Each file should have YAML frontmatter with metadata
|
||||
fields matching RuleMetadata.
|
||||
|
||||
Example frontmatter:
|
||||
---
|
||||
rule_id: "5.2.1(b)"
|
||||
title: "Stolen Base Attempts"
|
||||
section: "Baserunning"
|
||||
parent_rule: "5.2"
|
||||
page_ref: "32"
|
||||
---
|
||||
|
||||
Rule content here...
|
||||
"""
|
||||
|
||||
import sys
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import yaml
|
||||
|
||||
from app.config import settings
|
||||
from app.vector_store import VectorStore
|
||||
from app.models import RuleDocument, RuleMetadata
|
||||
|
||||
|
||||
def parse_frontmatter(content: str) -> tuple[dict, str]:
|
||||
"""Parse YAML frontmatter from markdown content."""
|
||||
pattern = r"^---\s*\n(.*?)\n---\s*\n(.*)$"
|
||||
match = re.match(pattern, content, re.DOTALL)
|
||||
|
||||
if match:
|
||||
frontmatter_str = match.group(1)
|
||||
body_content = match.group(2).strip()
|
||||
metadata = yaml.safe_load(frontmatter_str) or {}
|
||||
return metadata, body_content
|
||||
else:
|
||||
raise ValueError("No valid YAML frontmatter found")
|
||||
|
||||
|
||||
def load_markdown_file(filepath: Path) -> Optional[RuleDocument]:
|
||||
"""Load a single markdown file and convert to RuleDocument."""
|
||||
try:
|
||||
content = filepath.read_text(encoding="utf-8")
|
||||
metadata_dict, body = parse_frontmatter(content)
|
||||
|
||||
# Validate and create metadata
|
||||
metadata = RuleMetadata(**metadata_dict)
|
||||
|
||||
# Use filename as source reference
|
||||
source_file = str(filepath.relative_to(Path.cwd()))
|
||||
|
||||
return RuleDocument(metadata=metadata, content=body, source_file=source_file)
|
||||
except Exception as e:
|
||||
print(f"Error loading {filepath}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def ingest_rules(
|
||||
rules_dir: Path, vector_store: VectorStore, clear_existing: bool = False
|
||||
) -> None:
|
||||
"""Ingest all markdown rule files into the vector store."""
|
||||
if not rules_dir.exists():
|
||||
print(f"Rules directory does not exist: {rules_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
if clear_existing:
|
||||
print("Clearing existing vector store...")
|
||||
vector_store.clear_all()
|
||||
|
||||
# Find all markdown files
|
||||
md_files = list(rules_dir.rglob("*.md"))
|
||||
if not md_files:
|
||||
print(f"No markdown files found in {rules_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(md_files)} markdown files to ingest")
|
||||
|
||||
# Load and validate documents
|
||||
documents = []
|
||||
for filepath in md_files:
|
||||
doc = load_markdown_file(filepath)
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
print(f" Loaded: {doc.metadata.rule_id} - {doc.metadata.title}")
|
||||
|
||||
print(f"Successfully loaded {len(documents)} documents")
|
||||
|
||||
# Add to vector store
|
||||
print("Adding to vector store (this may take a moment)...")
|
||||
vector_store.add_documents(documents)
|
||||
|
||||
print(f"\nIngestion complete!")
|
||||
print(f"Total rules in store: {vector_store.count()}")
|
||||
stats = vector_store.get_stats()
|
||||
print("Sections:", ", ".join(f"{k}: {v}" for k, v in stats["sections"].items()))
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Ingest rule documents into ChromaDB")
|
||||
parser.add_argument(
|
||||
"--rules-dir",
|
||||
type=Path,
|
||||
default=settings.rules_dir,
|
||||
help="Directory containing markdown rule files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=Path,
|
||||
default=settings.data_dir,
|
||||
help="Data directory (chroma will be stored in data/chroma)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clear",
|
||||
action="store_true",
|
||||
help="Clear existing vector store before ingesting",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default=settings.embedding_model,
|
||||
help="Sentence transformer model name",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
chroma_dir = args.data_dir / "chroma"
|
||||
print(f"Initializing vector store at: {chroma_dir}")
|
||||
print(f"Using embedding model: {args.embedding_model}")
|
||||
|
||||
vector_store = VectorStore(chroma_dir, args.embedding_model)
|
||||
ingest_rules(args.rules_dir, vector_store, clear_existing=args.clear)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
58
setup.sh
Executable file
58
setup.sh
Executable file
@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env bash
|
||||
# Setup script for Strat-Chatbot
|
||||
|
||||
set -e
|
||||
|
||||
echo "=== Strat-Chatbot Setup ==="
|
||||
|
||||
# Check for .env file
|
||||
if [ ! -f .env ]; then
|
||||
echo "Creating .env from template..."
|
||||
cp .env.example .env
|
||||
echo "⚠️ Please edit .env and add your OpenRouter API key (and optionally Discord/Gitea keys)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create necessary directories
|
||||
mkdir -p data/rules
|
||||
mkdir -p data/chroma
|
||||
|
||||
# Check if uv is installed
|
||||
if ! command -v uv &> /dev/null; then
|
||||
echo "Installing uv package manager..."
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
fi
|
||||
|
||||
# Install dependencies
|
||||
echo "Installing Python dependencies..."
|
||||
uv sync
|
||||
|
||||
# Initialize database
|
||||
echo "Initializing database..."
|
||||
uv run python -c "from app.database import ConversationManager; import asyncio; mgr = ConversationManager('sqlite+aiosqlite:///./data/conversations.db'); asyncio.run(mgr.init_db())"
|
||||
|
||||
# Check if rules exist
|
||||
if ! ls data/rules/*.md 1> /dev/null 2>&1; then
|
||||
echo "⚠️ No rule files found in data/rules/"
|
||||
echo " Please add your markdown rule files to data/rules/"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ingest rules
|
||||
echo "Ingesting rules into vector store..."
|
||||
uv run python scripts/ingest_rules.py
|
||||
|
||||
echo "✅ Setup complete!"
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo "1. Ensure your .env file has OPENROUTER_API_KEY set"
|
||||
echo "2. (Optional) Set DISCORD_BOT_TOKEN to enable Discord bot"
|
||||
echo "3. Start the API:"
|
||||
echo " uv run app/main.py"
|
||||
echo ""
|
||||
echo "Or use Docker Compose:"
|
||||
echo " docker compose up -d"
|
||||
echo ""
|
||||
echo "API will be at: http://localhost:8000"
|
||||
echo "Docs at: http://localhost:8000/docs"
|
||||
63
tests/test_basic.py
Normal file
63
tests/test_basic.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Basic test to verify the vector store and ingestion."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "app"))
|
||||
|
||||
from app.config import settings
|
||||
from app.vector_store import VectorStore
|
||||
from app.models import RuleDocument, RuleMetadata
|
||||
|
||||
|
||||
def test_ingest_example_rule():
|
||||
"""Test ingesting the example rule and searching."""
|
||||
# Override settings for test
|
||||
test_data_dir = Path(__file__).parent.parent / "data"
|
||||
test_chroma_dir = test_data_dir / "chroma_test"
|
||||
test_rules_dir = test_data_dir / "rules"
|
||||
|
||||
vs = VectorStore(test_chroma_dir, settings.embedding_model)
|
||||
vs.clear_all()
|
||||
|
||||
# Load example rule
|
||||
example_rule_path = test_rules_dir / "example_rule.md"
|
||||
if not example_rule_path.exists():
|
||||
print(f"Example rule not found at {example_rule_path}, skipping test")
|
||||
return
|
||||
|
||||
content = example_rule_path.read_text(encoding="utf-8")
|
||||
import re
|
||||
import yaml
|
||||
|
||||
pattern = r"^---\s*\n(.*?)\n---\s*\n(.*)$"
|
||||
match = re.match(pattern, content, re.DOTALL)
|
||||
if match:
|
||||
metadata_dict = yaml.safe_load(match.group(1))
|
||||
body = match.group(2).strip()
|
||||
metadata = RuleMetadata(**metadata_dict)
|
||||
doc = RuleDocument(
|
||||
metadata=metadata, content=body, source_file=str(example_rule_path)
|
||||
)
|
||||
vs.add_document(doc)
|
||||
|
||||
# Verify count
|
||||
assert vs.count() == 1, f"Expected 1 rule, got {vs.count()}"
|
||||
|
||||
# Search for relevant content
|
||||
results = vs.search("runner steal base", top_k=5)
|
||||
assert len(results) > 0, "Expected at least one search result"
|
||||
assert (
|
||||
results[0].rule_id == "5.2.1(b)"
|
||||
), f"Expected rule 5.2.1(b), got {results[0].rule_id}"
|
||||
|
||||
print("✓ Test passed: Ingestion and search work correctly")
|
||||
print(f" Found rule: {results[0].title}")
|
||||
print(f" Similarity: {results[0].similarity:.2%}")
|
||||
|
||||
# Cleanup
|
||||
vs.clear_all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ingest_example_rule()
|
||||
Loading…
Reference in New Issue
Block a user