- Add vector store with sentence-transformers for semantic search - FastAPI backend with /chat and /health endpoints - Conversation state persistence via SQLite - OpenRouter integration with structured JSON responses - Discord bot with /ask slash command and reply-based follow-ups - Automated Gitea issue creation for unanswered questions - Docker support with docker-compose for easy deployment - Example rule file and ingestion script - Comprehensive documentation in README
162 lines
5.9 KiB
Python
162 lines
5.9 KiB
Python
"""SQLAlchemy-based conversation state management with aiosqlite."""
|
|
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional
|
|
import uuid
|
|
import sqlalchemy as sa
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
|
from sqlalchemy import Column, String, DateTime, Boolean, ForeignKey, select
|
|
from .config import settings
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class ConversationTable(Base):
|
|
"""SQLAlchemy model for conversations."""
|
|
|
|
__tablename__ = "conversations"
|
|
|
|
id = Column(String, primary_key=True)
|
|
user_id = Column(String, nullable=False)
|
|
channel_id = Column(String, nullable=False)
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
last_activity = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
|
|
|
|
class MessageTable(Base):
|
|
"""SQLAlchemy model for messages."""
|
|
|
|
__tablename__ = "messages"
|
|
|
|
id = Column(String, primary_key=True)
|
|
conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False)
|
|
content = Column(String, nullable=False)
|
|
is_user = Column(Boolean, nullable=False)
|
|
parent_id = Column(String, ForeignKey("messages.id"), nullable=True)
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
|
|
|
|
class ConversationManager:
|
|
"""Manages conversation state in SQLite."""
|
|
|
|
def __init__(self, db_url: str):
|
|
"""Initialize database engine and session factory."""
|
|
self.engine = create_async_engine(db_url, echo=False)
|
|
self.async_session = sessionmaker(
|
|
self.engine, class_=AsyncSession, expire_on_commit=False
|
|
)
|
|
|
|
async def init_db(self):
|
|
"""Create tables if they don't exist."""
|
|
async with self.engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
async def get_or_create_conversation(
|
|
self, user_id: str, channel_id: str, conversation_id: Optional[str] = None
|
|
) -> str:
|
|
"""Get existing conversation or create a new one."""
|
|
async with self.async_session() as session:
|
|
if conversation_id:
|
|
result = await session.execute(
|
|
select(ConversationTable).where(
|
|
ConversationTable.id == conversation_id
|
|
)
|
|
)
|
|
conv = result.scalar_one_or_none()
|
|
if conv:
|
|
conv.last_activity = datetime.utcnow()
|
|
await session.commit()
|
|
return conv.id
|
|
|
|
# Create new conversation
|
|
new_id = str(uuid.uuid4())
|
|
conv = ConversationTable(id=new_id, user_id=user_id, channel_id=channel_id)
|
|
session.add(conv)
|
|
await session.commit()
|
|
return new_id
|
|
|
|
async def add_message(
|
|
self,
|
|
conversation_id: str,
|
|
content: str,
|
|
is_user: bool,
|
|
parent_id: Optional[str] = None,
|
|
) -> str:
|
|
"""Add a message to a conversation."""
|
|
message_id = str(uuid.uuid4())
|
|
async with self.async_session() as session:
|
|
msg = MessageTable(
|
|
id=message_id,
|
|
conversation_id=conversation_id,
|
|
content=content,
|
|
is_user=is_user,
|
|
parent_id=parent_id,
|
|
)
|
|
session.add(msg)
|
|
|
|
# Update conversation activity
|
|
result = await session.execute(
|
|
select(ConversationTable).where(ConversationTable.id == conversation_id)
|
|
)
|
|
conv = result.scalar_one_or_none()
|
|
if conv:
|
|
conv.last_activity = datetime.utcnow()
|
|
|
|
await session.commit()
|
|
return message_id
|
|
|
|
async def get_conversation_history(
|
|
self, conversation_id: str, limit: int = 10
|
|
) -> list[dict]:
|
|
"""Get recent messages from a conversation in OpenAI format."""
|
|
async with self.async_session() as session:
|
|
result = await session.execute(
|
|
select(MessageTable)
|
|
.where(MessageTable.conversation_id == conversation_id)
|
|
.order_by(MessageTable.created_at.desc())
|
|
.limit(limit)
|
|
)
|
|
messages = result.scalars().all()
|
|
# Reverse to get chronological order and convert to API format
|
|
history = []
|
|
for msg in reversed(messages):
|
|
role = "user" if msg.is_user else "assistant"
|
|
history.append({"role": role, "content": msg.content})
|
|
|
|
return history
|
|
|
|
async def cleanup_old_conversations(self, ttl_seconds: int = 1800):
|
|
"""Delete conversations older than TTL to free up storage."""
|
|
cutoff = datetime.utcnow() - timedelta(seconds=ttl_seconds)
|
|
async with self.async_session() as session:
|
|
# Find old conversations
|
|
result = await session.execute(
|
|
select(ConversationTable).where(
|
|
ConversationTable.last_activity < cutoff
|
|
)
|
|
)
|
|
old_convs = result.scalars().all()
|
|
|
|
conv_ids = [conv.id for conv in old_convs]
|
|
if conv_ids:
|
|
# Delete messages first (cascade would handle but explicit is clear)
|
|
await session.execute(
|
|
sa.delete(MessageTable).where(
|
|
MessageTable.conversation_id.in_(conv_ids)
|
|
)
|
|
)
|
|
# Delete conversations
|
|
await session.execute(
|
|
sa.delete(ConversationTable).where(
|
|
ConversationTable.id.in_(conv_ids)
|
|
)
|
|
)
|
|
await session.commit()
|
|
print(f"Cleaned up {len(conv_ids)} old conversations")
|
|
|
|
|
|
async def get_conversation_manager() -> ConversationManager:
|
|
"""Dependency for FastAPI to get a ConversationManager instance."""
|
|
return ConversationManager(settings.db_url)
|