strat-chatbot/app/database.py
Cal Corum c2c7f7d3c2 fix: resolve 4 critical bugs found in code review
- Discord bot: store full conversation UUID in footer instead of truncated
  8-char prefix, fixing completely broken follow-up threading. Add footer
  to follow-up embeds so conversation chains work beyond depth 1. Edit
  loading message in-place instead of leaving ghost messages. Replace bare
  except with specific exception types. Fix channel_id attribute access.
- GiteaClient: remove broken async context manager pattern that caused
  every create_unanswered_issue call to raise RuntimeError. Use per-request
  httpx.AsyncClient instead.
- Database: return singleton ConversationManager from app.state instead of
  creating a new SQLAlchemy engine (and connection pool) on every request.
- Vector store: clamp cosine similarity to [0, 1] to prevent Pydantic
  ValidationError crashes when ChromaDB returns distances > 1.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 15:31:11 -05:00

163 lines
5.9 KiB
Python

"""SQLAlchemy-based conversation state management with aiosqlite."""
from datetime import datetime, timedelta
from typing import Optional
import uuid
import sqlalchemy as sa
from fastapi import Request
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, String, DateTime, Boolean, ForeignKey, select
from .config import settings
Base = declarative_base()
class ConversationTable(Base):
"""SQLAlchemy model for conversations."""
__tablename__ = "conversations"
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
channel_id = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
last_activity = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class MessageTable(Base):
"""SQLAlchemy model for messages."""
__tablename__ = "messages"
id = Column(String, primary_key=True)
conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False)
content = Column(String, nullable=False)
is_user = Column(Boolean, nullable=False)
parent_id = Column(String, ForeignKey("messages.id"), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
class ConversationManager:
"""Manages conversation state in SQLite."""
def __init__(self, db_url: str):
"""Initialize database engine and session factory."""
self.engine = create_async_engine(db_url, echo=False)
self.async_session = sessionmaker(
self.engine, class_=AsyncSession, expire_on_commit=False
)
async def init_db(self):
"""Create tables if they don't exist."""
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_or_create_conversation(
self, user_id: str, channel_id: str, conversation_id: Optional[str] = None
) -> str:
"""Get existing conversation or create a new one."""
async with self.async_session() as session:
if conversation_id:
result = await session.execute(
select(ConversationTable).where(
ConversationTable.id == conversation_id
)
)
conv = result.scalar_one_or_none()
if conv:
conv.last_activity = datetime.utcnow()
await session.commit()
return conv.id
# Create new conversation
new_id = str(uuid.uuid4())
conv = ConversationTable(id=new_id, user_id=user_id, channel_id=channel_id)
session.add(conv)
await session.commit()
return new_id
async def add_message(
self,
conversation_id: str,
content: str,
is_user: bool,
parent_id: Optional[str] = None,
) -> str:
"""Add a message to a conversation."""
message_id = str(uuid.uuid4())
async with self.async_session() as session:
msg = MessageTable(
id=message_id,
conversation_id=conversation_id,
content=content,
is_user=is_user,
parent_id=parent_id,
)
session.add(msg)
# Update conversation activity
result = await session.execute(
select(ConversationTable).where(ConversationTable.id == conversation_id)
)
conv = result.scalar_one_or_none()
if conv:
conv.last_activity = datetime.utcnow()
await session.commit()
return message_id
async def get_conversation_history(
self, conversation_id: str, limit: int = 10
) -> list[dict]:
"""Get recent messages from a conversation in OpenAI format."""
async with self.async_session() as session:
result = await session.execute(
select(MessageTable)
.where(MessageTable.conversation_id == conversation_id)
.order_by(MessageTable.created_at.desc())
.limit(limit)
)
messages = result.scalars().all()
# Reverse to get chronological order and convert to API format
history = []
for msg in reversed(messages):
role = "user" if msg.is_user else "assistant"
history.append({"role": role, "content": msg.content})
return history
async def cleanup_old_conversations(self, ttl_seconds: int = 1800):
"""Delete conversations older than TTL to free up storage."""
cutoff = datetime.utcnow() - timedelta(seconds=ttl_seconds)
async with self.async_session() as session:
# Find old conversations
result = await session.execute(
select(ConversationTable).where(
ConversationTable.last_activity < cutoff
)
)
old_convs = result.scalars().all()
conv_ids = [conv.id for conv in old_convs]
if conv_ids:
# Delete messages first (cascade would handle but explicit is clear)
await session.execute(
sa.delete(MessageTable).where(
MessageTable.conversation_id.in_(conv_ids)
)
)
# Delete conversations
await session.execute(
sa.delete(ConversationTable).where(
ConversationTable.id.in_(conv_ids)
)
)
await session.commit()
print(f"Cleaned up {len(conv_ids)} old conversations")
async def get_conversation_manager(request: Request) -> ConversationManager:
"""Dependency for FastAPI to get the singleton ConversationManager from app state."""
return request.app.state.db_manager