strat-chatbot/app/database.py
Cal Corum c42fea66ba feat: initial chatbot implementation with FastAPI, ChromaDB, Discord bot, and Gitea integration
- Add vector store with sentence-transformers for semantic search
- FastAPI backend with /chat and /health endpoints
- Conversation state persistence via SQLite
- OpenRouter integration with structured JSON responses
- Discord bot with /ask slash command and reply-based follow-ups
- Automated Gitea issue creation for unanswered questions
- Docker support with docker-compose for easy deployment
- Example rule file and ingestion script
- Comprehensive documentation in README
2026-03-08 15:19:26 -05:00

162 lines
5.9 KiB
Python

"""SQLAlchemy-based conversation state management with aiosqlite."""
from datetime import datetime, timedelta
from typing import Optional
import uuid
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, String, DateTime, Boolean, ForeignKey, select
from .config import settings
Base = declarative_base()
class ConversationTable(Base):
"""SQLAlchemy model for conversations."""
__tablename__ = "conversations"
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
channel_id = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
last_activity = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class MessageTable(Base):
"""SQLAlchemy model for messages."""
__tablename__ = "messages"
id = Column(String, primary_key=True)
conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False)
content = Column(String, nullable=False)
is_user = Column(Boolean, nullable=False)
parent_id = Column(String, ForeignKey("messages.id"), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
class ConversationManager:
"""Manages conversation state in SQLite."""
def __init__(self, db_url: str):
"""Initialize database engine and session factory."""
self.engine = create_async_engine(db_url, echo=False)
self.async_session = sessionmaker(
self.engine, class_=AsyncSession, expire_on_commit=False
)
async def init_db(self):
"""Create tables if they don't exist."""
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_or_create_conversation(
self, user_id: str, channel_id: str, conversation_id: Optional[str] = None
) -> str:
"""Get existing conversation or create a new one."""
async with self.async_session() as session:
if conversation_id:
result = await session.execute(
select(ConversationTable).where(
ConversationTable.id == conversation_id
)
)
conv = result.scalar_one_or_none()
if conv:
conv.last_activity = datetime.utcnow()
await session.commit()
return conv.id
# Create new conversation
new_id = str(uuid.uuid4())
conv = ConversationTable(id=new_id, user_id=user_id, channel_id=channel_id)
session.add(conv)
await session.commit()
return new_id
async def add_message(
self,
conversation_id: str,
content: str,
is_user: bool,
parent_id: Optional[str] = None,
) -> str:
"""Add a message to a conversation."""
message_id = str(uuid.uuid4())
async with self.async_session() as session:
msg = MessageTable(
id=message_id,
conversation_id=conversation_id,
content=content,
is_user=is_user,
parent_id=parent_id,
)
session.add(msg)
# Update conversation activity
result = await session.execute(
select(ConversationTable).where(ConversationTable.id == conversation_id)
)
conv = result.scalar_one_or_none()
if conv:
conv.last_activity = datetime.utcnow()
await session.commit()
return message_id
async def get_conversation_history(
self, conversation_id: str, limit: int = 10
) -> list[dict]:
"""Get recent messages from a conversation in OpenAI format."""
async with self.async_session() as session:
result = await session.execute(
select(MessageTable)
.where(MessageTable.conversation_id == conversation_id)
.order_by(MessageTable.created_at.desc())
.limit(limit)
)
messages = result.scalars().all()
# Reverse to get chronological order and convert to API format
history = []
for msg in reversed(messages):
role = "user" if msg.is_user else "assistant"
history.append({"role": role, "content": msg.content})
return history
async def cleanup_old_conversations(self, ttl_seconds: int = 1800):
"""Delete conversations older than TTL to free up storage."""
cutoff = datetime.utcnow() - timedelta(seconds=ttl_seconds)
async with self.async_session() as session:
# Find old conversations
result = await session.execute(
select(ConversationTable).where(
ConversationTable.last_activity < cutoff
)
)
old_convs = result.scalars().all()
conv_ids = [conv.id for conv in old_convs]
if conv_ids:
# Delete messages first (cascade would handle but explicit is clear)
await session.execute(
sa.delete(MessageTable).where(
MessageTable.conversation_id.in_(conv_ids)
)
)
# Delete conversations
await session.execute(
sa.delete(ConversationTable).where(
ConversationTable.id.in_(conv_ids)
)
)
await session.commit()
print(f"Cleaned up {len(conv_ids)} old conversations")
async def get_conversation_manager() -> ConversationManager:
"""Dependency for FastAPI to get a ConversationManager instance."""
return ConversationManager(settings.db_url)