"""ChromaDB vector store for rule embeddings.""" from pathlib import Path from typing import Optional import chromadb from chromadb.config import Settings as ChromaSettings from sentence_transformers import SentenceTransformer import numpy as np from .config import settings from .models import RuleDocument, RuleSearchResult class VectorStore: """Wrapper around ChromaDB for rule retrieval.""" def __init__(self, persist_dir: Path, embedding_model: str): """Initialize vector store with embedding model.""" self.persist_dir = Path(persist_dir) self.persist_dir.mkdir(parents=True, exist_ok=True) chroma_settings = ChromaSettings( anonymized_telemetry=False, is_persist_directory_actually_writable=True ) self.client = chromadb.PersistentClient( path=str(self.persist_dir), settings=chroma_settings ) self.embedding_model = SentenceTransformer(embedding_model) def get_collection(self): """Get or create the rules collection.""" return self.client.get_or_create_collection( name="rules", metadata={"hnsw:space": "cosine"} ) def add_document(self, doc: RuleDocument) -> None: """Add a single rule document to the vector store.""" embedding = self.embedding_model.encode(doc.content).tolist() collection = self.get_collection() collection.add( ids=[doc.metadata.rule_id], embeddings=[embedding], documents=[doc.content], metadatas=[doc.to_chroma_metadata()], ) def add_documents(self, docs: list[RuleDocument]) -> None: """Add multiple documents in batch.""" if not docs: return ids = [doc.metadata.rule_id for doc in docs] contents = [doc.content for doc in docs] embeddings = self.embedding_model.encode(contents).tolist() metadatas = [doc.to_chroma_metadata() for doc in docs] collection = self.get_collection() collection.add( ids=ids, embeddings=embeddings, documents=contents, metadatas=metadatas ) def search( self, query: str, top_k: int = 10, section_filter: Optional[str] = None ) -> list[RuleSearchResult]: """Search for relevant rules using semantic similarity.""" query_embedding = self.embedding_model.encode(query).tolist() collection = self.get_collection() where = None if section_filter: where = {"section": section_filter} results = collection.query( query_embeddings=[query_embedding], n_results=top_k, where=where, include=["documents", "metadatas", "distances"], ) search_results = [] if results and results["documents"] and results["documents"][0]: for i in range(len(results["documents"][0])): metadata = results["metadatas"][0][i] distance = results["distances"][0][i] similarity = max( 0.0, min(1.0, 1 - distance) ) # Clamp to [0, 1]: cosine distance ranges 0–2 search_results.append( RuleSearchResult( rule_id=metadata["rule_id"], title=metadata["title"], content=results["documents"][0][i], section=metadata["section"], similarity=similarity, ) ) return search_results def delete_rule(self, rule_id: str) -> None: """Remove a rule by its ID.""" collection = self.get_collection() collection.delete(ids=[rule_id]) def clear_all(self) -> None: """Delete all rules from the collection.""" self.client.delete_collection("rules") self.get_collection() # Recreate empty collection def get_rule(self, rule_id: str) -> Optional[RuleSearchResult]: """Retrieve a specific rule by ID.""" collection = self.get_collection() result = collection.get(ids=[rule_id], include=["documents", "metadatas"]) if result and result["documents"] and result["documents"][0]: metadata = result["metadatas"][0][0] return RuleSearchResult( rule_id=metadata["rule_id"], title=metadata["title"], content=result["documents"][0][0], section=metadata["section"], similarity=1.0, ) return None def list_all_rules(self) -> list[RuleSearchResult]: """Return all rules in the store.""" collection = self.get_collection() result = collection.get(include=["documents", "metadatas"]) all_rules = [] if result and result["documents"]: for i in range(len(result["documents"])): metadata = result["metadatas"][i] all_rules.append( RuleSearchResult( rule_id=metadata["rule_id"], title=metadata["title"], content=result["documents"][i], section=metadata["section"], similarity=1.0, ) ) return all_rules def count(self) -> int: """Return the number of rules in the store.""" collection = self.get_collection() return collection.count() def get_stats(self) -> dict: """Get statistics about the vector store.""" collection = self.get_collection() all_rules = self.list_all_rules() sections = {} for rule in all_rules: sections[rule.section] = sections.get(rule.section, 0) + 1 return { "total_rules": len(all_rules), "sections": sections, "persist_directory": str(self.persist_dir), }