"""ChromaDB outbound adapter implementing the RuleRepository port.""" from __future__ import annotations import logging from pathlib import Path from typing import Optional import chromadb from chromadb.config import Settings as ChromaSettings from sentence_transformers import SentenceTransformer from domain.models import RuleDocument, RuleSearchResult from domain.ports import RuleRepository logger = logging.getLogger(__name__) _COLLECTION_NAME = "rules" class ChromaRuleRepository(RuleRepository): """Persist and search rules in a ChromaDB vector store. Parameters ---------- persist_dir: Directory that ChromaDB uses for on-disk persistence. Created automatically if it does not exist. embedding_model: HuggingFace / sentence-transformers model name used to encode documents and queries (e.g. ``"all-MiniLM-L6-v2"``). """ def __init__(self, persist_dir: Path, embedding_model: str) -> None: self.persist_dir = Path(persist_dir) self.persist_dir.mkdir(parents=True, exist_ok=True) chroma_settings = ChromaSettings( anonymized_telemetry=False, is_persist_directory_actually_writable=True, ) self._client = chromadb.PersistentClient( path=str(self.persist_dir), settings=chroma_settings, ) logger.info("Loading embedding model '%s'", embedding_model) self._encoder = SentenceTransformer(embedding_model) logger.info("ChromaRuleRepository ready (persist_dir=%s)", self.persist_dir) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _get_collection(self): """Return the rules collection, creating it if absent.""" return self._client.get_or_create_collection( name=_COLLECTION_NAME, metadata={"hnsw:space": "cosine"}, ) @staticmethod def _distance_to_similarity(distance: float) -> float: """Convert a cosine distance in [0, 2] to a similarity in [0.0, 1.0]. ChromaDB stores cosine *distance* (0 = identical, 2 = opposite). The conversion is ``similarity = 1 - distance``, but floating-point noise can push the result slightly outside [0, 1], so we clamp. """ return max(0.0, min(1.0, 1.0 - distance)) # ------------------------------------------------------------------ # RuleRepository port implementation # ------------------------------------------------------------------ def add_documents(self, docs: list[RuleDocument]) -> None: """Embed and store a batch of RuleDocuments. Calling with an empty list is a no-op. """ if not docs: return logger.debug("Encoding %d document(s)", len(docs)) ids = [doc.rule_id for doc in docs] contents = [doc.content for doc in docs] metadatas = [doc.to_metadata() for doc in docs] # SentenceTransformer.encode returns a numpy array; .tolist() gives # a plain Python list which ChromaDB accepts. embeddings = self._encoder.encode(contents).tolist() collection = self._get_collection() collection.add( ids=ids, embeddings=embeddings, documents=contents, metadatas=metadatas, ) logger.info("Stored %d rule(s) in ChromaDB", len(docs)) def search( self, query: str, top_k: int = 10, section_filter: Optional[str] = None, ) -> list[RuleSearchResult]: """Return the *top_k* most semantically similar rules for *query*. Parameters ---------- query: Natural-language question or keyword string. top_k: Maximum number of results to return. section_filter: When provided, only documents whose ``section`` metadata field equals this value are considered. Returns ------- list[RuleSearchResult] Sorted by descending similarity (best match first). Returns an empty list if the collection is empty. """ collection = self._get_collection() doc_count = collection.count() if doc_count == 0: return [] # Clamp top_k so we never ask ChromaDB for more results than exist. effective_k = min(top_k, doc_count) query_embedding = self._encoder.encode(query).tolist() where = {"section": section_filter} if section_filter else None logger.debug( "Querying ChromaDB: top_k=%d, section_filter=%r", effective_k, section_filter, ) raw = collection.query( query_embeddings=[query_embedding], n_results=effective_k, where=where, include=["documents", "metadatas", "distances"], ) results: list[RuleSearchResult] = [] if raw and raw["documents"] and raw["documents"][0]: for i, doc_content in enumerate(raw["documents"][0]): metadata = raw["metadatas"][0][i] distance = raw["distances"][0][i] similarity = self._distance_to_similarity(distance) results.append( RuleSearchResult( rule_id=metadata["rule_id"], title=metadata["title"], content=doc_content, section=metadata["section"], similarity=similarity, ) ) logger.debug("Search returned %d result(s)", len(results)) return results def count(self) -> int: """Return the total number of rule documents in the collection.""" return self._get_collection().count() def clear_all(self) -> None: """Delete all documents by dropping and recreating the collection.""" logger.info( "Clearing all rules from ChromaDB collection '%s'", _COLLECTION_NAME ) self._client.delete_collection(_COLLECTION_NAME) self._get_collection() # Recreate so subsequent calls do not fail. def get_stats(self) -> dict: """Return a summary dict with total rule count, per-section counts, and path. Returns ------- dict with keys: ``total_rules`` (int), ``sections`` (dict[str, int]), ``persist_directory`` (str) """ collection = self._get_collection() raw = collection.get(include=["metadatas"]) sections: dict[str, int] = {} for metadata in raw.get("metadatas") or []: section = metadata.get("section", "") sections[section] = sections.get(section, 0) + 1 return { "total_rules": collection.count(), "sections": sections, "persist_directory": str(self.persist_dir), }