3,348-line monolith → 6 modules with mixin classes resolving via MRO. client.py retains __init__, internal helpers, and core CRUD (1,091 lines). All backward-compat imports preserved for mcp_server.py and dev/migrate.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
234 lines
8.6 KiB
Python
234 lines
8.6 KiB
Python
"""EmbeddingsMixin for CognitiveMemoryClient.
|
|
|
|
Provides embedding generation and semantic search capabilities. Extracted from
|
|
client.py as part of the mixin refactor. Methods rely on shared state (memory_dir,
|
|
_load_index, _load_embeddings_cached) provided by the base class via MRO.
|
|
"""
|
|
|
|
import json
|
|
import sys
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from common import (
|
|
EMBEDDING_MODEL,
|
|
EMBEDDINGS_PATH,
|
|
OPENAI_MODEL_DEFAULT,
|
|
_cosine_similarity,
|
|
_load_memory_config,
|
|
_ollama_embed,
|
|
_openai_embed,
|
|
)
|
|
|
|
|
|
class EmbeddingsMixin:
|
|
"""Mixin providing embedding generation and semantic recall for memory clients."""
|
|
|
|
def _get_embedding_provider(self) -> Dict[str, Any]:
|
|
"""Load embedding config from _config.json."""
|
|
return _load_memory_config(self.memory_dir / "_config.json")
|
|
|
|
def _embed_texts_with_fallback(
|
|
self,
|
|
texts: List[str],
|
|
timeout: int = 300,
|
|
) -> Tuple[Optional[List[List[float]]], str, str]:
|
|
"""Embed texts with fallback chain. Returns (vectors, provider_used, model_used)."""
|
|
config = self._get_embedding_provider()
|
|
provider = config.get("embedding_provider", "ollama")
|
|
|
|
# Try configured provider first
|
|
if provider == "openai":
|
|
api_key = config.get("openai_api_key")
|
|
model = config.get("openai_model", OPENAI_MODEL_DEFAULT)
|
|
if api_key:
|
|
vectors = _openai_embed(texts, api_key, model, timeout=timeout)
|
|
if vectors is not None:
|
|
return vectors, "openai", model
|
|
# Fallback to ollama
|
|
ollama_model = config.get("ollama_model", EMBEDDING_MODEL)
|
|
vectors = _ollama_embed(texts, model=ollama_model, timeout=timeout)
|
|
if vectors is not None:
|
|
return vectors, "ollama", ollama_model
|
|
else:
|
|
# ollama first
|
|
ollama_model = config.get("ollama_model", EMBEDDING_MODEL)
|
|
vectors = _ollama_embed(texts, model=ollama_model, timeout=timeout)
|
|
if vectors is not None:
|
|
return vectors, "ollama", ollama_model
|
|
# Fallback to openai
|
|
api_key = config.get("openai_api_key")
|
|
model = config.get("openai_model", OPENAI_MODEL_DEFAULT)
|
|
if api_key:
|
|
vectors = _openai_embed(texts, api_key, model, timeout=timeout)
|
|
if vectors is not None:
|
|
return vectors, "openai", model
|
|
|
|
return None, "", ""
|
|
|
|
def embed(self, if_changed: bool = False) -> Dict[str, Any]:
|
|
"""Generate embeddings for all memories using configured provider.
|
|
|
|
Detects provider changes and re-embeds everything (dimension mismatch safety).
|
|
Stores vectors in _embeddings.json (not git-tracked).
|
|
|
|
Args:
|
|
if_changed: If True, skip embedding if the set of memory IDs hasn't
|
|
changed since last run (no new/deleted memories).
|
|
"""
|
|
index = self._load_index()
|
|
entries = index.get("entries", {})
|
|
if not entries:
|
|
return {
|
|
"embedded": 0,
|
|
"provider": "none",
|
|
"model": "",
|
|
"path": str(EMBEDDINGS_PATH),
|
|
}
|
|
|
|
# Check for provider change
|
|
embeddings_path = self.memory_dir / "_embeddings.json"
|
|
old_provider = ""
|
|
if embeddings_path.exists():
|
|
try:
|
|
old_data = json.loads(embeddings_path.read_text())
|
|
old_provider = old_data.get("provider", "ollama")
|
|
except (json.JSONDecodeError, OSError):
|
|
pass
|
|
|
|
config = self._get_embedding_provider()
|
|
new_provider = config.get("embedding_provider", "ollama")
|
|
provider_changed = old_provider and old_provider != new_provider
|
|
if provider_changed:
|
|
print(
|
|
f"Provider changed ({old_provider} -> {new_provider}), re-embedding all memories...",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
# Skip if nothing changed (unless provider switched)
|
|
if if_changed and not provider_changed and embeddings_path.exists():
|
|
try:
|
|
old_data = json.loads(embeddings_path.read_text())
|
|
embedded_ids = set(old_data.get("entries", {}).keys())
|
|
index_ids = set(entries.keys())
|
|
if embedded_ids == index_ids:
|
|
return {
|
|
"embedded": 0,
|
|
"skipped": True,
|
|
"reason": "no new or deleted memories",
|
|
"path": str(embeddings_path),
|
|
}
|
|
except (json.JSONDecodeError, OSError):
|
|
pass # Can't read old data, re-embed
|
|
|
|
# Build texts to embed
|
|
memory_ids = list(entries.keys())
|
|
texts = []
|
|
for mid in memory_ids:
|
|
entry = entries[mid]
|
|
title = entry.get("title", "")
|
|
preview = entry.get("content_preview", "")
|
|
texts.append(f"{title}. {preview}")
|
|
|
|
# Batch embed in groups of 50
|
|
all_embeddings: Dict[str, List[float]] = {}
|
|
batch_size = 50
|
|
provider_used = ""
|
|
model_used = ""
|
|
for i in range(0, len(texts), batch_size):
|
|
batch_texts = texts[i : i + batch_size]
|
|
batch_ids = memory_ids[i : i + batch_size]
|
|
vectors, provider_used, model_used = self._embed_texts_with_fallback(
|
|
batch_texts,
|
|
timeout=300,
|
|
)
|
|
if vectors is None:
|
|
return {
|
|
"error": "All embedding providers unavailable",
|
|
"embedded": len(all_embeddings),
|
|
}
|
|
for mid, vec in zip(batch_ids, vectors):
|
|
all_embeddings[mid] = vec
|
|
|
|
# Write embeddings file with provider info
|
|
embeddings_data = {
|
|
"provider": provider_used,
|
|
"model": model_used,
|
|
"updated": datetime.now(timezone.utc).isoformat(),
|
|
"entries": all_embeddings,
|
|
}
|
|
embeddings_path.write_text(json.dumps(embeddings_data, default=str))
|
|
|
|
return {
|
|
"embedded": len(all_embeddings),
|
|
"provider": provider_used,
|
|
"model": model_used,
|
|
"path": str(embeddings_path),
|
|
}
|
|
|
|
def semantic_recall(self, query: str, limit: int = 10) -> List[Dict[str, Any]]:
|
|
"""Search memories by semantic similarity using embeddings.
|
|
|
|
Uses the same provider that generated stored embeddings to embed the query.
|
|
Skips vectors with dimension mismatch as safety guard.
|
|
"""
|
|
emb_data = self._load_embeddings_cached()
|
|
if emb_data is None:
|
|
return []
|
|
|
|
stored = emb_data.get("entries", {})
|
|
if not stored:
|
|
return []
|
|
|
|
# Embed query with matching provider
|
|
stored_provider = emb_data.get("provider", "ollama")
|
|
config = self._get_embedding_provider()
|
|
query_vec = None
|
|
|
|
if stored_provider == "openai":
|
|
api_key = config.get("openai_api_key")
|
|
model = emb_data.get("model", OPENAI_MODEL_DEFAULT)
|
|
if api_key:
|
|
vecs = _openai_embed([query], api_key, model)
|
|
if vecs:
|
|
query_vec = vecs[0]
|
|
if query_vec is None and stored_provider == "ollama":
|
|
stored_model = emb_data.get("model", EMBEDDING_MODEL)
|
|
vecs = _ollama_embed([query], model=stored_model)
|
|
if vecs:
|
|
query_vec = vecs[0]
|
|
# Last resort: try any available provider
|
|
if query_vec is None:
|
|
vecs, _, _ = self._embed_texts_with_fallback([query], timeout=30)
|
|
if vecs:
|
|
query_vec = vecs[0]
|
|
|
|
if query_vec is None:
|
|
return []
|
|
|
|
query_dim = len(query_vec)
|
|
|
|
# Score all memories by cosine similarity
|
|
index = self._load_index()
|
|
scored = []
|
|
for mid, vec in stored.items():
|
|
# Skip dimension mismatch
|
|
if len(vec) != query_dim:
|
|
continue
|
|
sim = _cosine_similarity(query_vec, vec)
|
|
entry = index.get("entries", {}).get(mid)
|
|
if entry:
|
|
scored.append(
|
|
{
|
|
"id": mid,
|
|
"title": entry.get("title", ""),
|
|
"type": entry.get("type", "general"),
|
|
"tags": entry.get("tags", []),
|
|
"similarity": round(sim, 4),
|
|
"path": entry.get("path", ""),
|
|
}
|
|
)
|
|
|
|
scored.sort(key=lambda x: x["similarity"], reverse=True)
|
|
return scored[:limit]
|