- Add vector store with sentence-transformers for semantic search - FastAPI backend with /chat and /health endpoints - Conversation state persistence via SQLite - OpenRouter integration with structured JSON responses - Discord bot with /ask slash command and reply-based follow-ups - Automated Gitea issue creation for unanswered questions - Docker support with docker-compose for easy deployment - Example rule file and ingestion script - Comprehensive documentation in README
167 lines
5.8 KiB
Python
167 lines
5.8 KiB
Python
"""ChromaDB vector store for rule embeddings."""
|
|
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
import chromadb
|
|
from chromadb.config import Settings as ChromaSettings
|
|
from sentence_transformers import SentenceTransformer
|
|
import numpy as np
|
|
from .config import settings
|
|
from .models import RuleDocument, RuleSearchResult
|
|
|
|
|
|
class VectorStore:
|
|
"""Wrapper around ChromaDB for rule retrieval."""
|
|
|
|
def __init__(self, persist_dir: Path, embedding_model: str):
|
|
"""Initialize vector store with embedding model."""
|
|
self.persist_dir = Path(persist_dir)
|
|
self.persist_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
chroma_settings = ChromaSettings(
|
|
anonymized_telemetry=False, is_persist_directory_actually_writable=True
|
|
)
|
|
|
|
self.client = chromadb.PersistentClient(
|
|
path=str(self.persist_dir), settings=chroma_settings
|
|
)
|
|
|
|
self.embedding_model = SentenceTransformer(embedding_model)
|
|
|
|
def get_collection(self):
|
|
"""Get or create the rules collection."""
|
|
return self.client.get_or_create_collection(
|
|
name="rules", metadata={"hnsw:space": "cosine"}
|
|
)
|
|
|
|
def add_document(self, doc: RuleDocument) -> None:
|
|
"""Add a single rule document to the vector store."""
|
|
embedding = self.embedding_model.encode(doc.content).tolist()
|
|
|
|
collection = self.get_collection()
|
|
collection.add(
|
|
ids=[doc.metadata.rule_id],
|
|
embeddings=[embedding],
|
|
documents=[doc.content],
|
|
metadatas=[doc.to_chroma_metadata()],
|
|
)
|
|
|
|
def add_documents(self, docs: list[RuleDocument]) -> None:
|
|
"""Add multiple documents in batch."""
|
|
if not docs:
|
|
return
|
|
|
|
ids = [doc.metadata.rule_id for doc in docs]
|
|
contents = [doc.content for doc in docs]
|
|
embeddings = self.embedding_model.encode(contents).tolist()
|
|
metadatas = [doc.to_chroma_metadata() for doc in docs]
|
|
|
|
collection = self.get_collection()
|
|
collection.add(
|
|
ids=ids, embeddings=embeddings, documents=contents, metadatas=metadatas
|
|
)
|
|
|
|
def search(
|
|
self, query: str, top_k: int = 10, section_filter: Optional[str] = None
|
|
) -> list[RuleSearchResult]:
|
|
"""Search for relevant rules using semantic similarity."""
|
|
query_embedding = self.embedding_model.encode(query).tolist()
|
|
|
|
collection = self.get_collection()
|
|
|
|
where = None
|
|
if section_filter:
|
|
where = {"section": section_filter}
|
|
|
|
results = collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=top_k,
|
|
where=where,
|
|
include=["documents", "metadatas", "distances"],
|
|
)
|
|
|
|
search_results = []
|
|
if results and results["documents"] and results["documents"][0]:
|
|
for i in range(len(results["documents"][0])):
|
|
metadata = results["metadatas"][0][i]
|
|
distance = results["distances"][0][i]
|
|
similarity = 1 - distance # Convert cosine distance to similarity
|
|
|
|
search_results.append(
|
|
RuleSearchResult(
|
|
rule_id=metadata["rule_id"],
|
|
title=metadata["title"],
|
|
content=results["documents"][0][i],
|
|
section=metadata["section"],
|
|
similarity=similarity,
|
|
)
|
|
)
|
|
|
|
return search_results
|
|
|
|
def delete_rule(self, rule_id: str) -> None:
|
|
"""Remove a rule by its ID."""
|
|
collection = self.get_collection()
|
|
collection.delete(ids=[rule_id])
|
|
|
|
def clear_all(self) -> None:
|
|
"""Delete all rules from the collection."""
|
|
self.client.delete_collection("rules")
|
|
self.get_collection() # Recreate empty collection
|
|
|
|
def get_rule(self, rule_id: str) -> Optional[RuleSearchResult]:
|
|
"""Retrieve a specific rule by ID."""
|
|
collection = self.get_collection()
|
|
result = collection.get(ids=[rule_id], include=["documents", "metadatas"])
|
|
|
|
if result and result["documents"] and result["documents"][0]:
|
|
metadata = result["metadatas"][0][0]
|
|
return RuleSearchResult(
|
|
rule_id=metadata["rule_id"],
|
|
title=metadata["title"],
|
|
content=result["documents"][0][0],
|
|
section=metadata["section"],
|
|
similarity=1.0,
|
|
)
|
|
return None
|
|
|
|
def list_all_rules(self) -> list[RuleSearchResult]:
|
|
"""Return all rules in the store."""
|
|
collection = self.get_collection()
|
|
result = collection.get(include=["documents", "metadatas"])
|
|
|
|
all_rules = []
|
|
if result and result["documents"]:
|
|
for i in range(len(result["documents"])):
|
|
metadata = result["metadatas"][i]
|
|
all_rules.append(
|
|
RuleSearchResult(
|
|
rule_id=metadata["rule_id"],
|
|
title=metadata["title"],
|
|
content=result["documents"][i],
|
|
section=metadata["section"],
|
|
similarity=1.0,
|
|
)
|
|
)
|
|
|
|
return all_rules
|
|
|
|
def count(self) -> int:
|
|
"""Return the number of rules in the store."""
|
|
collection = self.get_collection()
|
|
return collection.count()
|
|
|
|
def get_stats(self) -> dict:
|
|
"""Get statistics about the vector store."""
|
|
collection = self.get_collection()
|
|
all_rules = self.list_all_rules()
|
|
sections = {}
|
|
for rule in all_rules:
|
|
sections[rule.section] = sections.get(rule.section, 0) + 1
|
|
|
|
return {
|
|
"total_rules": len(all_rules),
|
|
"sections": sections,
|
|
"persist_directory": str(self.persist_dir),
|
|
}
|