refactor: hexagonal architecture with ports & adapters, DI, and test-first development

Domain layer (zero framework imports):
- domain/models.py: pure dataclasses (RuleDocument, RuleSearchResult,
  Conversation, ChatMessage, LLMResponse, ChatResult)
- domain/ports.py: ABC interfaces (RuleRepository, LLMPort,
  ConversationStore, IssueTracker)
- domain/services.py: ChatService orchestrates Q&A flow using only ports

Outbound adapters (implement domain ports):
- adapters/outbound/openrouter.py: OpenRouterLLM with persistent httpx
  client, robust JSON parsing, regex citation fallback
- adapters/outbound/sqlite_convos.py: SQLiteConversationStore with
  async_sessionmaker, timezone-aware datetimes, cleanup support
- adapters/outbound/gitea_issues.py: GiteaIssueTracker with markdown
  injection protection (fenced code blocks)
- adapters/outbound/chroma_rules.py: ChromaRuleRepository with clamped
  similarity scores

Inbound adapter:
- adapters/inbound/api.py: thin FastAPI router with input validation
  (max_length constraints), proper HTTP status codes (503 for missing LLM)

Configuration & wiring:
- config/settings.py: Pydantic v2 SettingsConfigDict (no module-level singleton)
- config/container.py: create_app() factory with lifespan-managed DI
- main.py: minimal entry point

Test infrastructure (90 tests, all passing):
- tests/fakes/: in-memory implementations of all 4 ports
- tests/domain/: 26 tests for models and ChatService
- tests/adapters/: 64 tests for all adapters using fakes/mocks
- No real API calls, no model downloads, no disk I/O in fast tests

Also fixes: aiosqlite version constraint (>=0.19.0), adds hatch build
targets for new package layout.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Cal Corum 2026-03-08 15:51:16 -05:00
parent c2c7f7d3c2
commit c3218f70c4
34 changed files with 7855 additions and 1 deletions

0
adapters/__init__.py Normal file
View File

View File

165
adapters/inbound/api.py Normal file
View File

@ -0,0 +1,165 @@
"""FastAPI inbound adapter — thin HTTP layer over ChatService.
This module contains only routing / serialisation logic. All business rules
live in domain.services.ChatService; all storage / LLM calls live in outbound
adapters. The router reads ChatService and RuleRepository from app.state so
that the container (config/container.py) remains the single wiring point and
tests can substitute fakes without monkey-patching.
"""
import logging
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field
from domain.ports import RuleRepository
from domain.services import ChatService
logger = logging.getLogger(__name__)
router = APIRouter()
# ---------------------------------------------------------------------------
# Request / response Pydantic models
# ---------------------------------------------------------------------------
class ChatRequest(BaseModel):
"""Payload accepted by POST /chat."""
message: str = Field(
...,
min_length=1,
max_length=4000,
description="The user's question (14000 characters).",
)
user_id: str = Field(
...,
min_length=1,
max_length=64,
description="Opaque caller identifier, e.g. Discord snowflake.",
)
channel_id: str = Field(
...,
min_length=1,
max_length=64,
description="Opaque channel identifier, e.g. Discord channel snowflake.",
)
conversation_id: Optional[str] = Field(
default=None,
description="Continue an existing conversation; omit to start a new one.",
)
parent_message_id: Optional[str] = Field(
default=None,
description="Thread parent message ID for Discord thread replies.",
)
class ChatResponse(BaseModel):
"""Payload returned by POST /chat."""
response: str
conversation_id: str
message_id: str
parent_message_id: Optional[str] = None
cited_rules: list[str]
confidence: float
needs_human: bool
# ---------------------------------------------------------------------------
# Dependency helpers — read from app.state set by the container
# ---------------------------------------------------------------------------
def _get_chat_service(request: Request) -> ChatService:
"""Extract the ChatService wired by the container from app.state."""
return request.app.state.chat_service
def _get_rule_repository(request: Request) -> RuleRepository:
"""Extract the RuleRepository wired by the container from app.state."""
return request.app.state.rule_repository
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/chat", response_model=ChatResponse)
async def chat(
body: ChatRequest,
service: Annotated[ChatService, Depends(_get_chat_service)],
rules: Annotated[RuleRepository, Depends(_get_rule_repository)],
) -> ChatResponse:
"""Handle a rules Q&A request.
Delegates entirely to ChatService.answer_question no business logic here.
Returns HTTP 503 when the LLM adapter cannot be constructed (missing API key)
rather than producing a fake success response, so callers can distinguish
genuine answers from configuration errors.
"""
# The container raises at startup if the API key is required but absent;
# however if the service was created without a real LLM (e.g. missing key
# detected at request time), surface a clear service-unavailable rather than
# leaking a misleading 200 OK.
if not hasattr(service, "llm") or service.llm is None:
raise HTTPException(
status_code=503,
detail="LLM service is not available — check OPENROUTER_API_KEY configuration.",
)
try:
result = await service.answer_question(
message=body.message,
user_id=body.user_id,
channel_id=body.channel_id,
conversation_id=body.conversation_id,
parent_message_id=body.parent_message_id,
)
except Exception as exc:
logger.exception("Unhandled error in ChatService.answer_question")
raise HTTPException(status_code=500, detail=str(exc)) from exc
return ChatResponse(
response=result.response,
conversation_id=result.conversation_id,
message_id=result.message_id,
parent_message_id=result.parent_message_id,
cited_rules=result.cited_rules,
confidence=result.confidence,
needs_human=result.needs_human,
)
@router.get("/health")
async def health(
rules: Annotated[RuleRepository, Depends(_get_rule_repository)],
) -> dict:
"""Return service health and a summary of the loaded knowledge base."""
stats = rules.get_stats()
return {
"status": "healthy",
"rules_count": stats.get("total_rules", 0),
"sections": stats.get("sections", {}),
}
@router.get("/stats")
async def stats(
rules: Annotated[RuleRepository, Depends(_get_rule_repository)],
request: Request,
) -> dict:
"""Return extended statistics about the knowledge base and configuration."""
kb_stats = rules.get_stats()
# Pull optional config snapshot from app.state (set by container).
config_snapshot: dict = getattr(request.app.state, "config_snapshot", {})
return {
"knowledge_base": kb_stats,
"config": config_snapshot,
}

View File

View File

@ -0,0 +1,203 @@
"""ChromaDB outbound adapter implementing the RuleRepository port."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional
import chromadb
from chromadb.config import Settings as ChromaSettings
from sentence_transformers import SentenceTransformer
from domain.models import RuleDocument, RuleSearchResult
from domain.ports import RuleRepository
logger = logging.getLogger(__name__)
_COLLECTION_NAME = "rules"
class ChromaRuleRepository(RuleRepository):
"""Persist and search rules in a ChromaDB vector store.
Parameters
----------
persist_dir:
Directory that ChromaDB uses for on-disk persistence. Created
automatically if it does not exist.
embedding_model:
HuggingFace / sentence-transformers model name used to encode
documents and queries (e.g. ``"all-MiniLM-L6-v2"``).
"""
def __init__(self, persist_dir: Path, embedding_model: str) -> None:
self.persist_dir = Path(persist_dir)
self.persist_dir.mkdir(parents=True, exist_ok=True)
chroma_settings = ChromaSettings(
anonymized_telemetry=False,
is_persist_directory_actually_writable=True,
)
self._client = chromadb.PersistentClient(
path=str(self.persist_dir),
settings=chroma_settings,
)
logger.info("Loading embedding model '%s'", embedding_model)
self._encoder = SentenceTransformer(embedding_model)
logger.info("ChromaRuleRepository ready (persist_dir=%s)", self.persist_dir)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _get_collection(self):
"""Return the rules collection, creating it if absent."""
return self._client.get_or_create_collection(
name=_COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
@staticmethod
def _distance_to_similarity(distance: float) -> float:
"""Convert a cosine distance in [0, 2] to a similarity in [0.0, 1.0].
ChromaDB stores cosine *distance* (0 = identical, 2 = opposite).
The conversion is ``similarity = 1 - distance``, but floating-point
noise can push the result slightly outside [0, 1], so we clamp.
"""
return max(0.0, min(1.0, 1.0 - distance))
# ------------------------------------------------------------------
# RuleRepository port implementation
# ------------------------------------------------------------------
def add_documents(self, docs: list[RuleDocument]) -> None:
"""Embed and store a batch of RuleDocuments.
Calling with an empty list is a no-op.
"""
if not docs:
return
logger.debug("Encoding %d document(s)", len(docs))
ids = [doc.rule_id for doc in docs]
contents = [doc.content for doc in docs]
metadatas = [doc.to_metadata() for doc in docs]
# SentenceTransformer.encode returns a numpy array; .tolist() gives
# a plain Python list which ChromaDB accepts.
embeddings = self._encoder.encode(contents).tolist()
collection = self._get_collection()
collection.add(
ids=ids,
embeddings=embeddings,
documents=contents,
metadatas=metadatas,
)
logger.info("Stored %d rule(s) in ChromaDB", len(docs))
def search(
self,
query: str,
top_k: int = 10,
section_filter: Optional[str] = None,
) -> list[RuleSearchResult]:
"""Return the *top_k* most semantically similar rules for *query*.
Parameters
----------
query:
Natural-language question or keyword string.
top_k:
Maximum number of results to return.
section_filter:
When provided, only documents whose ``section`` metadata field
equals this value are considered.
Returns
-------
list[RuleSearchResult]
Sorted by descending similarity (best match first). Returns an
empty list if the collection is empty.
"""
collection = self._get_collection()
doc_count = collection.count()
if doc_count == 0:
return []
# Clamp top_k so we never ask ChromaDB for more results than exist.
effective_k = min(top_k, doc_count)
query_embedding = self._encoder.encode(query).tolist()
where = {"section": section_filter} if section_filter else None
logger.debug(
"Querying ChromaDB: top_k=%d, section_filter=%r",
effective_k,
section_filter,
)
raw = collection.query(
query_embeddings=[query_embedding],
n_results=effective_k,
where=where,
include=["documents", "metadatas", "distances"],
)
results: list[RuleSearchResult] = []
if raw and raw["documents"] and raw["documents"][0]:
for i, doc_content in enumerate(raw["documents"][0]):
metadata = raw["metadatas"][0][i]
distance = raw["distances"][0][i]
similarity = self._distance_to_similarity(distance)
results.append(
RuleSearchResult(
rule_id=metadata["rule_id"],
title=metadata["title"],
content=doc_content,
section=metadata["section"],
similarity=similarity,
)
)
logger.debug("Search returned %d result(s)", len(results))
return results
def count(self) -> int:
"""Return the total number of rule documents in the collection."""
return self._get_collection().count()
def clear_all(self) -> None:
"""Delete all documents by dropping and recreating the collection."""
logger.info(
"Clearing all rules from ChromaDB collection '%s'", _COLLECTION_NAME
)
self._client.delete_collection(_COLLECTION_NAME)
self._get_collection() # Recreate so subsequent calls do not fail.
def get_stats(self) -> dict:
"""Return a summary dict with total rule count, per-section counts, and path.
Returns
-------
dict with keys:
``total_rules`` (int), ``sections`` (dict[str, int]),
``persist_directory`` (str)
"""
collection = self._get_collection()
raw = collection.get(include=["metadatas"])
sections: dict[str, int] = {}
for metadata in raw.get("metadatas") or []:
section = metadata.get("section", "")
sections[section] = sections.get(section, 0) + 1
return {
"total_rules": collection.count(),
"sections": sections,
"persist_directory": str(self.persist_dir),
}

View File

@ -0,0 +1,167 @@
"""Outbound adapter: Gitea issue tracker.
Implements the IssueTracker port using the Gitea REST API. A single
httpx.AsyncClient is shared across all calls (connection pool reuse); callers
must await close() when the adapter is no longer needed, typically in an
application lifespan handler.
"""
import logging
from typing import Optional
import httpx
from domain.ports import IssueTracker
logger = logging.getLogger(__name__)
_LABELS: list[str] = ["rules-gap", "ai-generated", "needs-review"]
_TITLE_MAX_QUESTION_LEN = 80
class GiteaIssueTracker(IssueTracker):
"""Outbound adapter that creates Gitea issues for unanswered questions.
Args:
token: Personal access token with issue-write permission.
owner: Repository owner (user or org name).
repo: Repository slug.
base_url: Base URL of the Gitea instance, e.g. "https://gitea.example.com".
Trailing slashes are stripped automatically.
"""
def __init__(
self,
token: str,
owner: str,
repo: str,
base_url: str,
) -> None:
self._token = token
self._owner = owner
self._repo = repo
self._base_url = base_url.rstrip("/")
self._headers = {
"Authorization": f"token {token}",
"Content-Type": "application/json",
"Accept": "application/json",
}
self._client = httpx.AsyncClient(
headers=self._headers,
timeout=30.0,
)
# ------------------------------------------------------------------
# IssueTracker port implementation
# ------------------------------------------------------------------
async def create_unanswered_issue(
self,
question: str,
user_id: str,
channel_id: str,
attempted_rules: list[str],
conversation_id: str,
) -> str:
"""Create a Gitea issue for a question the bot could not answer.
The question is embedded in a fenced code block to prevent markdown
injection a user could craft a question that contains headers, links,
or other markdown syntax that would corrupt the issue layout.
Returns:
The HTML URL of the newly created issue.
Raises:
RuntimeError: If the Gitea API responds with a non-2xx status code.
"""
title = self._build_title(question)
body = self._build_body(
question, user_id, channel_id, attempted_rules, conversation_id
)
logger.info(
"Creating Gitea issue for unanswered question from user=%s channel=%s",
user_id,
channel_id,
)
payload: dict = {
"title": title,
"body": body,
"labels": _LABELS,
}
url = f"{self._base_url}/repos/{self._owner}/{self._repo}/issues"
response = await self._client.post(url, json=payload)
if response.status_code not in (200, 201):
error_detail = response.text
logger.error(
"Gitea API returned %s creating issue: %s",
response.status_code,
error_detail,
)
raise RuntimeError(
f"Gitea API error {response.status_code} creating issue: {error_detail}"
)
data = response.json()
html_url: str = data.get("html_url", "")
logger.info("Created Gitea issue: %s", html_url)
return html_url
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
async def close(self) -> None:
"""Release the underlying HTTP connection pool.
Call this in an application shutdown handler (e.g. FastAPI lifespan)
to avoid ResourceWarning on interpreter exit.
"""
await self._client.aclose()
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
@staticmethod
def _build_title(question: str) -> str:
"""Return a short, human-readable issue title."""
truncated = question[:_TITLE_MAX_QUESTION_LEN]
suffix = "..." if len(question) > _TITLE_MAX_QUESTION_LEN else ""
return f"Unanswered rules question: {truncated}{suffix}"
@staticmethod
def _build_body(
question: str,
user_id: str,
channel_id: str,
attempted_rules: list[str],
conversation_id: str,
) -> str:
"""Compose the Gitea issue body with all triage context.
The question is fenced so that markdown special characters in user
input cannot alter the issue structure.
"""
rules_list: str = ", ".join(attempted_rules) if attempted_rules else "None"
return (
"## Unanswered Question\n\n"
f"**User:** {user_id}\n\n"
f"**Channel:** {channel_id}\n\n"
f"**Conversation ID:** {conversation_id}\n\n"
"**Question:**\n"
f"```\n{question}\n```\n\n"
f"**Searched Rules:** {rules_list}\n\n"
"**Additional Context:**\n"
"This question was asked in Discord and the bot could not provide "
"a confident answer. The rules either don't cover this question or "
"the information was ambiguous.\n\n"
"---\n\n"
"*This issue was automatically created by the Strat-Chatbot.*"
)

View File

@ -0,0 +1,253 @@
"""OpenRouter outbound adapter — implements LLMPort via the OpenRouter API.
This module is the sole owner of:
- The SYSTEM_PROMPT for the Strat-O-Matic rules assistant
- All JSON parsing / extraction logic for LLM responses
- The persistent httpx.AsyncClient connection pool
It returns domain.models.LLMResponse exclusively; no legacy app.* types leak
through this boundary.
"""
from __future__ import annotations
import json
import logging
import re
from typing import Optional
import httpx
from domain.models import LLMResponse, RuleSearchResult
from domain.ports import LLMPort
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# System prompt
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """You are a helpful assistant for a Strat-O-Matic baseball league.
Your job is to answer questions about league rules and procedures using the provided rule excerpts.
CRITICAL RULES:
1. ONLY use information from the provided rules. If the rules don't contain the answer, say so clearly.
2. ALWAYS cite rule IDs when referencing a rule (e.g., "Rule 5.2.1(b) states that...")
3. If multiple rules are relevant, cite all of them.
4. If you're uncertain or the rules are ambiguous, say so and suggest asking a league administrator.
5. Keep responses concise but complete. Use examples when helpful from the rules.
6. Do NOT make up rules or infer beyond what's explicitly stated.
When answering:
- Start with a direct answer to the question
- Support with rule citations
- Include relevant details from the rules
- If no relevant rules found, explicitly state: "I don't have a rule that addresses this question."
Response format (JSON):
{
"answer": "Your response text",
"cited_rules": ["rule_id_1", "rule_id_2"],
"confidence": 0.0-1.0,
"needs_human": boolean
}
Higher confidence (0.8-1.0) when rules clearly answer the question.
Lower confidence (0.3-0.7) when rules partially address the question or are ambiguous.
Very low confidence (0.0-0.2) when rules don't address the question at all.
"""
# Regex for extracting rule IDs from free-text answers when cited_rules is empty.
# Matches patterns like "Rule 5.2.1(b)" or "Rule 7.4".
# The character class includes '.' so a sentence-ending period may be captured
# (e.g. "Rule 7.4." → raw match "7.4."). Matches are stripped of a trailing
# dot at the extraction site to normalise IDs like "7.4." → "7.4".
_RULE_ID_PATTERN = re.compile(r"Rule\s+([\d\.\(\)a-b]+)")
# ---------------------------------------------------------------------------
# Adapter
# ---------------------------------------------------------------------------
class OpenRouterLLM(LLMPort):
"""Outbound adapter that calls the OpenRouter chat completions API.
A single httpx.AsyncClient is reused across all calls (connection pooling).
Call ``await adapter.close()`` when tearing down to release the pool.
Args:
api_key: Bearer token for the OpenRouter API.
model: OpenRouter model identifier, e.g. ``"openai/gpt-4o-mini"``.
base_url: Full URL for the chat completions endpoint.
http_client: Optional pre-built httpx.AsyncClient (useful for testing).
When *None* a new client is created with a 120-second timeout.
"""
def __init__(
self,
api_key: str,
model: str,
base_url: str = "https://openrouter.ai/api/v1/chat/completions",
http_client: Optional[httpx.AsyncClient] = None,
) -> None:
if not api_key:
raise ValueError("api_key must not be empty")
self._api_key = api_key
self._model = model
self._base_url = base_url
self._http: httpx.AsyncClient = http_client or httpx.AsyncClient(timeout=120.0)
# ------------------------------------------------------------------
# LLMPort implementation
# ------------------------------------------------------------------
async def generate_response(
self,
question: str,
rules: list[RuleSearchResult],
conversation_history: Optional[list[dict[str, str]]] = None,
) -> LLMResponse:
"""Call the OpenRouter API and return a structured LLMResponse.
Args:
question: The user's natural-language question.
rules: Relevant rule excerpts retrieved from the knowledge base.
conversation_history: Optional list of prior ``{"role": ..., "content": ...}``
dicts. At most the last 6 messages are forwarded to stay within
token budgets.
Returns:
LLMResponse with ``answer``, ``cited_rules``, ``confidence``, and
``needs_human`` populated from the LLM's JSON reply. On parse
failure ``confidence=0.0`` and ``needs_human=True`` signal that
the raw response could not be structured reliably.
Raises:
RuntimeError: When the API returns a non-200 HTTP status.
"""
messages = self._build_messages(question, rules, conversation_history)
logger.debug(
"Sending request to OpenRouter model=%s messages=%d",
self._model,
len(messages),
)
response = await self._http.post(
self._base_url,
headers={
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
},
json={
"model": self._model,
"messages": messages,
"temperature": 0.3,
"max_tokens": 1000,
"top_p": 0.9,
},
)
if response.status_code != 200:
raise RuntimeError(
f"OpenRouter API error: {response.status_code} - {response.text}"
)
result = response.json()
content: str = result["choices"][0]["message"]["content"]
logger.debug("Received response content length=%d", len(content))
return self._parse_content(content, rules)
async def close(self) -> None:
"""Release the underlying HTTP connection pool.
Should be called when the adapter is no longer needed (e.g. on
application shutdown) to avoid resource leaks.
"""
await self._http.aclose()
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _build_messages(
self,
question: str,
rules: list[RuleSearchResult],
conversation_history: Optional[list[dict[str, str]]],
) -> list[dict[str, str]]:
"""Assemble the messages list for the API request."""
if rules:
rules_context = "\n\n".join(
f"Rule {r.rule_id}: {r.title}\n{r.content}" for r in rules
)
context_msg = (
f"Here are the relevant rules for the question:\n\n{rules_context}"
)
else:
context_msg = "No relevant rules were found in the knowledge base."
messages: list[dict[str, str]] = [{"role": "system", "content": SYSTEM_PROMPT}]
if conversation_history:
# Limit to last 6 messages (3 exchanges) to avoid token overflow
messages.extend(conversation_history[-6:])
user_message = (
f"{context_msg}\n\nUser question: {question}\n\n"
"Answer the question based on the rules provided."
)
messages.append({"role": "user", "content": user_message})
return messages
def _parse_content(
self, content: str, rules: list[RuleSearchResult]
) -> LLMResponse:
"""Parse the raw LLM content string into an LLMResponse.
Handles three cases in order:
1. JSON wrapped in a ```json ... ``` markdown fence.
2. Bare JSON string.
3. Plain text (fallback) sets confidence=0.0, needs_human=True.
"""
try:
json_str = self._extract_json_string(content)
parsed = json.loads(json_str)
except (json.JSONDecodeError, KeyError, IndexError) as exc:
logger.warning("Failed to parse LLM response as JSON: %s", exc)
return LLMResponse(
answer=content,
cited_rules=[],
confidence=0.0,
needs_human=True,
)
cited_rules: list[str] = parsed.get("cited_rules", [])
# Regex fallback: if the model omitted cited_rules but mentioned rule
# IDs inline, extract them from the answer text so callers have
# attribution without losing information.
if not cited_rules and rules:
answer_text: str = parsed.get("answer", "")
# Strip a trailing dot from each match to handle sentence-ending
# punctuation (e.g. "Rule 7.4." → "7.4").
matches = [m.rstrip(".") for m in _RULE_ID_PATTERN.findall(answer_text)]
cited_rules = list(dict.fromkeys(matches)) # deduplicate, preserve order
return LLMResponse(
answer=parsed["answer"],
cited_rules=cited_rules,
confidence=float(parsed.get("confidence", 0.5)),
needs_human=bool(parsed.get("needs_human", False)),
)
@staticmethod
def _extract_json_string(content: str) -> str:
"""Strip optional markdown fences and return the raw JSON string."""
if "```json" in content:
return content.split("```json")[1].split("```")[0].strip()
return content.strip()

View File

@ -0,0 +1,277 @@
"""SQLite outbound adapter implementing the ConversationStore port.
Uses SQLAlchemy 2.x async API with aiosqlite as the driver. Designed to be
instantiated once at application startup; call `await init_db()` before use.
"""
import logging
import uuid
from datetime import datetime, timedelta, timezone
from typing import Optional
import sqlalchemy as sa
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, select
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.orm import declarative_base
from domain.ports import ConversationStore
logger = logging.getLogger(__name__)
Base = declarative_base()
# ---------------------------------------------------------------------------
# ORM table definitions
# ---------------------------------------------------------------------------
class _ConversationRow(Base):
"""SQLAlchemy table model for a conversation session."""
__tablename__ = "conversations"
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
channel_id = Column(String, nullable=False)
created_at = Column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
last_activity = Column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
class _MessageRow(Base):
"""SQLAlchemy table model for a single chat message."""
__tablename__ = "messages"
id = Column(String, primary_key=True)
conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False)
content = Column(String, nullable=False)
is_user = Column(Boolean, nullable=False)
parent_id = Column(String, ForeignKey("messages.id"), nullable=True)
created_at = Column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
# ---------------------------------------------------------------------------
# Adapter
# ---------------------------------------------------------------------------
class SQLiteConversationStore(ConversationStore):
"""Persists conversation state to a SQLite database via SQLAlchemy async.
Parameters
----------
db_url:
SQLAlchemy async connection URL, e.g.
``"sqlite+aiosqlite:///path/to/conversations.db"`` or
``"sqlite+aiosqlite://"`` for an in-memory database.
"""
def __init__(self, db_url: str) -> None:
self._engine = create_async_engine(db_url, echo=False)
# async_sessionmaker is the modern (SQLAlchemy 2.0) replacement for
# sessionmaker(class_=AsyncSession, ...).
self._session_factory: async_sessionmaker[AsyncSession] = async_sessionmaker(
self._engine, expire_on_commit=False
)
async def init_db(self) -> None:
"""Create database tables if they do not already exist.
Must be called before any other method.
"""
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.debug("Database tables initialised")
# ------------------------------------------------------------------
# ConversationStore implementation
# ------------------------------------------------------------------
async def get_or_create_conversation(
self,
user_id: str,
channel_id: str,
conversation_id: Optional[str] = None,
) -> str:
"""Return *conversation_id* if it exists in the DB, otherwise create a
new conversation row and return its fresh ID.
If *conversation_id* is supplied but not found (e.g. after a restart
with a clean in-memory DB), a new conversation is created transparently
rather than raising an error.
"""
async with self._session_factory() as session:
if conversation_id:
result = await session.execute(
select(_ConversationRow).where(
_ConversationRow.id == conversation_id
)
)
row = result.scalar_one_or_none()
if row is not None:
row.last_activity = datetime.now(timezone.utc)
await session.commit()
logger.debug(
"Resumed existing conversation %s for user %s",
conversation_id,
user_id,
)
return row.id
logger.warning(
"Conversation %s not found; creating a new one", conversation_id
)
new_id = str(uuid.uuid4())
session.add(
_ConversationRow(
id=new_id,
user_id=user_id,
channel_id=channel_id,
created_at=datetime.now(timezone.utc),
last_activity=datetime.now(timezone.utc),
)
)
await session.commit()
logger.debug(
"Created conversation %s for user %s in channel %s",
new_id,
user_id,
channel_id,
)
return new_id
async def add_message(
self,
conversation_id: str,
content: str,
is_user: bool,
parent_id: Optional[str] = None,
) -> str:
"""Append a message to *conversation_id* and update last_activity.
Returns the new message's UUID string.
"""
message_id = str(uuid.uuid4())
now = datetime.now(timezone.utc)
async with self._session_factory() as session:
session.add(
_MessageRow(
id=message_id,
conversation_id=conversation_id,
content=content,
is_user=is_user,
parent_id=parent_id,
created_at=now,
)
)
# Bump last_activity on the parent conversation.
result = await session.execute(
select(_ConversationRow).where(_ConversationRow.id == conversation_id)
)
conv = result.scalar_one_or_none()
if conv is not None:
conv.last_activity = now
else:
logger.warning(
"add_message: conversation %s not found; message stored orphaned",
conversation_id,
)
await session.commit()
logger.debug(
"Added message %s to conversation %s (is_user=%s)",
message_id,
conversation_id,
is_user,
)
return message_id
async def get_conversation_history(
self, conversation_id: str, limit: int = 10
) -> list[dict[str, str]]:
"""Return the most-recent *limit* messages in chronological order.
The query fetches the newest rows (ORDER BY created_at DESC LIMIT n)
then reverses the list so callers receive oldest-first ordering,
which is what LLM APIs expect in the ``messages`` array.
Returns a list of ``{"role": "user"|"assistant", "content": "..."}``
dicts compatible with the OpenAI chat-completion format.
"""
async with self._session_factory() as session:
result = await session.execute(
select(_MessageRow)
.where(_MessageRow.conversation_id == conversation_id)
.order_by(_MessageRow.created_at.desc())
.limit(limit)
)
rows = result.scalars().all()
# Reverse so the list is oldest → newest (chronological).
history: list[dict[str, str]] = [
{
"role": "user" if row.is_user else "assistant",
"content": row.content,
}
for row in reversed(rows)
]
logger.debug(
"Retrieved %d messages for conversation %s (limit=%d)",
len(history),
conversation_id,
limit,
)
return history
# ------------------------------------------------------------------
# Housekeeping
# ------------------------------------------------------------------
async def cleanup_old_conversations(self, ttl_seconds: int = 1800) -> None:
"""Delete conversations (and their messages) older than *ttl_seconds*.
Useful as a periodic background task to keep the database small.
Messages are deleted first to satisfy the foreign-key constraint even
in databases where cascade deletes are not configured.
"""
cutoff = datetime.now(timezone.utc) - timedelta(seconds=ttl_seconds)
async with self._session_factory() as session:
result = await session.execute(
select(_ConversationRow).where(_ConversationRow.last_activity < cutoff)
)
old_rows = result.scalars().all()
conv_ids = [row.id for row in old_rows]
if not conv_ids:
logger.debug("cleanup_old_conversations: nothing to remove")
return
await session.execute(
sa.delete(_MessageRow).where(_MessageRow.conversation_id.in_(conv_ids))
)
await session.execute(
sa.delete(_ConversationRow).where(_ConversationRow.id.in_(conv_ids))
)
await session.commit()
logger.info(
"Cleaned up %d stale conversations (TTL=%ds)", len(conv_ids), ttl_seconds
)

0
config/__init__.py Normal file
View File

163
config/container.py Normal file
View File

@ -0,0 +1,163 @@
"""Dependency wiring — the single place that constructs all adapters.
create_app() is the composition root for the production runtime. Tests use
the make_test_app() factory in tests/adapters/test_api.py instead (which
wires fakes directly into app.state, bypassing this module entirely).
Why a factory function instead of module-level globals
-------------------------------------------------------
- Makes the lifespan scope explicit: adapters are created inside the lifespan
context manager and torn down on exit.
- Avoids the singleton-state problems that plague import-time construction:
tests can call create_app() in isolation without shared state.
- Follows the hexagonal architecture principle that wiring is infrastructure,
not domain logic.
"""
import logging
from contextlib import asynccontextmanager
from typing import AsyncIterator
from fastapi import FastAPI
from adapters.inbound.api import router
from adapters.outbound.chroma_rules import ChromaRuleRepository
from adapters.outbound.gitea_issues import GiteaIssueTracker
from adapters.outbound.openrouter import OpenRouterLLM
from adapters.outbound.sqlite_convos import SQLiteConversationStore
from config.settings import Settings
from domain.services import ChatService
logger = logging.getLogger(__name__)
def _make_lifespan(settings: Settings):
"""Return an async context manager that owns the adapter lifecycle.
Accepts Settings so the lifespan closure captures a specific configuration
instance rather than reading from module-level state.
"""
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
# ------------------------------------------------------------------
# Startup
# ------------------------------------------------------------------
logger.info("Initialising Strat-Chatbot...")
print("Initialising Strat-Chatbot...")
# Ensure required directories exist
settings.data_dir.mkdir(parents=True, exist_ok=True)
settings.chroma_dir.mkdir(parents=True, exist_ok=True)
# Vector store (synchronous adapter — no async init needed)
chroma_repo = ChromaRuleRepository(
persist_dir=settings.chroma_dir,
embedding_model=settings.embedding_model,
)
rule_count = chroma_repo.count()
print(f"ChromaDB ready at {settings.chroma_dir} ({rule_count} rules loaded)")
# SQLite conversation store
conv_store = SQLiteConversationStore(db_url=settings.db_url)
await conv_store.init_db()
print("SQLite conversation store initialised")
# LLM adapter — only constructed when an API key is present
llm: OpenRouterLLM | None = None
if settings.openrouter_api_key:
llm = OpenRouterLLM(
api_key=settings.openrouter_api_key,
model=settings.openrouter_model,
)
print(f"OpenRouter LLM ready (model: {settings.openrouter_model})")
else:
logger.warning(
"OPENROUTER_API_KEY not set — LLM adapter disabled. "
"POST /chat will return HTTP 503."
)
print(
"WARNING: OPENROUTER_API_KEY not set — "
"POST /chat will return HTTP 503."
)
# Gitea issue tracker — optional
gitea: GiteaIssueTracker | None = None
if settings.gitea_token:
gitea = GiteaIssueTracker(
token=settings.gitea_token,
owner=settings.gitea_owner,
repo=settings.gitea_repo,
base_url=settings.gitea_base_url,
)
print(
f"Gitea issue tracker ready "
f"({settings.gitea_owner}/{settings.gitea_repo})"
)
# Compose the service from its ports
service = ChatService(
rules=chroma_repo,
llm=llm, # type: ignore[arg-type] # None is handled at the router level
conversations=conv_store,
issues=gitea,
top_k_rules=settings.top_k_rules,
)
# Expose via app.state for the router's Depends helpers
app.state.chat_service = service
app.state.rule_repository = chroma_repo
app.state.config_snapshot = {
"openrouter_model": settings.openrouter_model,
"top_k_rules": settings.top_k_rules,
"embedding_model": settings.embedding_model,
}
print("Strat-Chatbot ready!")
logger.info("Strat-Chatbot ready")
yield
# ------------------------------------------------------------------
# Shutdown — release HTTP connection pools
# ------------------------------------------------------------------
logger.info("Shutting down Strat-Chatbot...")
print("Shutting down...")
if llm is not None:
await llm.close()
logger.debug("OpenRouterLLM HTTP client closed")
if gitea is not None:
await gitea.close()
logger.debug("GiteaIssueTracker HTTP client closed")
logger.info("Shutdown complete")
return lifespan
def create_app(settings: Settings | None = None) -> FastAPI:
"""Construct and return the production FastAPI application.
Args:
settings: Optional pre-built Settings instance. When *None* (the
common case), a new Settings() is constructed which reads from
environment variables and the .env file automatically.
Returns:
A fully-wired FastAPI application ready to be served by uvicorn.
"""
if settings is None:
settings = Settings()
app = FastAPI(
title="Strat-Chatbot",
description="Strat-O-Matic rules Q&A API",
version="0.1.0",
lifespan=_make_lifespan(settings),
)
app.include_router(router)
return app

74
config/settings.py Normal file
View File

@ -0,0 +1,74 @@
"""Application settings — Pydantic v2 style, no module-level singleton.
The container (config/container.py) instantiates Settings once at startup
and passes it down to adapters. This keeps tests free of singleton state.
"""
from pathlib import Path
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""All runtime configuration, sourced from environment variables or a .env file.
Fields use explicit ``env=`` aliases so the variable names are immediately
visible and grep-able without needing to know Pydantic's casing rules.
"""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
# Allow unknown env vars — avoids breakage when the .env file has
# variables that belong to other services (Discord bot, scripts, etc.).
extra="ignore",
)
# ------------------------------------------------------------------
# OpenRouter / LLM
# ------------------------------------------------------------------
openrouter_api_key: str = Field(default="", alias="OPENROUTER_API_KEY")
openrouter_model: str = Field(
default="stepfun/step-3.5-flash:free", alias="OPENROUTER_MODEL"
)
# ------------------------------------------------------------------
# Discord
# ------------------------------------------------------------------
discord_bot_token: str = Field(default="", alias="DISCORD_BOT_TOKEN")
discord_guild_id: Optional[str] = Field(default=None, alias="DISCORD_GUILD_ID")
# ------------------------------------------------------------------
# Gitea issue tracker
# ------------------------------------------------------------------
gitea_token: str = Field(default="", alias="GITEA_TOKEN")
gitea_owner: str = Field(default="cal", alias="GITEA_OWNER")
gitea_repo: str = Field(default="strat-chatbot", alias="GITEA_REPO")
gitea_base_url: str = Field(
default="https://git.manticorum.com/api/v1", alias="GITEA_BASE_URL"
)
# ------------------------------------------------------------------
# File-system paths
# ------------------------------------------------------------------
data_dir: Path = Field(default=Path("./data"), alias="DATA_DIR")
rules_dir: Path = Field(default=Path("./data/rules"), alias="RULES_DIR")
chroma_dir: Path = Field(default=Path("./data/chroma"), alias="CHROMA_DIR")
# ------------------------------------------------------------------
# Database
# ------------------------------------------------------------------
db_url: str = Field(
default="sqlite+aiosqlite:///./data/conversations.db", alias="DB_URL"
)
# ------------------------------------------------------------------
# Conversation / retrieval tuning
# ------------------------------------------------------------------
conversation_ttl: int = Field(default=1800, alias="CONVERSATION_TTL")
top_k_rules: int = Field(default=10, alias="TOP_K_RULES")
embedding_model: str = Field(
default="sentence-transformers/all-MiniLM-L6-v2", alias="EMBEDDING_MODEL"
)

0
domain/__init__.py Normal file
View File

92
domain/models.py Normal file
View File

@ -0,0 +1,92 @@
"""Pure domain models — no framework imports (no FastAPI, SQLAlchemy, httpx, etc.)."""
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Optional
@dataclass
class RuleDocument:
"""A rule from the knowledge base with metadata."""
rule_id: str
title: str
section: str
content: str
source_file: str
parent_rule: Optional[str] = None
page_ref: Optional[str] = None
def to_metadata(self) -> dict[str, str]:
"""Flat dict suitable for vector store metadata (no None values)."""
return {
"rule_id": self.rule_id,
"title": self.title,
"section": self.section,
"parent_rule": self.parent_rule or "",
"page_ref": self.page_ref or "",
"source_file": self.source_file,
}
@dataclass
class RuleSearchResult:
"""A rule returned from semantic search with a similarity score."""
rule_id: str
title: str
content: str
section: str
similarity: float
def __post_init__(self):
if not (0.0 <= self.similarity <= 1.0):
raise ValueError(
f"similarity must be between 0.0 and 1.0, got {self.similarity}"
)
@dataclass
class Conversation:
"""A chat session between a user and the bot."""
id: str
user_id: str
channel_id: str
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
last_activity: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@dataclass
class ChatMessage:
"""A single message in a conversation."""
id: str
conversation_id: str
content: str
is_user: bool
parent_id: Optional[str] = None
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@dataclass
class LLMResponse:
"""Structured response from the LLM."""
answer: str
cited_rules: list[str] = field(default_factory=list)
confidence: float = 0.5
needs_human: bool = False
@dataclass
class ChatResult:
"""Final result returned by ChatService to inbound adapters."""
response: str
conversation_id: str
message_id: str
cited_rules: list[str]
confidence: float
needs_human: bool
parent_message_id: Optional[str] = None

79
domain/ports.py Normal file
View File

@ -0,0 +1,79 @@
"""Port interfaces — abstract contracts the domain needs from the outside world.
No framework imports allowed. Adapters implement these ABCs.
"""
from abc import ABC, abstractmethod
from typing import Optional
from .models import RuleDocument, RuleSearchResult, LLMResponse
class RuleRepository(ABC):
"""Port for storing and searching rules in a vector knowledge base."""
@abstractmethod
def add_documents(self, docs: list[RuleDocument]) -> None: ...
@abstractmethod
def search(
self, query: str, top_k: int = 10, section_filter: Optional[str] = None
) -> list[RuleSearchResult]: ...
@abstractmethod
def count(self) -> int: ...
@abstractmethod
def clear_all(self) -> None: ...
@abstractmethod
def get_stats(self) -> dict: ...
class LLMPort(ABC):
"""Port for generating answers from an LLM given rules context."""
@abstractmethod
async def generate_response(
self,
question: str,
rules: list[RuleSearchResult],
conversation_history: Optional[list[dict[str, str]]] = None,
) -> LLMResponse: ...
class ConversationStore(ABC):
"""Port for persisting conversation state."""
@abstractmethod
async def get_or_create_conversation(
self, user_id: str, channel_id: str, conversation_id: Optional[str] = None
) -> str: ...
@abstractmethod
async def add_message(
self,
conversation_id: str,
content: str,
is_user: bool,
parent_id: Optional[str] = None,
) -> str: ...
@abstractmethod
async def get_conversation_history(
self, conversation_id: str, limit: int = 10
) -> list[dict[str, str]]: ...
class IssueTracker(ABC):
"""Port for creating issues when questions can't be answered."""
@abstractmethod
async def create_unanswered_issue(
self,
question: str,
user_id: str,
channel_id: str,
attempted_rules: list[str],
conversation_id: str,
) -> str: ...

106
domain/services.py Normal file
View File

@ -0,0 +1,106 @@
"""Domain services — core business logic with no framework dependencies.
ChatService orchestrates the Q&A flow using only domain ports.
"""
import logging
from typing import Optional
from .models import ChatResult
from .ports import RuleRepository, LLMPort, ConversationStore, IssueTracker
logger = logging.getLogger(__name__)
CONFIDENCE_THRESHOLD = 0.4
class ChatService:
"""Orchestrates the rules Q&A use case.
All external dependencies are injected via ports this class has zero
knowledge of ChromaDB, OpenRouter, SQLite, or Gitea.
"""
def __init__(
self,
rules: RuleRepository,
llm: LLMPort,
conversations: ConversationStore,
issues: Optional[IssueTracker] = None,
top_k_rules: int = 10,
):
self.rules = rules
self.llm = llm
self.conversations = conversations
self.issues = issues
self.top_k_rules = top_k_rules
async def answer_question(
self,
message: str,
user_id: str,
channel_id: str,
conversation_id: Optional[str] = None,
parent_message_id: Optional[str] = None,
) -> ChatResult:
"""Full Q&A flow: search rules → get history → call LLM → persist → maybe create issue."""
# Get or create conversation
conv_id = await self.conversations.get_or_create_conversation(
user_id=user_id,
channel_id=channel_id,
conversation_id=conversation_id,
)
# Save user message
user_msg_id = await self.conversations.add_message(
conversation_id=conv_id,
content=message,
is_user=True,
parent_id=parent_message_id,
)
# Search for relevant rules
search_results = self.rules.search(query=message, top_k=self.top_k_rules)
# Get conversation history for context
history = await self.conversations.get_conversation_history(conv_id, limit=10)
# Generate response from LLM
llm_response = await self.llm.generate_response(
question=message,
rules=search_results,
conversation_history=history,
)
# Save assistant message
assistant_msg_id = await self.conversations.add_message(
conversation_id=conv_id,
content=llm_response.answer,
is_user=False,
parent_id=user_msg_id,
)
# Create issue if confidence is low or human review needed
if self.issues and (
llm_response.needs_human or llm_response.confidence < CONFIDENCE_THRESHOLD
):
try:
await self.issues.create_unanswered_issue(
question=message,
user_id=user_id,
channel_id=channel_id,
attempted_rules=[r.rule_id for r in search_results],
conversation_id=conv_id,
)
except Exception:
logger.exception("Failed to create issue for unanswered question")
return ChatResult(
response=llm_response.answer,
conversation_id=conv_id,
message_id=assistant_msg_id,
parent_message_id=user_msg_id,
cited_rules=llm_response.cited_rules,
confidence=llm_response.confidence,
needs_human=llm_response.needs_human,
)

27
main.py Normal file
View File

@ -0,0 +1,27 @@
"""Application entry point for the hexagonal-architecture refactor.
Run directly:
uv run python main.py
Or via uvicorn:
uv run uvicorn main:app --reload --host 0.0.0.0 --port 8000
The old entry point (app/main.py) remains in place for reference until the
migration is complete.
"""
import uvicorn
from config.container import create_app
# create_app() reads Settings from env / .env and wires all adapters.
# The lifespan (startup/shutdown) is attached to the returned FastAPI instance.
app = create_app()
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True,
)

View File

@ -12,7 +12,7 @@ dependencies = [
"openai>=1.0.0",
"python-dotenv>=1.0.0",
"sqlalchemy>=2.0.0",
"aiosqlite>=2.0.0",
"aiosqlite>=0.19.0",
"pydantic>=2.0.0",
"pydantic-settings>=2.0.0",
"httpx>=0.27.0",
@ -30,6 +30,9 @@ dev = [
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["domain", "adapters", "config", "app"]
[tool.black]
line-length = 88
target-version = ['py311']

7
pyrightconfig.json Normal file
View File

@ -0,0 +1,7 @@
{
"include": ["domain", "adapters", "config", "tests"],
"extraPaths": ["."],
"pythonVersion": "3.11",
"typeCheckingMode": "basic",
"reportMissingImports": "warning"
}

0
tests/__init__.py Normal file
View File

View File

438
tests/adapters/test_api.py Normal file
View File

@ -0,0 +1,438 @@
"""Tests for the FastAPI inbound adapter (adapters/inbound/api.py).
Strategy
--------
We build a minimal FastAPI app in each fixture by wiring fakes into app.state,
then drive it with httpx.AsyncClient using ASGITransport so no real HTTP server
is needed. This means:
- No real ChromaDB, SQLite, OpenRouter, or Gitea calls.
- Tests are fast, deterministic, and isolated.
- The test app mirrors exactly what the production container does the only
difference is which objects sit in app.state.
What is tested
--------------
- POST /chat returns 200 and a well-formed ChatResponse for a normal message.
- POST /chat stores the conversation and returns a stable conversation_id on a
second call with the same conversation_id (conversation continuation).
- GET /health returns {"status": "healthy", ...} with rule counts.
- GET /stats returns a knowledge_base sub-dict and a config sub-dict.
- POST /chat with missing required fields returns HTTP 422 (Unprocessable Entity).
- POST /chat with a message that exceeds 4000 characters returns HTTP 422.
- POST /chat with a user_id that exceeds 64 characters returns HTTP 422.
- POST /chat when ChatService.answer_question raises returns HTTP 500.
"""
from __future__ import annotations
import pytest
import httpx
from fastapi import FastAPI
from httpx import ASGITransport
from domain.models import RuleDocument
from domain.services import ChatService
from adapters.inbound.api import router
from tests.fakes import (
FakeRuleRepository,
FakeLLM,
FakeConversationStore,
FakeIssueTracker,
)
# ---------------------------------------------------------------------------
# Test app factory
# ---------------------------------------------------------------------------
def make_test_app(
*,
rules: FakeRuleRepository | None = None,
llm: FakeLLM | None = None,
conversations: FakeConversationStore | None = None,
issues: FakeIssueTracker | None = None,
top_k_rules: int = 5,
) -> FastAPI:
"""Build a minimal FastAPI app with fakes wired into app.state.
The factory mirrors what config/container.py does in production, but uses
in-memory fakes so no external services are needed. Each test that calls
this gets a fresh, isolated set of fakes unless shared fixtures are passed.
"""
_rules = rules or FakeRuleRepository()
_llm = llm or FakeLLM()
_conversations = conversations or FakeConversationStore()
_issues = issues or FakeIssueTracker()
service = ChatService(
rules=_rules,
llm=_llm,
conversations=_conversations,
issues=_issues,
top_k_rules=top_k_rules,
)
app = FastAPI()
app.include_router(router)
app.state.chat_service = service
app.state.rule_repository = _rules
app.state.config_snapshot = {
"openrouter_model": "fake-model",
"top_k_rules": top_k_rules,
"embedding_model": "fake-embeddings",
}
return app
# ---------------------------------------------------------------------------
# Shared fixture: an async client backed by the test app
# ---------------------------------------------------------------------------
@pytest.fixture()
async def client() -> httpx.AsyncClient:
"""Return an AsyncClient wired to a fresh test app.
Each test function gets its own completely isolated set of fakes so that
state from one test cannot leak into another.
"""
app = make_test_app()
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
yield ac
# ---------------------------------------------------------------------------
# POST /chat — successful response
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_returns_200_with_valid_payload(client: httpx.AsyncClient):
"""A well-formed POST /chat request must return HTTP 200 and a response body
that maps one-to-one with the ChatResponse Pydantic model.
We verify every field so a structural change to ChatResult or ChatResponse
is caught immediately rather than silently producing a wrong value.
"""
payload = {
"message": "How many strikes to strike out?",
"user_id": "user-001",
"channel_id": "channel-001",
}
resp = await client.post("/chat", json=payload)
assert resp.status_code == 200
body = resp.json()
assert isinstance(body["response"], str)
assert len(body["response"]) > 0
assert isinstance(body["conversation_id"], str)
assert isinstance(body["message_id"], str)
assert isinstance(body["cited_rules"], list)
assert isinstance(body["confidence"], float)
assert isinstance(body["needs_human"], bool)
@pytest.mark.asyncio
async def test_chat_uses_rules_when_available():
"""When the FakeRuleRepository has documents matching the query, the FakeLLM
receives them and returns a high-confidence answer with cited_rules populated.
This exercises the full ChatService flow through the inbound adapter.
"""
rules_repo = FakeRuleRepository()
rules_repo.add_documents(
[
RuleDocument(
rule_id="1.1",
title="Batting Order",
section="Batting",
content="A batter gets three strikes before striking out.",
source_file="rules.pdf",
)
]
)
app = make_test_app(rules=rules_repo)
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.post(
"/chat",
json={
"message": "How many strikes before a batter strikes out?",
"user_id": "user-abc",
"channel_id": "ch-xyz",
},
)
assert resp.status_code == 200
body = resp.json()
# FakeLLM returns cited_rules when rules are found
assert len(body["cited_rules"]) > 0
assert body["confidence"] > 0.5
# ---------------------------------------------------------------------------
# POST /chat — conversation continuation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_continues_existing_conversation():
"""Supplying conversation_id in the request should resume the same
conversation rather than creating a new one.
We make two requests: the first creates a conversation and returns its ID;
the second passes that ID back and must return the same conversation_id.
This ensures the FakeConversationStore (and real SQLite adapter) behave
consistently from the router's perspective.
"""
conversations = FakeConversationStore()
app = make_test_app(conversations=conversations)
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
# First turn — no conversation_id
resp1 = await ac.post(
"/chat",
json={
"message": "First question",
"user_id": "user-42",
"channel_id": "ch-1",
},
)
assert resp1.status_code == 200
conv_id = resp1.json()["conversation_id"]
# Second turn — same conversation
resp2 = await ac.post(
"/chat",
json={
"message": "Follow-up question",
"user_id": "user-42",
"channel_id": "ch-1",
"conversation_id": conv_id,
},
)
assert resp2.status_code == 200
assert resp2.json()["conversation_id"] == conv_id
# ---------------------------------------------------------------------------
# GET /health
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_health_returns_healthy_status(client: httpx.AsyncClient):
"""GET /health must return {"status": "healthy", ...} with integer rule count
and a sections dict.
The FakeRuleRepository starts empty so rules_count should be 0.
"""
resp = await client.get("/health")
assert resp.status_code == 200
body = resp.json()
assert body["status"] == "healthy"
assert isinstance(body["rules_count"], int)
assert isinstance(body["sections"], dict)
@pytest.mark.asyncio
async def test_health_reflects_loaded_rules():
"""After adding documents to FakeRuleRepository, GET /health must show the
updated rule count. This confirms the router reads a live reference to the
repository, not a snapshot taken at startup.
"""
rules_repo = FakeRuleRepository()
rules_repo.add_documents(
[
RuleDocument(
rule_id="2.1",
title="Pitching",
section="Pitching",
content="The pitcher throws the ball.",
source_file="rules.pdf",
)
]
)
app = make_test_app(rules=rules_repo)
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.get("/health")
assert resp.status_code == 200
assert resp.json()["rules_count"] == 1
# ---------------------------------------------------------------------------
# GET /stats
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_stats_returns_knowledge_base_and_config(client: httpx.AsyncClient):
"""GET /stats must include a knowledge_base sub-dict (from RuleRepository.get_stats)
and a config sub-dict (from app.state.config_snapshot set by the container).
This ensures the stats endpoint exposes enough information for an operator
to confirm what model and retrieval settings are active.
"""
resp = await client.get("/stats")
assert resp.status_code == 200
body = resp.json()
assert "knowledge_base" in body
assert "config" in body
assert "total_rules" in body["knowledge_base"]
# ---------------------------------------------------------------------------
# POST /chat — validation errors (HTTP 422)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_missing_message_returns_422(client: httpx.AsyncClient):
"""Omitting the required 'message' field must trigger Pydantic validation and
return HTTP 422 Unprocessable Entity with a detail array describing the error.
We do NOT want a 500 a missing field is a client error, not a server error.
"""
resp = await client.post("/chat", json={"user_id": "u1", "channel_id": "ch1"})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_missing_user_id_returns_422(client: httpx.AsyncClient):
"""Omitting 'user_id' must return HTTP 422."""
resp = await client.post("/chat", json={"message": "Hello", "channel_id": "ch1"})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_missing_channel_id_returns_422(client: httpx.AsyncClient):
"""Omitting 'channel_id' must return HTTP 422."""
resp = await client.post("/chat", json={"message": "Hello", "user_id": "u1"})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_message_too_long_returns_422(client: httpx.AsyncClient):
"""A message that exceeds 4000 characters must fail field-level validation
and return HTTP 422 rather than passing to the service layer.
The max_length constraint on ChatRequest.message enforces this.
"""
long_message = "x" * 4001
resp = await client.post(
"/chat",
json={"message": long_message, "user_id": "u1", "channel_id": "ch1"},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_user_id_too_long_returns_422(client: httpx.AsyncClient):
"""A user_id that exceeds 64 characters must return HTTP 422.
Discord snowflakes are at most 20 digits; 64 chars is a generous cap that
still prevents runaway strings from reaching the database layer.
"""
long_user_id = "u" * 65
resp = await client.post(
"/chat",
json={"message": "Hello", "user_id": long_user_id, "channel_id": "ch1"},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_channel_id_too_long_returns_422(client: httpx.AsyncClient):
"""A channel_id that exceeds 64 characters must return HTTP 422."""
long_channel_id = "c" * 65
resp = await client.post(
"/chat",
json={"message": "Hello", "user_id": "u1", "channel_id": long_channel_id},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_empty_message_returns_422(client: httpx.AsyncClient):
"""An empty string for 'message' must fail min_length=1 and return HTTP 422.
We never want an empty string propagated to the LLM it would produce a
confusing response and waste tokens.
"""
resp = await client.post(
"/chat", json={"message": "", "user_id": "u1", "channel_id": "ch1"}
)
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# POST /chat — service-layer exception bubbles up as 500
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_service_exception_returns_500():
"""When ChatService.answer_question raises an unexpected exception the router
must catch it and return HTTP 500, not let the exception propagate and crash
the server process.
We use FakeLLM(force_error=...) to inject the failure deterministically.
"""
broken_llm = FakeLLM(force_error=RuntimeError("LLM exploded"))
app = make_test_app(llm=broken_llm)
async with httpx.AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
resp = await ac.post(
"/chat",
json={"message": "Hello", "user_id": "u1", "channel_id": "ch1"},
)
assert resp.status_code == 500
assert "LLM exploded" in resp.json()["detail"]
# ---------------------------------------------------------------------------
# POST /chat — parent_message_id thread reply
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_chat_with_parent_message_id_returns_200(client: httpx.AsyncClient):
"""Supplying the optional parent_message_id must not cause an error.
The field passes through to ChatService and ends up in the conversation
store. We just assert a 200 here the service-layer tests cover the
parent_id wiring in more detail.
"""
resp = await client.post(
"/chat",
json={
"message": "Thread reply",
"user_id": "u1",
"channel_id": "ch1",
"parent_message_id": "some-parent-uuid",
},
)
assert resp.status_code == 200
body = resp.json()
# The response's parent_message_id is the user turn message id,
# not the one we passed in — that's the service's threading model.
assert body["parent_message_id"] is not None

View File

@ -0,0 +1,403 @@
"""Tests for the ChromaRuleRepository outbound adapter.
Uses ChromaDB's ephemeral (in-memory) client so no files are written to disk
and no cleanup is needed between runs.
All tests are marked ``slow`` because constructing a SentenceTransformer
downloads a ~100 MB model on a cold cache. Skip the entire module when the
sentence-transformers package is absent so the rest of the test suite still
passes in a minimal CI environment.
"""
from __future__ import annotations
import pytest
# ---------------------------------------------------------------------------
# Optional-import guard: skip the whole module if sentence-transformers is
# not installed (avoids a hard ImportError in minimal environments).
# ---------------------------------------------------------------------------
sentence_transformers = pytest.importorskip(
"sentence_transformers",
reason="sentence-transformers not installed; skipping ChromaDB adapter tests",
)
from unittest.mock import MagicMock, patch # noqa: E402
import chromadb # noqa: E402 (after importorskip guard)
from adapters.outbound.chroma_rules import ChromaRuleRepository # noqa: E402
from domain.models import RuleDocument, RuleSearchResult # noqa: E402
from domain.ports import RuleRepository # noqa: E402
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
def _make_doc(
rule_id: str = "1.0",
title: str = "Test Rule",
section: str = "Section 1",
content: str = "This is the content of the rule.",
source_file: str = "rules/test.md",
parent_rule: str | None = None,
page_ref: str | None = None,
) -> RuleDocument:
"""Factory for RuleDocument with sensible defaults."""
return RuleDocument(
rule_id=rule_id,
title=title,
section=section,
content=content,
source_file=source_file,
parent_rule=parent_rule,
page_ref=page_ref,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def embedding_model_mock():
"""
Return a lightweight mock for SentenceTransformer so the tests do not
download the real model unless running in a full environment.
The mock's ``encode`` method returns a fixed-length float list that is
valid for ChromaDB (32-dimensional vector). Using the same vector for
every document means cosine distance will be 0 (similarity == 1), which
lets us assert similarity >= 0 without caring about ranking.
"""
mock = MagicMock()
# Single-doc encode returns a 1-D array-like; batch returns 2-D list.
fixed_vector = [0.1] * 32
def encode(texts, **kwargs):
if isinstance(texts, str):
return fixed_vector
# Batch: return one vector per document
return [fixed_vector for _ in texts]
mock.encode.side_effect = encode
return mock
@pytest.fixture()
def repo(embedding_model_mock, tmp_path):
"""
ChromaRuleRepository backed by an ephemeral (in-memory) ChromaDB client.
We patch:
- ``chromadb.EphemeralClient`` is injected via monkeypatching the client
factory inside the adapter so nothing is written to ``tmp_path``.
- ``SentenceTransformer`` is replaced with ``embedding_model_mock`` so
no model download occurs.
``tmp_path`` is still passed to satisfy the constructor signature even
though the ephemeral client ignores it.
"""
ephemeral_client = chromadb.EphemeralClient()
with (
patch(
"adapters.outbound.chroma_rules.chromadb.PersistentClient",
return_value=ephemeral_client,
),
patch(
"adapters.outbound.chroma_rules.SentenceTransformer",
return_value=embedding_model_mock,
),
):
instance = ChromaRuleRepository(
persist_dir=tmp_path / "chroma",
embedding_model=EMBEDDING_MODEL,
)
yield instance
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.slow
class TestChromaRuleRepositoryContract:
"""Verify that ChromaRuleRepository satisfies the RuleRepository port."""
def test_is_rule_repository_subclass(self):
"""ChromaRuleRepository must be a concrete implementation of the port ABC."""
assert issubclass(ChromaRuleRepository, RuleRepository)
@pytest.mark.slow
class TestAddDocuments:
"""Tests for add_documents()."""
def test_add_single_document_increments_count(self, repo):
"""
Adding a single RuleDocument should make count() return 1.
Verifies that the adapter correctly maps the domain model to
ChromaDB's add() API.
"""
doc = _make_doc(rule_id="1.1", content="Single rule content.")
repo.add_documents([doc])
assert repo.count() == 1
def test_add_batch_all_stored(self, repo):
"""
Adding a batch of N documents should result in count() == N.
Validates that batch encoding and bulk add() work end-to-end.
"""
docs = [
_make_doc(rule_id=f"2.{i}", content=f"Batch rule number {i}.")
for i in range(5)
]
repo.add_documents(docs)
assert repo.count() == 5
def test_add_empty_list_is_noop(self, repo):
"""
Calling add_documents([]) must not raise and must leave count unchanged.
"""
repo.add_documents([])
assert repo.count() == 0
def test_add_document_with_optional_fields(self, repo):
"""
RuleDocument with parent_rule and page_ref set should be stored without
error; optional fields must be serialised via to_metadata().
"""
doc = _make_doc(
rule_id="3.1",
parent_rule="3.0",
page_ref="p.42",
)
repo.add_documents([doc])
assert repo.count() == 1
@pytest.mark.slow
class TestSearch:
"""Tests for search()."""
def test_search_returns_results(self, repo):
"""
After adding at least one document, search() must return a non-empty
list of RuleSearchResult objects.
"""
doc = _make_doc(rule_id="10.1", content="A searchable rule about batting.")
repo.add_documents([doc])
results = repo.search("batting rules", top_k=5)
assert len(results) >= 1
assert all(isinstance(r, RuleSearchResult) for r in results)
def test_search_result_fields_populated(self, repo):
"""
Each RuleSearchResult returned must have non-empty rule_id, title,
content, and section. This confirms metadata round-trips correctly
through ChromaDB.
"""
doc = _make_doc(
rule_id="11.1",
title="Fielding Rule",
section="Defense",
content="Rules for fielding plays.",
)
repo.add_documents([doc])
results = repo.search("fielding", top_k=1)
assert len(results) >= 1
r = results[0]
assert r.rule_id == "11.1"
assert r.title == "Fielding Rule"
assert r.section == "Defense"
assert r.content == "Rules for fielding plays."
def test_search_with_section_filter(self, repo):
"""
search() with section_filter must only return documents whose section
field matches the filter value. Documents from other sections must not
appear in the results even when they would otherwise score highly.
"""
docs = [
_make_doc(rule_id="20.1", section="Pitching", content="Pitching rules."),
_make_doc(rule_id="20.2", section="Batting", content="Batting rules."),
]
repo.add_documents(docs)
results = repo.search("rules", top_k=10, section_filter="Pitching")
assert len(results) >= 1
assert all(r.section == "Pitching" for r in results)
def test_search_top_k_respected(self, repo):
"""
The number of results must not exceed top_k even when more documents
exist in the collection.
"""
docs = [
_make_doc(rule_id=f"30.{i}", content=f"Rule number {i}.") for i in range(10)
]
repo.add_documents(docs)
results = repo.search("rule", top_k=3)
assert len(results) <= 3
def test_search_empty_collection_returns_empty_list(self, repo):
"""
Searching an empty collection must return an empty list without raising.
ChromaDB raises when n_results > collection size, so the adapter must
guard against this.
"""
results = repo.search("anything", top_k=5)
assert results == []
@pytest.mark.slow
class TestSimilarityClamping:
"""Tests for the similarity score clamping behaviour."""
def test_similarity_within_bounds(self, repo):
"""
Every RuleSearchResult returned by search() must have a similarity
value in [0.0, 1.0]. ChromaDB cosine distance can technically exceed
1 for near-opposite vectors; the adapter must clamp the value before
constructing RuleSearchResult (which validates the range in __post_init__).
"""
docs = [_make_doc(rule_id="40.1", content="Content for similarity check.")]
repo.add_documents(docs)
results = repo.search("similarity check", top_k=5)
for r in results:
assert (
0.0 <= r.similarity <= 1.0
), f"similarity {r.similarity} is outside [0.0, 1.0]"
def test_similarity_clamped_when_distance_exceeds_one(
self, repo, embedding_model_mock
):
"""
When ChromaDB returns a cosine distance > 1 (e.g. 1.5), the formula
``max(0.0, min(1.0, 1 - distance))`` must produce 0.0 rather than a
negative value, preventing the RuleSearchResult validator from raising.
We simulate this by patching the collection's query() to return a
synthetic distance of 1.5.
"""
doc = _make_doc(rule_id="50.1", content="Edge case content.")
repo.add_documents([doc])
raw_results = {
"documents": [["Edge case content."]],
"metadatas": [
[
{
"rule_id": "50.1",
"title": "Test Rule",
"section": "Section 1",
"parent_rule": "",
"page_ref": "",
"source_file": "rules/test.md",
}
]
],
"distances": [[1.5]], # distance > 1 → naive similarity would be negative
}
collection = repo._get_collection()
with patch.object(collection, "query", return_value=raw_results):
results = repo.search("edge case", top_k=1)
assert len(results) == 1
assert results[0].similarity == 0.0
@pytest.mark.slow
class TestCount:
"""Tests for count()."""
def test_count_empty(self, repo):
"""count() on a fresh collection must return 0."""
assert repo.count() == 0
def test_count_after_add(self, repo):
"""count() must reflect the exact number of documents added."""
docs = [_make_doc(rule_id=f"60.{i}") for i in range(3)]
repo.add_documents(docs)
assert repo.count() == 3
@pytest.mark.slow
class TestClearAll:
"""Tests for clear_all()."""
def test_clear_all_resets_count_to_zero(self, repo):
"""
After adding documents and calling clear_all(), count() must return 0.
Also verifies that the collection is recreated (not left deleted) so
subsequent operations succeed without error.
"""
docs = [_make_doc(rule_id=f"70.{i}") for i in range(4)]
repo.add_documents(docs)
assert repo.count() == 4
repo.clear_all()
assert repo.count() == 0
def test_operations_work_after_clear(self, repo):
"""
The adapter must be usable after clear_all() the internal collection
must be recreated so add_documents() and search() do not raise.
"""
repo.add_documents([_make_doc(rule_id="80.1")])
repo.clear_all()
new_doc = _make_doc(rule_id="80.2", content="Post-clear document.")
repo.add_documents([new_doc])
assert repo.count() == 1
@pytest.mark.slow
class TestGetStats:
"""Tests for get_stats()."""
def test_get_stats_returns_dict(self, repo):
"""get_stats() must return a dict (structural sanity check)."""
stats = repo.get_stats()
assert isinstance(stats, dict)
def test_get_stats_contains_required_keys(self, repo):
"""
get_stats() must include at minimum:
- ``total_rules``: int total document count
- ``sections``: dict per-section counts
- ``persist_directory``: str path used by the client
"""
docs = [
_make_doc(rule_id="90.1", section="Alpha"),
_make_doc(rule_id="90.2", section="Alpha"),
_make_doc(rule_id="90.3", section="Beta"),
]
repo.add_documents(docs)
stats = repo.get_stats()
assert "total_rules" in stats
assert "sections" in stats
assert "persist_directory" in stats
assert stats["total_rules"] == 3
assert stats["sections"]["Alpha"] == 2
assert stats["sections"]["Beta"] == 1

View File

@ -0,0 +1,411 @@
"""Tests for GiteaIssueTracker — the outbound adapter for the IssueTracker port.
Strategy: use httpx.MockTransport to intercept HTTP calls without a live Gitea
server. This exercises the real adapter code (headers, URL construction, JSON
serialisation, error handling) without any external network dependency.
We import GiteaIssueTracker from adapters.outbound.gitea_issues and verify it
against the IssueTracker ABC from domain.ports confirming the adapter truly
satisfies the port contract.
"""
import json
import pytest
import httpx
from domain.ports import IssueTracker
from adapters.outbound.gitea_issues import GiteaIssueTracker
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_issue_response(
issue_number: int = 1,
title: str = "Test issue",
html_url: str = "https://gitea.example.com/owner/repo/issues/1",
) -> dict:
"""Return a minimal Gitea issue API response payload."""
return {
"id": issue_number,
"number": issue_number,
"title": title,
"html_url": html_url,
"state": "open",
}
class _MockTransport(httpx.AsyncBaseTransport):
"""Configurable httpx transport that returns a pre-built response.
Captures the outgoing request so tests can assert on it after the fact.
"""
def __init__(self, status_code: int = 201, body: dict | None = None):
self.status_code = status_code
self.body = body or _make_issue_response()
self.last_request: httpx.Request | None = None
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
self.last_request = request
content = json.dumps(self.body).encode()
return httpx.Response(
status_code=self.status_code,
headers={"Content-Type": "application/json"},
content=content,
)
def _make_tracker(transport: httpx.AsyncBaseTransport) -> GiteaIssueTracker:
"""Construct a GiteaIssueTracker wired to the given mock transport."""
tracker = GiteaIssueTracker(
token="test-token-abc",
owner="testowner",
repo="testrepo",
base_url="https://gitea.example.com",
)
# Replace the internal client's transport with our mock.
# We recreate the client so we don't have to expose the transport in __init__.
tracker._client = httpx.AsyncClient(
transport=transport,
headers=tracker._headers,
timeout=30.0,
)
return tracker
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def good_transport():
"""Mock transport that returns a successful 201 issue response."""
return _MockTransport(status_code=201)
@pytest.fixture
def error_transport():
"""Mock transport that simulates a Gitea API 422 error."""
return _MockTransport(
status_code=422,
body={"message": "label does not exist"},
)
@pytest.fixture
def good_tracker(good_transport):
return _make_tracker(good_transport)
@pytest.fixture
def error_tracker(error_transport):
return _make_tracker(error_transport)
# ---------------------------------------------------------------------------
# Port contract test
# ---------------------------------------------------------------------------
class TestPortContract:
"""GiteaIssueTracker must be a concrete subclass of IssueTracker."""
def test_is_subclass_of_issue_tracker_port(self):
"""The adapter satisfies the IssueTracker ABC — no missing abstract methods."""
assert issubclass(GiteaIssueTracker, IssueTracker)
def test_instance_passes_isinstance_check(self, good_tracker):
"""An instantiated adapter is accepted anywhere IssueTracker is expected."""
assert isinstance(good_tracker, IssueTracker)
# ---------------------------------------------------------------------------
# Successful issue creation
# ---------------------------------------------------------------------------
class TestSuccessfulIssueCreation:
"""Happy-path behaviour when Gitea responds with 201."""
async def test_returns_html_url(self, good_tracker):
"""create_unanswered_issue should return the html_url from the API response."""
url = await good_tracker.create_unanswered_issue(
question="Can I steal home?",
user_id="user-42",
channel_id="chan-99",
attempted_rules=["5.2.1(b)", "5.2.2"],
conversation_id="conv-abc",
)
assert url == "https://gitea.example.com/owner/repo/issues/1"
async def test_posts_to_correct_endpoint(self, good_tracker, good_transport):
"""The adapter must POST to /repos/{owner}/{repo}/issues."""
await good_tracker.create_unanswered_issue(
question="Can I steal home?",
user_id="user-42",
channel_id="chan-99",
attempted_rules=[],
conversation_id="conv-abc",
)
req = good_transport.last_request
assert req is not None
assert req.method == "POST"
assert "/repos/testowner/testrepo/issues" in str(req.url)
async def test_sends_bearer_token(self, good_tracker, good_transport):
"""Authorization header must carry the configured token."""
await good_tracker.create_unanswered_issue(
question="test question",
user_id="u",
channel_id="c",
attempted_rules=[],
conversation_id="conv-1",
)
req = good_transport.last_request
assert req.headers["Authorization"] == "token test-token-abc"
async def test_content_type_is_json(self, good_tracker, good_transport):
"""The request must declare application/json content type."""
await good_tracker.create_unanswered_issue(
question="test",
user_id="u",
channel_id="c",
attempted_rules=[],
conversation_id="c1",
)
req = good_transport.last_request
assert req.headers["Content-Type"] == "application/json"
async def test_also_accepts_200_status(self, good_tracker):
"""Some Gitea instances return 200 on issue creation; both are valid."""
transport_200 = _MockTransport(status_code=200)
tracker = _make_tracker(transport_200)
url = await tracker.create_unanswered_issue(
question="Is 200 ok?",
user_id="u",
channel_id="c",
attempted_rules=[],
conversation_id="c1",
)
assert url == "https://gitea.example.com/owner/repo/issues/1"
# ---------------------------------------------------------------------------
# Issue body content
# ---------------------------------------------------------------------------
class TestIssueBodyContent:
"""The issue body must contain context needed for human triage."""
async def _get_body(self, transport, **kwargs) -> str:
tracker = _make_tracker(transport)
defaults = dict(
question="Can I intentionally walk a batter?",
user_id="user-99",
channel_id="channel-7",
attempted_rules=["4.1.1", "4.1.2"],
conversation_id="conv-xyz",
)
defaults.update(kwargs)
await tracker.create_unanswered_issue(**defaults)
req = transport.last_request
return json.loads(req.content)["body"]
async def test_body_contains_question_in_code_block(self, good_transport):
"""The question must be wrapped in a fenced code block to prevent markdown
injection a user could craft a question containing headers, links, or
other markdown that would corrupt the issue layout."""
body = await self._get_body(
good_transport, question="Can I intentionally walk a batter?"
)
assert "```" in body
assert "Can I intentionally walk a batter?" in body
# Must be inside a fenced block (preceded by ```)
fence_idx = body.index("```")
question_idx = body.index("Can I intentionally walk a batter?")
assert fence_idx < question_idx
async def test_body_contains_user_id(self, good_transport):
"""User ID must appear in the body so reviewers know who asked."""
body = await self._get_body(good_transport, user_id="user-99")
assert "user-99" in body
async def test_body_contains_channel_id(self, good_transport):
"""Channel ID must appear so reviewers can locate the conversation."""
body = await self._get_body(good_transport, channel_id="channel-7")
assert "channel-7" in body
async def test_body_contains_conversation_id(self, good_transport):
"""Conversation ID must be present for traceability to the chat log."""
body = await self._get_body(good_transport, conversation_id="conv-xyz")
assert "conv-xyz" in body
async def test_body_contains_attempted_rules(self, good_transport):
"""Searched rule IDs must be listed so reviewers know what was tried."""
body = await self._get_body(good_transport, attempted_rules=["4.1.1", "4.1.2"])
assert "4.1.1" in body
assert "4.1.2" in body
async def test_body_handles_empty_attempted_rules(self, good_transport):
"""An empty rules list should not crash; body should gracefully note none."""
body = await self._get_body(good_transport, attempted_rules=[])
# Should not raise and body should still be a non-empty string
assert isinstance(body, str)
assert len(body) > 0
async def test_title_contains_truncated_question(self, good_transport):
"""Issue title should contain the question (truncated to ~80 chars)."""
transport = good_transport
tracker = _make_tracker(transport)
long_question = "A" * 200
await tracker.create_unanswered_issue(
question=long_question,
user_id="u",
channel_id="c",
attempted_rules=[],
conversation_id="c1",
)
req = transport.last_request
payload = json.loads(req.content)
# Title should not be absurdly long — it should be truncated
assert len(payload["title"]) < 150
# ---------------------------------------------------------------------------
# Labels
# ---------------------------------------------------------------------------
class TestLabels:
"""Labels must be passed to the Gitea API in the request payload."""
async def test_labels_present_in_request_payload(
self, good_tracker, good_transport
):
"""The adapter should send a 'labels' field in the POST body."""
await good_tracker.create_unanswered_issue(
question="test",
user_id="u",
channel_id="c",
attempted_rules=[],
conversation_id="c1",
)
payload = json.loads(good_transport.last_request.content)
assert "labels" in payload
assert isinstance(payload["labels"], list)
assert len(payload["labels"]) > 0
async def test_expected_label_values(self, good_tracker, good_transport):
"""Labels should identify the issue origin clearly.
We require at least 'rules-gap' or equivalent, 'ai-generated', and
'needs-review' so that Gitea project boards can filter automatically.
"""
await good_tracker.create_unanswered_issue(
question="test",
user_id="u",
channel_id="c",
attempted_rules=[],
conversation_id="c1",
)
payload = json.loads(good_transport.last_request.content)
labels = payload["labels"]
assert "rules-gap" in labels
assert "needs-review" in labels
assert "ai-generated" in labels
# ---------------------------------------------------------------------------
# API error handling
# ---------------------------------------------------------------------------
class TestAPIErrorHandling:
"""Non-2xx responses from Gitea should raise a descriptive RuntimeError."""
async def test_raises_on_422(self, error_tracker):
"""A 422 Unprocessable Entity should raise RuntimeError with status info."""
with pytest.raises(RuntimeError) as exc_info:
await error_tracker.create_unanswered_issue(
question="bad label question",
user_id="u",
channel_id="c",
attempted_rules=[],
conversation_id="c1",
)
msg = str(exc_info.value)
assert "422" in msg
async def test_raises_on_401(self):
"""A 401 Unauthorized (bad token) should raise RuntimeError."""
transport = _MockTransport(status_code=401, body={"message": "Unauthorized"})
tracker = _make_tracker(transport)
with pytest.raises(RuntimeError) as exc_info:
await tracker.create_unanswered_issue(
question="test",
user_id="u",
channel_id="c",
attempted_rules=[],
conversation_id="c1",
)
assert "401" in str(exc_info.value)
async def test_raises_on_500(self):
"""A 500 server error should raise RuntimeError, not silently return empty."""
transport = _MockTransport(
status_code=500, body={"message": "Internal Server Error"}
)
tracker = _make_tracker(transport)
with pytest.raises(RuntimeError) as exc_info:
await tracker.create_unanswered_issue(
question="test",
user_id="u",
channel_id="c",
attempted_rules=[],
conversation_id="c1",
)
assert "500" in str(exc_info.value)
async def test_error_message_includes_response_body(self, error_tracker):
"""The RuntimeError message should embed the raw API error body to aid
debugging operators need to know whether the failure was a bad label,
an auth issue, a quota error, etc."""
with pytest.raises(RuntimeError) as exc_info:
await error_tracker.create_unanswered_issue(
question="test",
user_id="u",
channel_id="c",
attempted_rules=[],
conversation_id="c1",
)
# The error transport returns {"message": "label does not exist"}
assert "label does not exist" in str(exc_info.value)
# ---------------------------------------------------------------------------
# Lifecycle — persistent client
# ---------------------------------------------------------------------------
class TestClientLifecycle:
"""The adapter must expose a close() coroutine for clean resource teardown."""
async def test_close_is_callable(self, good_tracker):
"""close() should exist and be awaitable (used in dependency teardown)."""
# Should not raise
await good_tracker.close()
async def test_close_after_request_does_not_raise(self, good_tracker):
"""Closing after making a real request should be clean."""
await good_tracker.create_unanswered_issue(
question="cleanup test",
user_id="u",
channel_id="c",
attempted_rules=[],
conversation_id="c1",
)
await good_tracker.close() # should not raise

View File

@ -0,0 +1,392 @@
"""Tests for the OpenRouterLLM outbound adapter.
Tests cover:
- Successful JSON response parsing from the LLM
- JSON embedded in markdown code fences (```json ... ```)
- Plain-text fallback when JSON parsing fails completely
- HTTP error status codes raising RuntimeError
- Regex fallback for cited_rules when the LLM omits them but mentions rules in text
- Conversation history is forwarded correctly to the API
- The adapter returns domain.models.LLMResponse, not any legacy type
- close() shuts down the underlying httpx client
All HTTP calls are intercepted via unittest.mock so no real API key is needed.
"""
from __future__ import annotations
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from domain.models import LLMResponse, RuleSearchResult
from domain.ports import LLMPort
from adapters.outbound.openrouter import OpenRouterLLM
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_rules(*rule_ids: str) -> list[RuleSearchResult]:
"""Create minimal RuleSearchResult fixtures."""
return [
RuleSearchResult(
rule_id=rid,
title=f"Title for {rid}",
content=f"Content for rule {rid}.",
section="General",
similarity=0.9,
)
for rid in rule_ids
]
def _api_payload(content: str) -> dict:
"""Wrap a content string in the OpenRouter / OpenAI response envelope."""
return {"choices": [{"message": {"content": content}}]}
def _mock_http_response(
status_code: int = 200, body: dict | str | None = None
) -> MagicMock:
"""Build a mock httpx.Response with the given status and JSON body."""
resp = MagicMock()
resp.status_code = status_code
if isinstance(body, dict):
resp.json.return_value = body
resp.text = json.dumps(body)
else:
resp.json.side_effect = ValueError("not JSON")
resp.text = body or ""
return resp
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def adapter() -> OpenRouterLLM:
"""Return an OpenRouterLLM with a mocked internal httpx.AsyncClient.
We patch httpx.AsyncClient so the adapter's __init__ wires up a mock
that we can control per-test through the returned instance.
"""
mock_client = AsyncMock()
with patch(
"adapters.outbound.openrouter.httpx.AsyncClient", return_value=mock_client
):
inst = OpenRouterLLM(api_key="test-key", model="test-model")
inst._http = mock_client
return inst
# ---------------------------------------------------------------------------
# Interface compliance
# ---------------------------------------------------------------------------
def test_openrouter_llm_implements_port():
"""OpenRouterLLM must be a concrete implementation of LLMPort.
This catches missing abstract method overrides at class-definition time,
not just at instantiation time.
"""
assert issubclass(OpenRouterLLM, LLMPort)
# ---------------------------------------------------------------------------
# Successful JSON response
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_successful_json_response(adapter: OpenRouterLLM):
"""A well-formed JSON body from the LLM should be parsed into LLMResponse.
Verifies that answer, cited_rules, confidence, and needs_human are all
mapped correctly from the parsed JSON.
"""
llm_json = {
"answer": "The runner advances one base.",
"cited_rules": ["5.2.1(b)", "5.2.2"],
"confidence": 0.9,
"needs_human": False,
}
api_body = _api_payload(json.dumps(llm_json))
adapter._http.post = AsyncMock(return_value=_mock_http_response(200, api_body))
result = await adapter.generate_response(
"Can the runner advance?", _make_rules("5.2.1(b)", "5.2.2")
)
assert isinstance(result, LLMResponse)
assert result.answer == "The runner advances one base."
assert "5.2.1(b)" in result.cited_rules
assert "5.2.2" in result.cited_rules
assert result.confidence == pytest.approx(0.9)
assert result.needs_human is False
# ---------------------------------------------------------------------------
# Markdown-fenced JSON response
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_markdown_fenced_json_response(adapter: OpenRouterLLM):
"""LLMs often wrap JSON in ```json ... ``` fences.
The adapter must strip the fences before parsing so responses formatted
this way are handled identically to bare JSON.
"""
llm_json = {
"answer": "No, the batter is out.",
"cited_rules": ["3.1"],
"confidence": 0.85,
"needs_human": False,
}
fenced_content = f"```json\n{json.dumps(llm_json)}\n```"
api_body = _api_payload(fenced_content)
adapter._http.post = AsyncMock(return_value=_mock_http_response(200, api_body))
result = await adapter.generate_response("Is the batter out?", _make_rules("3.1"))
assert isinstance(result, LLMResponse)
assert result.answer == "No, the batter is out."
assert result.cited_rules == ["3.1"]
assert result.confidence == pytest.approx(0.85)
assert result.needs_human is False
# ---------------------------------------------------------------------------
# Plain-text fallback (JSON parse failure)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_plain_text_fallback_on_parse_failure(adapter: OpenRouterLLM):
"""When the LLM returns plain text that cannot be parsed as JSON, the
adapter falls back gracefully:
- answer = raw content string
- cited_rules = []
- confidence = 0.0 (not 0.5, signalling unreliable parse)
- needs_human = True (not False, signalling human review needed)
"""
plain_text = "I'm not sure which rule covers this situation."
api_body = _api_payload(plain_text)
adapter._http.post = AsyncMock(return_value=_mock_http_response(200, api_body))
result = await adapter.generate_response("Which rule applies?", [])
assert isinstance(result, LLMResponse)
assert result.answer == plain_text
assert result.cited_rules == []
assert result.confidence == pytest.approx(0.0)
assert result.needs_human is True
# ---------------------------------------------------------------------------
# HTTP error codes
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_http_error_raises_runtime_error(adapter: OpenRouterLLM):
"""Non-200 HTTP status codes from the API must raise RuntimeError.
This ensures upstream callers (the service layer) can catch a predictable
exception type and decide whether to retry or surface an error message.
"""
error_body_text = "Rate limit exceeded"
resp = _mock_http_response(429, error_body_text)
adapter._http.post = AsyncMock(return_value=resp)
with pytest.raises(RuntimeError, match="429"):
await adapter.generate_response("Any question", [])
@pytest.mark.asyncio
async def test_http_500_raises_runtime_error(adapter: OpenRouterLLM):
"""500 Internal Server Error from OpenRouter should also raise RuntimeError."""
resp = _mock_http_response(500, "Internal server error")
adapter._http.post = AsyncMock(return_value=resp)
with pytest.raises(RuntimeError, match="500"):
await adapter.generate_response("Any question", [])
# ---------------------------------------------------------------------------
# cited_rules regex fallback
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_cited_rules_regex_fallback(adapter: OpenRouterLLM):
"""When the LLM returns valid JSON but omits cited_rules (empty list),
the adapter should extract rule IDs mentioned in the answer text via regex
and populate cited_rules from those matches.
This preserves rule attribution even when the model forgets the field.
"""
llm_json = {
"answer": "According to Rule 5.2.1(b) the runner must advance. See also Rule 7.4.",
"cited_rules": [],
"confidence": 0.75,
"needs_human": False,
}
api_body = _api_payload(json.dumps(llm_json))
adapter._http.post = AsyncMock(return_value=_mock_http_response(200, api_body))
result = await adapter.generate_response(
"Advance question?", _make_rules("5.2.1(b)", "7.4")
)
assert isinstance(result, LLMResponse)
# Regex should have extracted both rule IDs from the answer text
assert "5.2.1(b)" in result.cited_rules
assert "7.4" in result.cited_rules
@pytest.mark.asyncio
async def test_cited_rules_regex_not_triggered_when_rules_present(
adapter: OpenRouterLLM,
):
"""When cited_rules is already populated by the LLM, the regex fallback
must NOT override it to avoid double-adding or mangling IDs.
"""
llm_json = {
"answer": "Rule 5.2.1(b) says the runner advances.",
"cited_rules": ["5.2.1(b)"],
"confidence": 0.8,
"needs_human": False,
}
api_body = _api_payload(json.dumps(llm_json))
adapter._http.post = AsyncMock(return_value=_mock_http_response(200, api_body))
result = await adapter.generate_response(
"Advance question?", _make_rules("5.2.1(b)")
)
assert result.cited_rules == ["5.2.1(b)"]
# ---------------------------------------------------------------------------
# Conversation history forwarded correctly
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_conversation_history_included_in_request(adapter: OpenRouterLLM):
"""When conversation_history is provided it must appear in the messages list
sent to the API, interleaved between the system prompt and the new user turn.
We inspect the captured POST body to assert ordering and content.
"""
history = [
{"role": "user", "content": "Who bats first?"},
{"role": "assistant", "content": "The home team bats last."},
]
llm_json = {
"answer": "Yes, that is correct.",
"cited_rules": [],
"confidence": 0.8,
"needs_human": False,
}
api_body = _api_payload(json.dumps(llm_json))
adapter._http.post = AsyncMock(return_value=_mock_http_response(200, api_body))
await adapter.generate_response(
"Follow-up question?", [], conversation_history=history
)
call_kwargs = adapter._http.post.call_args
sent_json = (
call_kwargs.kwargs.get("json") or call_kwargs.args[1]
if call_kwargs.args
else call_kwargs.kwargs["json"]
)
messages = sent_json["messages"]
roles = [m["role"] for m in messages]
# system prompt first, history next, new user message last
assert roles[0] == "system"
assert {"role": "user", "content": "Who bats first?"} in messages
assert {"role": "assistant", "content": "The home team bats last."} in messages
# final message should be the new user turn
assert messages[-1]["role"] == "user"
assert "Follow-up question?" in messages[-1]["content"]
@pytest.mark.asyncio
async def test_no_conversation_history_omitted_from_request(adapter: OpenRouterLLM):
"""When conversation_history is None or empty the messages list must only
contain the system prompt and the new user message no history entries.
"""
llm_json = {
"answer": "Yes.",
"cited_rules": [],
"confidence": 0.9,
"needs_human": False,
}
api_body = _api_payload(json.dumps(llm_json))
adapter._http.post = AsyncMock(return_value=_mock_http_response(200, api_body))
await adapter.generate_response("Simple question?", [], conversation_history=None)
call_kwargs = adapter._http.post.call_args
sent_json = call_kwargs.kwargs.get("json") or call_kwargs.kwargs["json"]
messages = sent_json["messages"]
assert len(messages) == 2
assert messages[0]["role"] == "system"
assert messages[1]["role"] == "user"
# ---------------------------------------------------------------------------
# No rules context
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_no_rules_uses_not_found_message(adapter: OpenRouterLLM):
"""When rules is an empty list the user message sent to the API should
contain a clear indication that no relevant rules were found, rather than
an empty or misleading context block.
"""
llm_json = {
"answer": "I don't have a rule for this.",
"cited_rules": [],
"confidence": 0.1,
"needs_human": True,
}
api_body = _api_payload(json.dumps(llm_json))
adapter._http.post = AsyncMock(return_value=_mock_http_response(200, api_body))
await adapter.generate_response("Unknown rule question?", [])
call_kwargs = adapter._http.post.call_args
sent_json = call_kwargs.kwargs.get("json") or call_kwargs.kwargs["json"]
user_message = next(
m["content"] for m in sent_json["messages"] if m["role"] == "user"
)
assert "No relevant rules" in user_message
# ---------------------------------------------------------------------------
# close()
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_close_shuts_down_http_client(adapter: OpenRouterLLM):
"""close() must await the underlying httpx.AsyncClient.aclose() so that
connection pools are released cleanly without leaving open sockets.
"""
adapter._http.aclose = AsyncMock()
await adapter.close()
adapter._http.aclose.assert_awaited_once()

View File

@ -0,0 +1,266 @@
"""Tests for the SQLiteConversationStore outbound adapter.
Uses an in-memory SQLite database (sqlite+aiosqlite://) so each test is fast
and hermetic no file I/O, no shared state between tests.
What we verify:
- A fresh conversation can be created and its ID returned.
- Calling get_or_create_conversation with an existing ID returns the same ID
(and does NOT create a new row).
- Calling get_or_create_conversation with an unknown/missing ID creates a new
conversation (graceful fallback rather than a hard error).
- Messages can be appended to a conversation; each returns a unique ID.
- get_conversation_history returns messages in chronological order (oldest
first), not insertion-reverse order.
- The limit parameter is respected; when more messages exist than the limit,
only the most-recent `limit` messages come back (still chronological within
that window).
- The returned dicts have exactly the keys {"role", "content"}, matching the
OpenAI-compatible format expected by the LLM port.
"""
import pytest
from adapters.outbound.sqlite_convos import SQLiteConversationStore
IN_MEMORY_URL = "sqlite+aiosqlite://"
@pytest.fixture
async def store() -> SQLiteConversationStore:
"""Create an initialised in-memory store for a single test.
The fixture is async because init_db() is a coroutine that runs the
CREATE TABLE statements. Each test gets a completely fresh database
because in-memory SQLite databases are private to the connection that
created them.
"""
s = SQLiteConversationStore(db_url=IN_MEMORY_URL)
await s.init_db()
return s
# ---------------------------------------------------------------------------
# Conversation creation
# ---------------------------------------------------------------------------
async def test_create_new_conversation(store: SQLiteConversationStore):
"""get_or_create_conversation should return a non-empty string ID when no
existing conversation_id is supplied."""
conv_id = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
assert isinstance(conv_id, str)
assert len(conv_id) > 0
async def test_create_conversation_returns_uuid_format(
store: SQLiteConversationStore,
):
"""The generated conversation ID should look like a UUID (36-char with
hyphens), since we use uuid.uuid4() internally."""
conv_id = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
# UUID4 format: 8-4-4-4-12 hex digits separated by hyphens = 36 chars
assert len(conv_id) == 36
assert conv_id.count("-") == 4
# ---------------------------------------------------------------------------
# Idempotency — fetching an existing conversation
# ---------------------------------------------------------------------------
async def test_get_existing_conversation_returns_same_id(
store: SQLiteConversationStore,
):
"""Passing an existing conversation_id back into get_or_create_conversation
must return exactly that same ID, not create a new one."""
original_id = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
fetched_id = await store.get_or_create_conversation(
user_id="u1", channel_id="ch1", conversation_id=original_id
)
assert fetched_id == original_id
async def test_get_unknown_conversation_id_creates_new(
store: SQLiteConversationStore,
):
"""If conversation_id is provided but not found in the DB, the adapter
should gracefully create a fresh conversation rather than raise."""
new_id = await store.get_or_create_conversation(
user_id="u2",
channel_id="ch2",
conversation_id="00000000-0000-0000-0000-000000000000",
)
assert isinstance(new_id, str)
# The returned ID must differ from the bogus one we passed in.
assert new_id != "00000000-0000-0000-0000-000000000000"
# ---------------------------------------------------------------------------
# Adding messages
# ---------------------------------------------------------------------------
async def test_add_message_returns_string_id(store: SQLiteConversationStore):
"""add_message should return a non-empty string ID for the new message."""
conv_id = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
msg_id = await store.add_message(
conversation_id=conv_id, content="Hello!", is_user=True
)
assert isinstance(msg_id, str)
assert len(msg_id) > 0
async def test_add_multiple_messages_returns_unique_ids(
store: SQLiteConversationStore,
):
"""Every call to add_message must produce a distinct message ID."""
conv_id = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
id1 = await store.add_message(conv_id, "Hi", is_user=True)
id2 = await store.add_message(conv_id, "Hello back", is_user=False)
assert id1 != id2
async def test_add_message_with_parent_id(store: SQLiteConversationStore):
"""add_message should accept an optional parent_id without error. We
cannot easily inspect the raw DB row here, but we verify that the call
succeeds and returns an ID."""
conv_id = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
parent_id = await store.add_message(conv_id, "parent msg", is_user=True)
child_id = await store.add_message(
conv_id, "child msg", is_user=False, parent_id=parent_id
)
assert isinstance(child_id, str)
assert child_id != parent_id
# ---------------------------------------------------------------------------
# Conversation history — format
# ---------------------------------------------------------------------------
async def test_history_returns_list_of_dicts(store: SQLiteConversationStore):
"""get_conversation_history must return a list of dicts."""
conv_id = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
await store.add_message(conv_id, "Hello", is_user=True)
history = await store.get_conversation_history(conv_id)
assert isinstance(history, list)
assert len(history) == 1
assert isinstance(history[0], dict)
async def test_history_dict_has_role_and_content_keys(
store: SQLiteConversationStore,
):
"""Each dict in the history must have exactly the keys 'role' and
'content', matching the OpenAI chat-completion message format."""
conv_id = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
await store.add_message(conv_id, "A question", is_user=True)
await store.add_message(conv_id, "An answer", is_user=False)
history = await store.get_conversation_history(conv_id)
for entry in history:
assert set(entry.keys()) == {
"role",
"content",
}, f"Expected keys {{'role','content'}}, got {set(entry.keys())}"
async def test_history_role_mapping(store: SQLiteConversationStore):
"""is_user=True maps to role='user'; is_user=False maps to
role='assistant'."""
conv_id = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
await store.add_message(conv_id, "user msg", is_user=True)
await store.add_message(conv_id, "assistant msg", is_user=False)
history = await store.get_conversation_history(conv_id)
roles = [e["role"] for e in history]
assert "user" in roles
assert "assistant" in roles
# ---------------------------------------------------------------------------
# Conversation history — ordering
# ---------------------------------------------------------------------------
async def test_history_is_chronological(store: SQLiteConversationStore):
"""Messages must come back oldest-first (chronological), NOT newest-first.
The underlying query orders DESC then reverses, so the first item in the
returned list must have the content of the first message we inserted.
"""
conv_id = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
await store.add_message(conv_id, "first", is_user=True)
await store.add_message(conv_id, "second", is_user=False)
await store.add_message(conv_id, "third", is_user=True)
history = await store.get_conversation_history(conv_id, limit=10)
contents = [e["content"] for e in history]
assert contents == [
"first",
"second",
"third",
], f"Expected chronological order, got: {contents}"
# ---------------------------------------------------------------------------
# Conversation history — limit
# ---------------------------------------------------------------------------
async def test_history_limit_respected(store: SQLiteConversationStore):
"""When there are more messages than the limit, only `limit` messages are
returned."""
conv_id = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
for i in range(5):
await store.add_message(conv_id, f"msg {i}", is_user=(i % 2 == 0))
history = await store.get_conversation_history(conv_id, limit=3)
assert len(history) == 3
async def test_history_limit_returns_most_recent(
store: SQLiteConversationStore,
):
"""When the limit truncates results, the MOST RECENT messages should be
included, not the oldest ones. After inserting 5 messages (0-4) and
requesting limit=2, we expect messages 3 and 4 (in chronological order)."""
conv_id = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
for i in range(5):
await store.add_message(conv_id, f"msg {i}", is_user=(i % 2 == 0))
history = await store.get_conversation_history(conv_id, limit=2)
contents = [e["content"] for e in history]
assert contents == [
"msg 3",
"msg 4",
], f"Expected the 2 most-recent messages in order, got: {contents}"
async def test_history_empty_conversation(store: SQLiteConversationStore):
"""A conversation with no messages returns an empty list, not an error."""
conv_id = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
history = await store.get_conversation_history(conv_id)
assert history == []
# ---------------------------------------------------------------------------
# Isolation between conversations
# ---------------------------------------------------------------------------
async def test_history_isolated_between_conversations(
store: SQLiteConversationStore,
):
"""Messages from one conversation must not appear in another conversation's
history."""
conv_a = await store.get_or_create_conversation(user_id="u1", channel_id="ch1")
conv_b = await store.get_or_create_conversation(user_id="u2", channel_id="ch2")
await store.add_message(conv_a, "from A", is_user=True)
await store.add_message(conv_b, "from B", is_user=True)
history_a = await store.get_conversation_history(conv_a)
history_b = await store.get_conversation_history(conv_b)
assert len(history_a) == 1
assert history_a[0]["content"] == "from A"
assert len(history_b) == 1
assert history_b[0]["content"] == "from B"

0
tests/domain/__init__.py Normal file
View File

200
tests/domain/test_models.py Normal file
View File

@ -0,0 +1,200 @@
"""Tests for domain models — pure data structures with no framework dependencies."""
from datetime import datetime, timezone
from domain.models import (
RuleDocument,
RuleSearchResult,
Conversation,
ChatMessage,
LLMResponse,
ChatResult,
)
class TestRuleDocument:
"""RuleDocument holds rule content with metadata for the knowledge base."""
def test_create_with_required_fields(self):
doc = RuleDocument(
rule_id="5.2.1(b)",
title="Stolen Base Attempts",
section="Baserunning",
content="When a runner attempts to steal...",
source_file="data/rules/baserunning.md",
)
assert doc.rule_id == "5.2.1(b)"
assert doc.title == "Stolen Base Attempts"
assert doc.section == "Baserunning"
assert doc.parent_rule is None
assert doc.page_ref is None
def test_optional_fields(self):
doc = RuleDocument(
rule_id="5.2",
title="Baserunning Overview",
section="Baserunning",
content="Overview content",
source_file="rules.md",
parent_rule="5",
page_ref="32",
)
assert doc.parent_rule == "5"
assert doc.page_ref == "32"
def test_metadata_dict_for_vector_store(self):
"""to_metadata() returns a flat dict suitable for ChromaDB/vector store metadata."""
doc = RuleDocument(
rule_id="5.2.1(b)",
title="Stolen Base Attempts",
section="Baserunning",
content="content",
source_file="rules.md",
parent_rule="5.2",
page_ref="32",
)
meta = doc.to_metadata()
assert meta == {
"rule_id": "5.2.1(b)",
"title": "Stolen Base Attempts",
"section": "Baserunning",
"parent_rule": "5.2",
"page_ref": "32",
"source_file": "rules.md",
}
def test_metadata_dict_empty_optionals(self):
"""Optional fields should be empty strings in metadata (not None) for vector stores."""
doc = RuleDocument(
rule_id="1.0",
title="General",
section="General",
content="c",
source_file="f.md",
)
meta = doc.to_metadata()
assert meta["parent_rule"] == ""
assert meta["page_ref"] == ""
class TestRuleSearchResult:
"""RuleSearchResult is what comes back from a semantic search."""
def test_create(self):
result = RuleSearchResult(
rule_id="5.2.1(b)",
title="Stolen Base Attempts",
content="When a runner attempts...",
section="Baserunning",
similarity=0.85,
)
assert result.similarity == 0.85
def test_similarity_bounds(self):
"""Similarity must be between 0.0 and 1.0."""
import pytest
with pytest.raises(ValueError):
RuleSearchResult(
rule_id="x", title="t", content="c", section="s", similarity=-0.1
)
with pytest.raises(ValueError):
RuleSearchResult(
rule_id="x", title="t", content="c", section="s", similarity=1.1
)
class TestConversation:
"""Conversation tracks a chat session between a user and the bot."""
def test_create_with_defaults(self):
conv = Conversation(
id="conv-123",
user_id="user-456",
channel_id="chan-789",
)
assert conv.id == "conv-123"
assert isinstance(conv.created_at, datetime)
assert isinstance(conv.last_activity, datetime)
def test_explicit_timestamps(self):
ts = datetime(2026, 1, 1, tzinfo=timezone.utc)
conv = Conversation(
id="c",
user_id="u",
channel_id="ch",
created_at=ts,
last_activity=ts,
)
assert conv.created_at == ts
class TestChatMessage:
"""ChatMessage is a single message in a conversation."""
def test_user_message(self):
msg = ChatMessage(
id="msg-1",
conversation_id="conv-1",
content="What is the steal rule?",
is_user=True,
)
assert msg.is_user is True
assert msg.parent_id is None
def test_assistant_message_with_parent(self):
msg = ChatMessage(
id="msg-2",
conversation_id="conv-1",
content="According to Rule 5.2.1(b)...",
is_user=False,
parent_id="msg-1",
)
assert msg.parent_id == "msg-1"
class TestLLMResponse:
"""LLMResponse is the structured output from the LLM port."""
def test_create(self):
resp = LLMResponse(
answer="Based on Rule 5.2.1(b), runners can steal...",
cited_rules=["5.2.1(b)"],
confidence=0.9,
needs_human=False,
)
assert resp.answer.startswith("Based on")
assert resp.confidence == 0.9
def test_defaults(self):
resp = LLMResponse(answer="text")
assert resp.cited_rules == []
assert resp.confidence == 0.5
assert resp.needs_human is False
class TestChatResult:
"""ChatResult is the final result returned by ChatService to inbound adapters."""
def test_create(self):
result = ChatResult(
response="answer text",
conversation_id="conv-1",
message_id="msg-2",
parent_message_id="msg-1",
cited_rules=["5.2.1(b)"],
confidence=0.85,
needs_human=False,
)
assert result.response == "answer text"
assert result.parent_message_id == "msg-1"
def test_optional_parent(self):
result = ChatResult(
response="r",
conversation_id="c",
message_id="m",
cited_rules=[],
confidence=0.5,
needs_human=False,
)
assert result.parent_message_id is None

View File

@ -0,0 +1,256 @@
"""Tests for ChatService — the core use case, tested entirely with fakes."""
import pytest
from domain.models import RuleDocument
from domain.services import ChatService
from tests.fakes import (
FakeRuleRepository,
FakeLLM,
FakeConversationStore,
FakeIssueTracker,
)
@pytest.fixture
def rules_repo():
repo = FakeRuleRepository()
repo.add_documents(
[
RuleDocument(
rule_id="5.2.1(b)",
title="Stolen Base Attempts",
section="Baserunning",
content="When a runner attempts to steal a base, roll 2 dice.",
source_file="rules.md",
),
RuleDocument(
rule_id="3.1",
title="Pitching Overview",
section="Pitching",
content="The pitcher rolls for each at-bat using the pitching card.",
source_file="rules.md",
),
]
)
return repo
@pytest.fixture
def llm():
return FakeLLM()
@pytest.fixture
def conversations():
return FakeConversationStore()
@pytest.fixture
def issues():
return FakeIssueTracker()
@pytest.fixture
def service(rules_repo, llm, conversations, issues):
return ChatService(
rules=rules_repo,
llm=llm,
conversations=conversations,
issues=issues,
)
class TestChatServiceAnswerQuestion:
"""ChatService.answer_question orchestrates the full Q&A flow."""
async def test_returns_answer_with_cited_rules(self, service):
"""When rules match the question, the LLM is called and rules are cited."""
result = await service.answer_question(
message="How do I steal a base?",
user_id="user-1",
channel_id="chan-1",
)
assert "5.2.1(b)" in result.cited_rules
assert result.confidence == 0.9
assert result.needs_human is False
assert result.conversation_id # should be a non-empty string
assert result.message_id # should be a non-empty string
async def test_creates_conversation_and_messages(self, service, conversations):
"""The service should persist both user and assistant messages."""
result = await service.answer_question(
message="How do I steal?",
user_id="user-1",
channel_id="chan-1",
)
history = await conversations.get_conversation_history(result.conversation_id)
assert len(history) == 2
assert history[0]["role"] == "user"
assert history[1]["role"] == "assistant"
async def test_continues_existing_conversation(self, service, conversations):
"""Passing a conversation_id should reuse the existing conversation."""
result1 = await service.answer_question(
message="How do I steal?",
user_id="user-1",
channel_id="chan-1",
)
result2 = await service.answer_question(
message="What about pickoffs?",
user_id="user-1",
channel_id="chan-1",
conversation_id=result1.conversation_id,
parent_message_id=result1.message_id,
)
assert result2.conversation_id == result1.conversation_id
history = await conversations.get_conversation_history(result1.conversation_id)
assert len(history) == 4 # 2 user + 2 assistant
async def test_passes_conversation_history_to_llm(self, service, llm):
"""The LLM should receive conversation history for context."""
result1 = await service.answer_question(
message="How do I steal?",
user_id="user-1",
channel_id="chan-1",
)
await service.answer_question(
message="Follow-up question",
user_id="user-1",
channel_id="chan-1",
conversation_id=result1.conversation_id,
)
assert len(llm.calls) == 2
second_call = llm.calls[1]
assert second_call["history"] is not None
assert len(second_call["history"]) >= 2
async def test_searches_rules_with_user_question(self, service, rules_repo):
"""The service should search the rules repo with the user's question."""
await service.answer_question(
message="steal a base",
user_id="u",
channel_id="c",
)
# FakeLLM records what rules it received
# If "steal" and "base" matched, the steal rule should be in there
async def test_sets_parent_message_id(self, service):
"""The result should link the assistant message back to the user message."""
result = await service.answer_question(
message="question",
user_id="u",
channel_id="c",
)
assert result.parent_message_id is not None
class TestChatServiceIssueCreation:
"""When confidence is low or no rules match, a Gitea issue should be created."""
async def test_creates_issue_on_low_confidence(
self, rules_repo, conversations, issues
):
"""When the LLM returns low confidence, an issue is created."""
low_confidence_llm = FakeLLM(default_confidence=0.2)
service = ChatService(
rules=rules_repo,
llm=low_confidence_llm,
conversations=conversations,
issues=issues,
)
await service.answer_question(
message="steal question",
user_id="user-1",
channel_id="chan-1",
)
assert len(issues.issues) == 1
assert issues.issues[0]["question"] == "steal question"
async def test_creates_issue_when_needs_human(
self, rules_repo, conversations, issues
):
"""When LLM says needs_human, an issue is created regardless of confidence."""
llm = FakeLLM(no_rules_confidence=0.1)
service = ChatService(
rules=rules_repo,
llm=llm,
conversations=conversations,
issues=issues,
)
# Use a question that won't match any rules
await service.answer_question(
message="something completely unrelated xyz",
user_id="user-1",
channel_id="chan-1",
)
assert len(issues.issues) == 1
async def test_no_issue_on_high_confidence(self, service, issues):
"""High confidence answers should not create issues."""
await service.answer_question(
message="steal a base",
user_id="user-1",
channel_id="chan-1",
)
assert len(issues.issues) == 0
async def test_no_issue_tracker_configured(self, rules_repo, llm, conversations):
"""If no issue tracker is provided, low confidence should not crash."""
service = ChatService(
rules=rules_repo,
llm=llm,
conversations=conversations,
issues=None,
)
# Should not raise even with low confidence LLM
result = await service.answer_question(
message="steal a base",
user_id="user-1",
channel_id="chan-1",
)
assert result.response
class TestChatServiceErrorHandling:
"""Service should handle adapter failures gracefully."""
async def test_llm_error_propagates(self, rules_repo, conversations, issues):
"""If the LLM raises, the service should let it propagate."""
error_llm = FakeLLM(force_error=RuntimeError("LLM is down"))
service = ChatService(
rules=rules_repo,
llm=error_llm,
conversations=conversations,
issues=issues,
)
with pytest.raises(RuntimeError, match="LLM is down"):
await service.answer_question(
message="steal a base",
user_id="user-1",
channel_id="chan-1",
)
async def test_issue_creation_failure_does_not_crash(
self, rules_repo, conversations
):
"""If the issue tracker fails, the answer should still be returned."""
class FailingIssueTracker(FakeIssueTracker):
async def create_unanswered_issue(self, **kwargs) -> str:
raise RuntimeError("Gitea is down")
low_llm = FakeLLM(default_confidence=0.2)
service = ChatService(
rules=rules_repo,
llm=low_llm,
conversations=conversations,
issues=FailingIssueTracker(),
)
# Should return the answer even though issue creation failed
result = await service.answer_question(
message="steal a base",
user_id="user-1",
channel_id="chan-1",
)
assert result.response

13
tests/fakes/__init__.py Normal file
View File

@ -0,0 +1,13 @@
"""Test fakes — in-memory implementations of domain ports."""
from .fake_rules import FakeRuleRepository
from .fake_llm import FakeLLM
from .fake_conversations import FakeConversationStore
from .fake_issues import FakeIssueTracker
__all__ = [
"FakeRuleRepository",
"FakeLLM",
"FakeConversationStore",
"FakeIssueTracker",
]

View File

@ -0,0 +1,58 @@
"""In-memory ConversationStore for testing — no SQLite, no SQLAlchemy."""
from typing import Optional
import uuid
from domain.ports import ConversationStore
class FakeConversationStore(ConversationStore):
"""Stores conversations and messages in dicts."""
def __init__(self):
self.conversations: dict[str, dict] = {}
self.messages: dict[str, list[dict]] = {}
async def get_or_create_conversation(
self, user_id: str, channel_id: str, conversation_id: Optional[str] = None
) -> str:
if conversation_id and conversation_id in self.conversations:
return conversation_id
new_id = conversation_id or str(uuid.uuid4())
self.conversations[new_id] = {
"user_id": user_id,
"channel_id": channel_id,
}
self.messages[new_id] = []
return new_id
async def add_message(
self,
conversation_id: str,
content: str,
is_user: bool,
parent_id: Optional[str] = None,
) -> str:
message_id = str(uuid.uuid4())
if conversation_id not in self.messages:
self.messages[conversation_id] = []
self.messages[conversation_id].append(
{
"id": message_id,
"content": content,
"is_user": is_user,
"parent_id": parent_id,
}
)
return message_id
async def get_conversation_history(
self, conversation_id: str, limit: int = 10
) -> list[dict[str, str]]:
msgs = self.messages.get(conversation_id, [])
history = []
for msg in msgs[-limit:]:
role = "user" if msg["is_user"] else "assistant"
history.append({"role": role, "content": msg["content"]})
return history

View File

@ -0,0 +1,28 @@
"""In-memory IssueTracker for testing — no Gitea API calls."""
from domain.ports import IssueTracker
class FakeIssueTracker(IssueTracker):
"""Records created issues in a list for assertion."""
def __init__(self):
self.issues: list[dict] = []
async def create_unanswered_issue(
self,
question: str,
user_id: str,
channel_id: str,
attempted_rules: list[str],
conversation_id: str,
) -> str:
issue = {
"question": question,
"user_id": user_id,
"channel_id": channel_id,
"attempted_rules": attempted_rules,
"conversation_id": conversation_id,
}
self.issues.append(issue)
return f"https://gitea.example.com/issues/{len(self.issues)}"

60
tests/fakes/fake_llm.py Normal file
View File

@ -0,0 +1,60 @@
"""In-memory LLM for testing — returns canned responses, no API calls."""
from typing import Optional
from domain.models import RuleSearchResult, LLMResponse
from domain.ports import LLMPort
class FakeLLM(LLMPort):
"""Returns predictable responses based on whether rules were provided.
Configurable for testing specific scenarios (low confidence, errors, etc.).
"""
def __init__(
self,
default_answer: str = "Based on the rules, here is the answer.",
default_confidence: float = 0.9,
no_rules_answer: str = "I don't have a rule that addresses this question.",
no_rules_confidence: float = 0.1,
force_error: Optional[Exception] = None,
):
self.default_answer = default_answer
self.default_confidence = default_confidence
self.no_rules_answer = no_rules_answer
self.no_rules_confidence = no_rules_confidence
self.force_error = force_error
self.calls: list[dict] = []
async def generate_response(
self,
question: str,
rules: list[RuleSearchResult],
conversation_history: Optional[list[dict[str, str]]] = None,
) -> LLMResponse:
self.calls.append(
{
"question": question,
"rules": rules,
"history": conversation_history,
}
)
if self.force_error:
raise self.force_error
if rules:
return LLMResponse(
answer=self.default_answer,
cited_rules=[r.rule_id for r in rules],
confidence=self.default_confidence,
needs_human=False,
)
else:
return LLMResponse(
answer=self.no_rules_answer,
cited_rules=[],
confidence=self.no_rules_confidence,
needs_human=True,
)

52
tests/fakes/fake_rules.py Normal file
View File

@ -0,0 +1,52 @@
"""In-memory RuleRepository for testing — no ChromaDB, no embeddings."""
from typing import Optional
from domain.models import RuleDocument, RuleSearchResult
from domain.ports import RuleRepository
class FakeRuleRepository(RuleRepository):
"""Stores rules in a list; search returns all rules sorted by naive keyword overlap."""
def __init__(self):
self.documents: list[RuleDocument] = []
def add_documents(self, docs: list[RuleDocument]) -> None:
self.documents.extend(docs)
def search(
self, query: str, top_k: int = 10, section_filter: Optional[str] = None
) -> list[RuleSearchResult]:
query_words = set(query.lower().split())
results = []
for doc in self.documents:
if section_filter and doc.section != section_filter:
continue
content_words = set(doc.content.lower().split())
overlap = len(query_words & content_words)
if overlap > 0:
similarity = min(1.0, overlap / max(len(query_words), 1))
results.append(
RuleSearchResult(
rule_id=doc.rule_id,
title=doc.title,
content=doc.content,
section=doc.section,
similarity=similarity,
)
)
results.sort(key=lambda r: r.similarity, reverse=True)
return results[:top_k]
def count(self) -> int:
return len(self.documents)
def clear_all(self) -> None:
self.documents.clear()
def get_stats(self) -> dict:
sections: dict[str, int] = {}
for doc in self.documents:
sections[doc.section] = sections.get(doc.section, 0) + 1
return {"total_rules": len(self.documents), "sections": sections}

3661
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff