- Discord bot: store full conversation UUID in footer instead of truncated 8-char prefix, fixing completely broken follow-up threading. Add footer to follow-up embeds so conversation chains work beyond depth 1. Edit loading message in-place instead of leaving ghost messages. Replace bare except with specific exception types. Fix channel_id attribute access. - GiteaClient: remove broken async context manager pattern that caused every create_unanswered_issue call to raise RuntimeError. Use per-request httpx.AsyncClient instead. - Database: return singleton ConversationManager from app.state instead of creating a new SQLAlchemy engine (and connection pool) on every request. - Vector store: clamp cosine similarity to [0, 1] to prevent Pydantic ValidationError crashes when ChromaDB returns distances > 1.0. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
169 lines
5.9 KiB
Python
169 lines
5.9 KiB
Python
"""ChromaDB vector store for rule embeddings."""
|
||
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
import chromadb
|
||
from chromadb.config import Settings as ChromaSettings
|
||
from sentence_transformers import SentenceTransformer
|
||
import numpy as np
|
||
from .config import settings
|
||
from .models import RuleDocument, RuleSearchResult
|
||
|
||
|
||
class VectorStore:
|
||
"""Wrapper around ChromaDB for rule retrieval."""
|
||
|
||
def __init__(self, persist_dir: Path, embedding_model: str):
|
||
"""Initialize vector store with embedding model."""
|
||
self.persist_dir = Path(persist_dir)
|
||
self.persist_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
chroma_settings = ChromaSettings(
|
||
anonymized_telemetry=False, is_persist_directory_actually_writable=True
|
||
)
|
||
|
||
self.client = chromadb.PersistentClient(
|
||
path=str(self.persist_dir), settings=chroma_settings
|
||
)
|
||
|
||
self.embedding_model = SentenceTransformer(embedding_model)
|
||
|
||
def get_collection(self):
|
||
"""Get or create the rules collection."""
|
||
return self.client.get_or_create_collection(
|
||
name="rules", metadata={"hnsw:space": "cosine"}
|
||
)
|
||
|
||
def add_document(self, doc: RuleDocument) -> None:
|
||
"""Add a single rule document to the vector store."""
|
||
embedding = self.embedding_model.encode(doc.content).tolist()
|
||
|
||
collection = self.get_collection()
|
||
collection.add(
|
||
ids=[doc.metadata.rule_id],
|
||
embeddings=[embedding],
|
||
documents=[doc.content],
|
||
metadatas=[doc.to_chroma_metadata()],
|
||
)
|
||
|
||
def add_documents(self, docs: list[RuleDocument]) -> None:
|
||
"""Add multiple documents in batch."""
|
||
if not docs:
|
||
return
|
||
|
||
ids = [doc.metadata.rule_id for doc in docs]
|
||
contents = [doc.content for doc in docs]
|
||
embeddings = self.embedding_model.encode(contents).tolist()
|
||
metadatas = [doc.to_chroma_metadata() for doc in docs]
|
||
|
||
collection = self.get_collection()
|
||
collection.add(
|
||
ids=ids, embeddings=embeddings, documents=contents, metadatas=metadatas
|
||
)
|
||
|
||
def search(
|
||
self, query: str, top_k: int = 10, section_filter: Optional[str] = None
|
||
) -> list[RuleSearchResult]:
|
||
"""Search for relevant rules using semantic similarity."""
|
||
query_embedding = self.embedding_model.encode(query).tolist()
|
||
|
||
collection = self.get_collection()
|
||
|
||
where = None
|
||
if section_filter:
|
||
where = {"section": section_filter}
|
||
|
||
results = collection.query(
|
||
query_embeddings=[query_embedding],
|
||
n_results=top_k,
|
||
where=where,
|
||
include=["documents", "metadatas", "distances"],
|
||
)
|
||
|
||
search_results = []
|
||
if results and results["documents"] and results["documents"][0]:
|
||
for i in range(len(results["documents"][0])):
|
||
metadata = results["metadatas"][0][i]
|
||
distance = results["distances"][0][i]
|
||
similarity = max(
|
||
0.0, min(1.0, 1 - distance)
|
||
) # Clamp to [0, 1]: cosine distance ranges 0–2
|
||
|
||
search_results.append(
|
||
RuleSearchResult(
|
||
rule_id=metadata["rule_id"],
|
||
title=metadata["title"],
|
||
content=results["documents"][0][i],
|
||
section=metadata["section"],
|
||
similarity=similarity,
|
||
)
|
||
)
|
||
|
||
return search_results
|
||
|
||
def delete_rule(self, rule_id: str) -> None:
|
||
"""Remove a rule by its ID."""
|
||
collection = self.get_collection()
|
||
collection.delete(ids=[rule_id])
|
||
|
||
def clear_all(self) -> None:
|
||
"""Delete all rules from the collection."""
|
||
self.client.delete_collection("rules")
|
||
self.get_collection() # Recreate empty collection
|
||
|
||
def get_rule(self, rule_id: str) -> Optional[RuleSearchResult]:
|
||
"""Retrieve a specific rule by ID."""
|
||
collection = self.get_collection()
|
||
result = collection.get(ids=[rule_id], include=["documents", "metadatas"])
|
||
|
||
if result and result["documents"] and result["documents"][0]:
|
||
metadata = result["metadatas"][0][0]
|
||
return RuleSearchResult(
|
||
rule_id=metadata["rule_id"],
|
||
title=metadata["title"],
|
||
content=result["documents"][0][0],
|
||
section=metadata["section"],
|
||
similarity=1.0,
|
||
)
|
||
return None
|
||
|
||
def list_all_rules(self) -> list[RuleSearchResult]:
|
||
"""Return all rules in the store."""
|
||
collection = self.get_collection()
|
||
result = collection.get(include=["documents", "metadatas"])
|
||
|
||
all_rules = []
|
||
if result and result["documents"]:
|
||
for i in range(len(result["documents"])):
|
||
metadata = result["metadatas"][i]
|
||
all_rules.append(
|
||
RuleSearchResult(
|
||
rule_id=metadata["rule_id"],
|
||
title=metadata["title"],
|
||
content=result["documents"][i],
|
||
section=metadata["section"],
|
||
similarity=1.0,
|
||
)
|
||
)
|
||
|
||
return all_rules
|
||
|
||
def count(self) -> int:
|
||
"""Return the number of rules in the store."""
|
||
collection = self.get_collection()
|
||
return collection.count()
|
||
|
||
def get_stats(self) -> dict:
|
||
"""Get statistics about the vector store."""
|
||
collection = self.get_collection()
|
||
all_rules = self.list_all_rules()
|
||
sections = {}
|
||
for rule in all_rules:
|
||
sections[rule.section] = sections.get(rule.section, 0) + 1
|
||
|
||
return {
|
||
"total_rules": len(all_rules),
|
||
"sections": sections,
|
||
"persist_directory": str(self.persist_dir),
|
||
}
|