"""SQLAlchemy-based conversation state management with aiosqlite.""" from datetime import datetime, timedelta from typing import Optional import uuid import sqlalchemy as sa from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy import Column, String, DateTime, Boolean, ForeignKey, select from .config import settings Base = declarative_base() class ConversationTable(Base): """SQLAlchemy model for conversations.""" __tablename__ = "conversations" id = Column(String, primary_key=True) user_id = Column(String, nullable=False) channel_id = Column(String, nullable=False) created_at = Column(DateTime, default=datetime.utcnow) last_activity = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) class MessageTable(Base): """SQLAlchemy model for messages.""" __tablename__ = "messages" id = Column(String, primary_key=True) conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False) content = Column(String, nullable=False) is_user = Column(Boolean, nullable=False) parent_id = Column(String, ForeignKey("messages.id"), nullable=True) created_at = Column(DateTime, default=datetime.utcnow) class ConversationManager: """Manages conversation state in SQLite.""" def __init__(self, db_url: str): """Initialize database engine and session factory.""" self.engine = create_async_engine(db_url, echo=False) self.async_session = sessionmaker( self.engine, class_=AsyncSession, expire_on_commit=False ) async def init_db(self): """Create tables if they don't exist.""" async with self.engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) async def get_or_create_conversation( self, user_id: str, channel_id: str, conversation_id: Optional[str] = None ) -> str: """Get existing conversation or create a new one.""" async with self.async_session() as session: if conversation_id: result = await session.execute( select(ConversationTable).where( ConversationTable.id == conversation_id ) ) conv = result.scalar_one_or_none() if conv: conv.last_activity = datetime.utcnow() await session.commit() return conv.id # Create new conversation new_id = str(uuid.uuid4()) conv = ConversationTable(id=new_id, user_id=user_id, channel_id=channel_id) session.add(conv) await session.commit() return new_id async def add_message( self, conversation_id: str, content: str, is_user: bool, parent_id: Optional[str] = None, ) -> str: """Add a message to a conversation.""" message_id = str(uuid.uuid4()) async with self.async_session() as session: msg = MessageTable( id=message_id, conversation_id=conversation_id, content=content, is_user=is_user, parent_id=parent_id, ) session.add(msg) # Update conversation activity result = await session.execute( select(ConversationTable).where(ConversationTable.id == conversation_id) ) conv = result.scalar_one_or_none() if conv: conv.last_activity = datetime.utcnow() await session.commit() return message_id async def get_conversation_history( self, conversation_id: str, limit: int = 10 ) -> list[dict]: """Get recent messages from a conversation in OpenAI format.""" async with self.async_session() as session: result = await session.execute( select(MessageTable) .where(MessageTable.conversation_id == conversation_id) .order_by(MessageTable.created_at.desc()) .limit(limit) ) messages = result.scalars().all() # Reverse to get chronological order and convert to API format history = [] for msg in reversed(messages): role = "user" if msg.is_user else "assistant" history.append({"role": role, "content": msg.content}) return history async def cleanup_old_conversations(self, ttl_seconds: int = 1800): """Delete conversations older than TTL to free up storage.""" cutoff = datetime.utcnow() - timedelta(seconds=ttl_seconds) async with self.async_session() as session: # Find old conversations result = await session.execute( select(ConversationTable).where( ConversationTable.last_activity < cutoff ) ) old_convs = result.scalars().all() conv_ids = [conv.id for conv in old_convs] if conv_ids: # Delete messages first (cascade would handle but explicit is clear) await session.execute( sa.delete(MessageTable).where( MessageTable.conversation_id.in_(conv_ids) ) ) # Delete conversations await session.execute( sa.delete(ConversationTable).where( ConversationTable.id.in_(conv_ids) ) ) await session.commit() print(f"Cleaned up {len(conv_ids)} old conversations") async def get_conversation_manager() -> ConversationManager: """Dependency for FastAPI to get a ConversationManager instance.""" return ConversationManager(settings.db_url)