strat-chatbot/app/vector_store.py
Cal Corum c2c7f7d3c2 fix: resolve 4 critical bugs found in code review
- Discord bot: store full conversation UUID in footer instead of truncated
  8-char prefix, fixing completely broken follow-up threading. Add footer
  to follow-up embeds so conversation chains work beyond depth 1. Edit
  loading message in-place instead of leaving ghost messages. Replace bare
  except with specific exception types. Fix channel_id attribute access.
- GiteaClient: remove broken async context manager pattern that caused
  every create_unanswered_issue call to raise RuntimeError. Use per-request
  httpx.AsyncClient instead.
- Database: return singleton ConversationManager from app.state instead of
  creating a new SQLAlchemy engine (and connection pool) on every request.
- Vector store: clamp cosine similarity to [0, 1] to prevent Pydantic
  ValidationError crashes when ChromaDB returns distances > 1.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 15:31:11 -05:00

169 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ChromaDB vector store for rule embeddings."""
from pathlib import Path
from typing import Optional
import chromadb
from chromadb.config import Settings as ChromaSettings
from sentence_transformers import SentenceTransformer
import numpy as np
from .config import settings
from .models import RuleDocument, RuleSearchResult
class VectorStore:
"""Wrapper around ChromaDB for rule retrieval."""
def __init__(self, persist_dir: Path, embedding_model: str):
"""Initialize vector store with embedding model."""
self.persist_dir = Path(persist_dir)
self.persist_dir.mkdir(parents=True, exist_ok=True)
chroma_settings = ChromaSettings(
anonymized_telemetry=False, is_persist_directory_actually_writable=True
)
self.client = chromadb.PersistentClient(
path=str(self.persist_dir), settings=chroma_settings
)
self.embedding_model = SentenceTransformer(embedding_model)
def get_collection(self):
"""Get or create the rules collection."""
return self.client.get_or_create_collection(
name="rules", metadata={"hnsw:space": "cosine"}
)
def add_document(self, doc: RuleDocument) -> None:
"""Add a single rule document to the vector store."""
embedding = self.embedding_model.encode(doc.content).tolist()
collection = self.get_collection()
collection.add(
ids=[doc.metadata.rule_id],
embeddings=[embedding],
documents=[doc.content],
metadatas=[doc.to_chroma_metadata()],
)
def add_documents(self, docs: list[RuleDocument]) -> None:
"""Add multiple documents in batch."""
if not docs:
return
ids = [doc.metadata.rule_id for doc in docs]
contents = [doc.content for doc in docs]
embeddings = self.embedding_model.encode(contents).tolist()
metadatas = [doc.to_chroma_metadata() for doc in docs]
collection = self.get_collection()
collection.add(
ids=ids, embeddings=embeddings, documents=contents, metadatas=metadatas
)
def search(
self, query: str, top_k: int = 10, section_filter: Optional[str] = None
) -> list[RuleSearchResult]:
"""Search for relevant rules using semantic similarity."""
query_embedding = self.embedding_model.encode(query).tolist()
collection = self.get_collection()
where = None
if section_filter:
where = {"section": section_filter}
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
where=where,
include=["documents", "metadatas", "distances"],
)
search_results = []
if results and results["documents"] and results["documents"][0]:
for i in range(len(results["documents"][0])):
metadata = results["metadatas"][0][i]
distance = results["distances"][0][i]
similarity = max(
0.0, min(1.0, 1 - distance)
) # Clamp to [0, 1]: cosine distance ranges 02
search_results.append(
RuleSearchResult(
rule_id=metadata["rule_id"],
title=metadata["title"],
content=results["documents"][0][i],
section=metadata["section"],
similarity=similarity,
)
)
return search_results
def delete_rule(self, rule_id: str) -> None:
"""Remove a rule by its ID."""
collection = self.get_collection()
collection.delete(ids=[rule_id])
def clear_all(self) -> None:
"""Delete all rules from the collection."""
self.client.delete_collection("rules")
self.get_collection() # Recreate empty collection
def get_rule(self, rule_id: str) -> Optional[RuleSearchResult]:
"""Retrieve a specific rule by ID."""
collection = self.get_collection()
result = collection.get(ids=[rule_id], include=["documents", "metadatas"])
if result and result["documents"] and result["documents"][0]:
metadata = result["metadatas"][0][0]
return RuleSearchResult(
rule_id=metadata["rule_id"],
title=metadata["title"],
content=result["documents"][0][0],
section=metadata["section"],
similarity=1.0,
)
return None
def list_all_rules(self) -> list[RuleSearchResult]:
"""Return all rules in the store."""
collection = self.get_collection()
result = collection.get(include=["documents", "metadatas"])
all_rules = []
if result and result["documents"]:
for i in range(len(result["documents"])):
metadata = result["metadatas"][i]
all_rules.append(
RuleSearchResult(
rule_id=metadata["rule_id"],
title=metadata["title"],
content=result["documents"][i],
section=metadata["section"],
similarity=1.0,
)
)
return all_rules
def count(self) -> int:
"""Return the number of rules in the store."""
collection = self.get_collection()
return collection.count()
def get_stats(self) -> dict:
"""Get statistics about the vector store."""
collection = self.get_collection()
all_rules = self.list_all_rules()
sections = {}
for rule in all_rules:
sections[rule.section] = sections.get(rule.section, 0) + 1
return {
"total_rules": len(all_rules),
"sections": sections,
"persist_directory": str(self.persist_dir),
}