Domain layer (zero framework imports): - domain/models.py: pure dataclasses (RuleDocument, RuleSearchResult, Conversation, ChatMessage, LLMResponse, ChatResult) - domain/ports.py: ABC interfaces (RuleRepository, LLMPort, ConversationStore, IssueTracker) - domain/services.py: ChatService orchestrates Q&A flow using only ports Outbound adapters (implement domain ports): - adapters/outbound/openrouter.py: OpenRouterLLM with persistent httpx client, robust JSON parsing, regex citation fallback - adapters/outbound/sqlite_convos.py: SQLiteConversationStore with async_sessionmaker, timezone-aware datetimes, cleanup support - adapters/outbound/gitea_issues.py: GiteaIssueTracker with markdown injection protection (fenced code blocks) - adapters/outbound/chroma_rules.py: ChromaRuleRepository with clamped similarity scores Inbound adapter: - adapters/inbound/api.py: thin FastAPI router with input validation (max_length constraints), proper HTTP status codes (503 for missing LLM) Configuration & wiring: - config/settings.py: Pydantic v2 SettingsConfigDict (no module-level singleton) - config/container.py: create_app() factory with lifespan-managed DI - main.py: minimal entry point Test infrastructure (90 tests, all passing): - tests/fakes/: in-memory implementations of all 4 ports - tests/domain/: 26 tests for models and ChatService - tests/adapters/: 64 tests for all adapters using fakes/mocks - No real API calls, no model downloads, no disk I/O in fast tests Also fixes: aiosqlite version constraint (>=0.19.0), adds hatch build targets for new package layout. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
278 lines
9.6 KiB
Python
278 lines
9.6 KiB
Python
"""SQLite outbound adapter implementing the ConversationStore port.
|
|
|
|
Uses SQLAlchemy 2.x async API with aiosqlite as the driver. Designed to be
|
|
instantiated once at application startup; call `await init_db()` before use.
|
|
"""
|
|
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
|
|
import sqlalchemy as sa
|
|
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, select
|
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
|
from sqlalchemy.orm import declarative_base
|
|
|
|
from domain.ports import ConversationStore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ORM table definitions
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _ConversationRow(Base):
|
|
"""SQLAlchemy table model for a conversation session."""
|
|
|
|
__tablename__ = "conversations"
|
|
|
|
id = Column(String, primary_key=True)
|
|
user_id = Column(String, nullable=False)
|
|
channel_id = Column(String, nullable=False)
|
|
created_at = Column(
|
|
DateTime(timezone=True),
|
|
nullable=False,
|
|
default=lambda: datetime.now(timezone.utc),
|
|
)
|
|
last_activity = Column(
|
|
DateTime(timezone=True),
|
|
nullable=False,
|
|
default=lambda: datetime.now(timezone.utc),
|
|
onupdate=lambda: datetime.now(timezone.utc),
|
|
)
|
|
|
|
|
|
class _MessageRow(Base):
|
|
"""SQLAlchemy table model for a single chat message."""
|
|
|
|
__tablename__ = "messages"
|
|
|
|
id = Column(String, primary_key=True)
|
|
conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False)
|
|
content = Column(String, nullable=False)
|
|
is_user = Column(Boolean, nullable=False)
|
|
parent_id = Column(String, ForeignKey("messages.id"), nullable=True)
|
|
created_at = Column(
|
|
DateTime(timezone=True),
|
|
nullable=False,
|
|
default=lambda: datetime.now(timezone.utc),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Adapter
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class SQLiteConversationStore(ConversationStore):
|
|
"""Persists conversation state to a SQLite database via SQLAlchemy async.
|
|
|
|
Parameters
|
|
----------
|
|
db_url:
|
|
SQLAlchemy async connection URL, e.g.
|
|
``"sqlite+aiosqlite:///path/to/conversations.db"`` or
|
|
``"sqlite+aiosqlite://"`` for an in-memory database.
|
|
"""
|
|
|
|
def __init__(self, db_url: str) -> None:
|
|
self._engine = create_async_engine(db_url, echo=False)
|
|
# async_sessionmaker is the modern (SQLAlchemy 2.0) replacement for
|
|
# sessionmaker(class_=AsyncSession, ...).
|
|
self._session_factory: async_sessionmaker[AsyncSession] = async_sessionmaker(
|
|
self._engine, expire_on_commit=False
|
|
)
|
|
|
|
async def init_db(self) -> None:
|
|
"""Create database tables if they do not already exist.
|
|
|
|
Must be called before any other method.
|
|
"""
|
|
async with self._engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
logger.debug("Database tables initialised")
|
|
|
|
# ------------------------------------------------------------------
|
|
# ConversationStore implementation
|
|
# ------------------------------------------------------------------
|
|
|
|
async def get_or_create_conversation(
|
|
self,
|
|
user_id: str,
|
|
channel_id: str,
|
|
conversation_id: Optional[str] = None,
|
|
) -> str:
|
|
"""Return *conversation_id* if it exists in the DB, otherwise create a
|
|
new conversation row and return its fresh ID.
|
|
|
|
If *conversation_id* is supplied but not found (e.g. after a restart
|
|
with a clean in-memory DB), a new conversation is created transparently
|
|
rather than raising an error.
|
|
"""
|
|
async with self._session_factory() as session:
|
|
if conversation_id:
|
|
result = await session.execute(
|
|
select(_ConversationRow).where(
|
|
_ConversationRow.id == conversation_id
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
if row is not None:
|
|
row.last_activity = datetime.now(timezone.utc)
|
|
await session.commit()
|
|
logger.debug(
|
|
"Resumed existing conversation %s for user %s",
|
|
conversation_id,
|
|
user_id,
|
|
)
|
|
return row.id
|
|
logger.warning(
|
|
"Conversation %s not found; creating a new one", conversation_id
|
|
)
|
|
|
|
new_id = str(uuid.uuid4())
|
|
session.add(
|
|
_ConversationRow(
|
|
id=new_id,
|
|
user_id=user_id,
|
|
channel_id=channel_id,
|
|
created_at=datetime.now(timezone.utc),
|
|
last_activity=datetime.now(timezone.utc),
|
|
)
|
|
)
|
|
await session.commit()
|
|
logger.debug(
|
|
"Created conversation %s for user %s in channel %s",
|
|
new_id,
|
|
user_id,
|
|
channel_id,
|
|
)
|
|
return new_id
|
|
|
|
async def add_message(
|
|
self,
|
|
conversation_id: str,
|
|
content: str,
|
|
is_user: bool,
|
|
parent_id: Optional[str] = None,
|
|
) -> str:
|
|
"""Append a message to *conversation_id* and update last_activity.
|
|
|
|
Returns the new message's UUID string.
|
|
"""
|
|
message_id = str(uuid.uuid4())
|
|
now = datetime.now(timezone.utc)
|
|
|
|
async with self._session_factory() as session:
|
|
session.add(
|
|
_MessageRow(
|
|
id=message_id,
|
|
conversation_id=conversation_id,
|
|
content=content,
|
|
is_user=is_user,
|
|
parent_id=parent_id,
|
|
created_at=now,
|
|
)
|
|
)
|
|
|
|
# Bump last_activity on the parent conversation.
|
|
result = await session.execute(
|
|
select(_ConversationRow).where(_ConversationRow.id == conversation_id)
|
|
)
|
|
conv = result.scalar_one_or_none()
|
|
if conv is not None:
|
|
conv.last_activity = now
|
|
else:
|
|
logger.warning(
|
|
"add_message: conversation %s not found; message stored orphaned",
|
|
conversation_id,
|
|
)
|
|
|
|
await session.commit()
|
|
|
|
logger.debug(
|
|
"Added message %s to conversation %s (is_user=%s)",
|
|
message_id,
|
|
conversation_id,
|
|
is_user,
|
|
)
|
|
return message_id
|
|
|
|
async def get_conversation_history(
|
|
self, conversation_id: str, limit: int = 10
|
|
) -> list[dict[str, str]]:
|
|
"""Return the most-recent *limit* messages in chronological order.
|
|
|
|
The query fetches the newest rows (ORDER BY created_at DESC LIMIT n)
|
|
then reverses the list so callers receive oldest-first ordering,
|
|
which is what LLM APIs expect in the ``messages`` array.
|
|
|
|
Returns a list of ``{"role": "user"|"assistant", "content": "..."}``
|
|
dicts compatible with the OpenAI chat-completion format.
|
|
"""
|
|
async with self._session_factory() as session:
|
|
result = await session.execute(
|
|
select(_MessageRow)
|
|
.where(_MessageRow.conversation_id == conversation_id)
|
|
.order_by(_MessageRow.created_at.desc())
|
|
.limit(limit)
|
|
)
|
|
rows = result.scalars().all()
|
|
|
|
# Reverse so the list is oldest → newest (chronological).
|
|
history: list[dict[str, str]] = [
|
|
{
|
|
"role": "user" if row.is_user else "assistant",
|
|
"content": row.content,
|
|
}
|
|
for row in reversed(rows)
|
|
]
|
|
logger.debug(
|
|
"Retrieved %d messages for conversation %s (limit=%d)",
|
|
len(history),
|
|
conversation_id,
|
|
limit,
|
|
)
|
|
return history
|
|
|
|
# ------------------------------------------------------------------
|
|
# Housekeeping
|
|
# ------------------------------------------------------------------
|
|
|
|
async def cleanup_old_conversations(self, ttl_seconds: int = 1800) -> None:
|
|
"""Delete conversations (and their messages) older than *ttl_seconds*.
|
|
|
|
Useful as a periodic background task to keep the database small.
|
|
Messages are deleted first to satisfy the foreign-key constraint even
|
|
in databases where cascade deletes are not configured.
|
|
"""
|
|
cutoff = datetime.now(timezone.utc) - timedelta(seconds=ttl_seconds)
|
|
|
|
async with self._session_factory() as session:
|
|
result = await session.execute(
|
|
select(_ConversationRow).where(_ConversationRow.last_activity < cutoff)
|
|
)
|
|
old_rows = result.scalars().all()
|
|
conv_ids = [row.id for row in old_rows]
|
|
|
|
if not conv_ids:
|
|
logger.debug("cleanup_old_conversations: nothing to remove")
|
|
return
|
|
|
|
await session.execute(
|
|
sa.delete(_MessageRow).where(_MessageRow.conversation_id.in_(conv_ids))
|
|
)
|
|
await session.execute(
|
|
sa.delete(_ConversationRow).where(_ConversationRow.id.in_(conv_ids))
|
|
)
|
|
await session.commit()
|
|
|
|
logger.info(
|
|
"Cleaned up %d stale conversations (TTL=%ds)", len(conv_ids), ttl_seconds
|
|
)
|