"""SQLite outbound adapter implementing the ConversationStore port. Uses SQLAlchemy 2.x async API with aiosqlite as the driver. Designed to be instantiated once at application startup; call `await init_db()` before use. """ import logging import uuid from datetime import datetime, timedelta, timezone from typing import Optional import sqlalchemy as sa from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, select from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.orm import declarative_base from domain.ports import ConversationStore logger = logging.getLogger(__name__) Base = declarative_base() # --------------------------------------------------------------------------- # ORM table definitions # --------------------------------------------------------------------------- class _ConversationRow(Base): """SQLAlchemy table model for a conversation session.""" __tablename__ = "conversations" id = Column(String, primary_key=True) user_id = Column(String, nullable=False) channel_id = Column(String, nullable=False) created_at = Column( DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), ) last_activity = Column( DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc), ) class _MessageRow(Base): """SQLAlchemy table model for a single chat message.""" __tablename__ = "messages" id = Column(String, primary_key=True) conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False) content = Column(String, nullable=False) is_user = Column(Boolean, nullable=False) parent_id = Column(String, ForeignKey("messages.id"), nullable=True) created_at = Column( DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), ) # --------------------------------------------------------------------------- # Adapter # --------------------------------------------------------------------------- class SQLiteConversationStore(ConversationStore): """Persists conversation state to a SQLite database via SQLAlchemy async. Parameters ---------- db_url: SQLAlchemy async connection URL, e.g. ``"sqlite+aiosqlite:///path/to/conversations.db"`` or ``"sqlite+aiosqlite://"`` for an in-memory database. """ def __init__(self, db_url: str) -> None: self._engine = create_async_engine(db_url, echo=False) # async_sessionmaker is the modern (SQLAlchemy 2.0) replacement for # sessionmaker(class_=AsyncSession, ...). self._session_factory: async_sessionmaker[AsyncSession] = async_sessionmaker( self._engine, expire_on_commit=False ) async def init_db(self) -> None: """Create database tables if they do not already exist. Must be called before any other method. """ async with self._engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) logger.debug("Database tables initialised") # ------------------------------------------------------------------ # ConversationStore implementation # ------------------------------------------------------------------ async def get_or_create_conversation( self, user_id: str, channel_id: str, conversation_id: Optional[str] = None, ) -> str: """Return *conversation_id* if it exists in the DB, otherwise create a new conversation row and return its fresh ID. If *conversation_id* is supplied but not found (e.g. after a restart with a clean in-memory DB), a new conversation is created transparently rather than raising an error. """ async with self._session_factory() as session: if conversation_id: result = await session.execute( select(_ConversationRow).where( _ConversationRow.id == conversation_id ) ) row = result.scalar_one_or_none() if row is not None: row.last_activity = datetime.now(timezone.utc) await session.commit() logger.debug( "Resumed existing conversation %s for user %s", conversation_id, user_id, ) return row.id logger.warning( "Conversation %s not found; creating a new one", conversation_id ) new_id = str(uuid.uuid4()) session.add( _ConversationRow( id=new_id, user_id=user_id, channel_id=channel_id, created_at=datetime.now(timezone.utc), last_activity=datetime.now(timezone.utc), ) ) await session.commit() logger.debug( "Created conversation %s for user %s in channel %s", new_id, user_id, channel_id, ) return new_id async def add_message( self, conversation_id: str, content: str, is_user: bool, parent_id: Optional[str] = None, ) -> str: """Append a message to *conversation_id* and update last_activity. Returns the new message's UUID string. """ message_id = str(uuid.uuid4()) now = datetime.now(timezone.utc) async with self._session_factory() as session: session.add( _MessageRow( id=message_id, conversation_id=conversation_id, content=content, is_user=is_user, parent_id=parent_id, created_at=now, ) ) # Bump last_activity on the parent conversation. result = await session.execute( select(_ConversationRow).where(_ConversationRow.id == conversation_id) ) conv = result.scalar_one_or_none() if conv is not None: conv.last_activity = now else: logger.warning( "add_message: conversation %s not found; message stored orphaned", conversation_id, ) await session.commit() logger.debug( "Added message %s to conversation %s (is_user=%s)", message_id, conversation_id, is_user, ) return message_id async def get_conversation_history( self, conversation_id: str, limit: int = 10 ) -> list[dict[str, str]]: """Return the most-recent *limit* messages in chronological order. The query fetches the newest rows (ORDER BY created_at DESC LIMIT n) then reverses the list so callers receive oldest-first ordering, which is what LLM APIs expect in the ``messages`` array. Returns a list of ``{"role": "user"|"assistant", "content": "..."}`` dicts compatible with the OpenAI chat-completion format. """ async with self._session_factory() as session: result = await session.execute( select(_MessageRow) .where(_MessageRow.conversation_id == conversation_id) .order_by(_MessageRow.created_at.desc()) .limit(limit) ) rows = result.scalars().all() # Reverse so the list is oldest → newest (chronological). history: list[dict[str, str]] = [ { "role": "user" if row.is_user else "assistant", "content": row.content, } for row in reversed(rows) ] logger.debug( "Retrieved %d messages for conversation %s (limit=%d)", len(history), conversation_id, limit, ) return history # ------------------------------------------------------------------ # Housekeeping # ------------------------------------------------------------------ async def cleanup_old_conversations(self, ttl_seconds: int = 1800) -> None: """Delete conversations (and their messages) older than *ttl_seconds*. Useful as a periodic background task to keep the database small. Messages are deleted first to satisfy the foreign-key constraint even in databases where cascade deletes are not configured. """ cutoff = datetime.now(timezone.utc) - timedelta(seconds=ttl_seconds) async with self._session_factory() as session: result = await session.execute( select(_ConversationRow).where(_ConversationRow.last_activity < cutoff) ) old_rows = result.scalars().all() conv_ids = [row.id for row in old_rows] if not conv_ids: logger.debug("cleanup_old_conversations: nothing to remove") return await session.execute( sa.delete(_MessageRow).where(_MessageRow.conversation_id.in_(conv_ids)) ) await session.execute( sa.delete(_ConversationRow).where(_ConversationRow.id.in_(conv_ids)) ) await session.commit() logger.info( "Cleaned up %d stale conversations (TTL=%ds)", len(conv_ids), ttl_seconds )