strat-chatbot/app/main.py
Cal Corum c42fea66ba feat: initial chatbot implementation with FastAPI, ChromaDB, Discord bot, and Gitea integration
- Add vector store with sentence-transformers for semantic search
- FastAPI backend with /chat and /health endpoints
- Conversation state persistence via SQLite
- OpenRouter integration with structured JSON responses
- Discord bot with /ask slash command and reply-based follow-ups
- Automated Gitea issue creation for unanswered questions
- Docker support with docker-compose for easy deployment
- Example rule file and ingestion script
- Comprehensive documentation in README
2026-03-08 15:19:26 -05:00

199 lines
6.2 KiB
Python

"""FastAPI application for Strat-O-Matic rules chatbot."""
from contextlib import asynccontextmanager
from typing import Optional
import uuid
from fastapi import FastAPI, HTTPException, Depends
import uvicorn
import sqlalchemy as sa
from .config import settings
from .models import ChatRequest, ChatResponse
from .vector_store import VectorStore
from .database import ConversationManager, get_conversation_manager
from .llm import get_llm_client
from .gitea import GiteaClient
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan - startup and shutdown."""
# Startup
print("Initializing Strat-Chatbot...")
# Initialize vector store
chroma_dir = settings.data_dir / "chroma"
vector_store = VectorStore(chroma_dir, settings.embedding_model)
print(f"Vector store ready at {chroma_dir} ({vector_store.count()} rules loaded)")
# Initialize database
db_manager = ConversationManager(settings.db_url)
await db_manager.init_db()
print("Database initialized")
# Initialize LLM client
llm_client = get_llm_client(use_mock=not settings.openrouter_api_key)
print(f"LLM client ready (model: {settings.openrouter_model})")
# Initialize Gitea client
gitea_client = GiteaClient() if settings.gitea_token else None
# Store in app state
app.state.vector_store = vector_store
app.state.db_manager = db_manager
app.state.llm_client = llm_client
app.state.gitea_client = gitea_client
print("Strat-Chatbot ready!")
yield
# Shutdown
print("Shutting down...")
app = FastAPI(
title="Strat-Chatbot",
description="Strat-O-Matic rules Q&A API",
version="0.1.0",
lifespan=lifespan,
)
@app.get("/health")
async def health_check():
"""Health check endpoint."""
vector_store: VectorStore = app.state.vector_store
stats = vector_store.get_stats()
return {
"status": "healthy",
"rules_count": stats["total_rules"],
"sections": stats["sections"],
}
@app.post("/chat", response_model=ChatResponse)
async def chat(
request: ChatRequest,
db_manager: ConversationManager = Depends(get_conversation_manager),
):
"""Handle chat requests from Discord."""
vector_store: VectorStore = app.state.vector_store
llm_client = app.state.llm_client
gitea_client = app.state.gitea_client
# Validate API key if using real LLM
if not settings.openrouter_api_key:
return ChatResponse(
response="⚠️ OpenRouter API key not configured. Set OPENROUTER_API_KEY environment variable.",
conversation_id=request.conversation_id or str(uuid.uuid4()),
message_id=str(uuid.uuid4()),
cited_rules=[],
confidence=0.0,
needs_human=True,
)
# Get or create conversation
conversation_id = await db_manager.get_or_create_conversation(
user_id=request.user_id,
channel_id=request.channel_id,
conversation_id=request.conversation_id,
)
# Save user message
user_message_id = await db_manager.add_message(
conversation_id=conversation_id,
content=request.message,
is_user=True,
parent_id=request.parent_message_id,
)
try:
# Search for relevant rules
search_results = vector_store.search(
query=request.message, top_k=settings.top_k_rules
)
# Get conversation history for context
history = await db_manager.get_conversation_history(conversation_id, limit=10)
# Generate response from LLM
response = await llm_client.generate_response(
question=request.message, rules=search_results, conversation_history=history
)
# Save assistant message
assistant_message_id = await db_manager.add_message(
conversation_id=conversation_id,
content=response.response,
is_user=False,
parent_id=user_message_id,
)
# If needs human or confidence is low, create Gitea issue
if gitea_client and (response.needs_human or response.confidence < 0.4):
try:
issue_url = await gitea_client.create_unanswered_issue(
question=request.message,
user_id=request.user_id,
channel_id=request.channel_id,
attempted_rules=[r.rule_id for r in search_results],
conversation_id=conversation_id,
)
print(f"Created Gitea issue: {issue_url}")
except Exception as e:
print(f"Failed to create Gitea issue: {e}")
# Build final response
return ChatResponse(
response=response.response,
conversation_id=conversation_id,
message_id=assistant_message_id,
parent_message_id=user_message_id,
cited_rules=response.cited_rules,
confidence=response.confidence,
needs_human=response.needs_human,
)
except Exception as e:
print(f"Error processing chat request: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stats")
async def stats():
"""Get statistics about the knowledge base and system."""
vector_store: VectorStore = app.state.vector_store
db_manager: ConversationManager = app.state.db_manager
# Get vector store stats
vs_stats = vector_store.get_stats()
# Get database stats
async with db_manager.async_session() as session:
conv_count = await session.execute(
sa.text("SELECT COUNT(*) FROM conversations")
)
msg_count = await session.execute(sa.text("SELECT COUNT(*) FROM messages"))
total_conversations = conv_count.scalar() or 0
total_messages = msg_count.scalar() or 0
return {
"knowledge_base": vs_stats,
"conversations": {
"total": total_conversations,
"total_messages": total_messages,
},
"config": {
"openrouter_model": settings.openrouter_model,
"top_k_rules": settings.top_k_rules,
"embedding_model": settings.embedding_model,
},
}
if __name__ == "__main__":
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)