"""EmbeddingsMixin for CognitiveMemoryClient. Provides embedding generation and semantic search capabilities. Extracted from client.py as part of the mixin refactor. Methods rely on shared state (memory_dir, _load_index, _load_embeddings_cached) provided by the base class via MRO. """ import json import sys from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Tuple from common import ( EMBEDDING_MODEL, OPENAI_MODEL_DEFAULT, THRESHOLD_DORMANT, _cosine_similarity, _load_memory_config, _ollama_embed, _openai_embed, ) class EmbeddingsMixin: """Mixin providing embedding generation and semantic recall for memory clients.""" def _get_embedding_provider(self) -> Dict[str, Any]: """Load embedding config from _config.json.""" return _load_memory_config(self.memory_dir / "_config.json") def _embed_texts_with_fallback( self, texts: List[str], timeout: int = 300, ) -> Tuple[Optional[List[List[float]]], str, str]: """Embed texts with fallback chain. Returns (vectors, provider_used, model_used).""" config = self._get_embedding_provider() provider = config.get("embedding_provider", "ollama") # Try configured provider first if provider == "openai": api_key = config.get("openai_api_key") model = config.get("openai_model", OPENAI_MODEL_DEFAULT) if api_key: vectors = _openai_embed(texts, api_key, model, timeout=timeout) if vectors is not None: return vectors, "openai", model # Fallback to ollama ollama_model = config.get("ollama_model", EMBEDDING_MODEL) vectors = _ollama_embed(texts, model=ollama_model, timeout=timeout) if vectors is not None: return vectors, "ollama", ollama_model else: # ollama first ollama_model = config.get("ollama_model", EMBEDDING_MODEL) vectors = _ollama_embed(texts, model=ollama_model, timeout=timeout) if vectors is not None: return vectors, "ollama", ollama_model # Fallback to openai api_key = config.get("openai_api_key") model = config.get("openai_model", OPENAI_MODEL_DEFAULT) if api_key: vectors = _openai_embed(texts, api_key, model, timeout=timeout) if vectors is not None: return vectors, "openai", model return None, "", "" def embed(self, if_changed: bool = False) -> Dict[str, Any]: """Generate embeddings for all memories using configured provider. Detects provider changes and re-embeds everything (dimension mismatch safety). Stores vectors in _embeddings.json (not git-tracked). Args: if_changed: If True, skip embedding if the set of memory IDs hasn't changed since last run (no new/deleted memories). """ index = self._load_index() entries = index.get("entries", {}) if not entries: return { "embedded": 0, "provider": "none", "model": "", "path": str(self.memory_dir / "_embeddings.json"), } # Check for provider change embeddings_path = self.memory_dir / "_embeddings.json" old_provider = "" if embeddings_path.exists(): try: old_data = json.loads(embeddings_path.read_text()) old_provider = old_data.get("provider", "ollama") except (json.JSONDecodeError, OSError): pass config = self._get_embedding_provider() new_provider = config.get("embedding_provider", "ollama") provider_changed = old_provider and old_provider != new_provider if provider_changed: print( f"Provider changed ({old_provider} -> {new_provider}), re-embedding all memories...", file=sys.stderr, ) # Skip if nothing changed (unless provider switched) if if_changed and not provider_changed and embeddings_path.exists(): try: old_data = json.loads(embeddings_path.read_text()) embedded_ids = set(old_data.get("entries", {}).keys()) index_ids = set(entries.keys()) if embedded_ids == index_ids: return { "embedded": 0, "skipped": True, "reason": "no new or deleted memories", "path": str(embeddings_path), } except (json.JSONDecodeError, OSError): pass # Can't read old data, re-embed # Build texts to embed memory_ids = list(entries.keys()) texts = [] for mid in memory_ids: entry = entries[mid] title = entry.get("title", "") preview = entry.get("content_preview", "") texts.append(f"{title}. {preview}") # Batch embed in groups of 50 all_embeddings: Dict[str, List[float]] = {} batch_size = 50 provider_used = "" model_used = "" for i in range(0, len(texts), batch_size): batch_texts = texts[i : i + batch_size] batch_ids = memory_ids[i : i + batch_size] vectors, provider_used, model_used = self._embed_texts_with_fallback( batch_texts, timeout=300, ) if vectors is None: return { "error": "All embedding providers unavailable", "embedded": len(all_embeddings), } for mid, vec in zip(batch_ids, vectors): all_embeddings[mid] = vec # Write embeddings file with provider info embeddings_data = { "provider": provider_used, "model": model_used, "updated": datetime.now(timezone.utc).isoformat(), "entries": all_embeddings, } embeddings_path.write_text(json.dumps(embeddings_data, default=str)) return { "embedded": len(all_embeddings), "provider": provider_used, "model": model_used, "path": str(embeddings_path), } def semantic_recall(self, query: str, limit: int = 10) -> List[Dict[str, Any]]: """Search memories by semantic similarity using embeddings. Uses the same provider that generated stored embeddings to embed the query. Skips vectors with dimension mismatch as safety guard. """ emb_data = self._load_embeddings_cached() if emb_data is None: return [] stored = emb_data.get("entries", {}) if not stored: return [] # Embed query with matching provider stored_provider = emb_data.get("provider", "ollama") config = self._get_embedding_provider() query_vec = None if stored_provider == "openai": api_key = config.get("openai_api_key") model = emb_data.get("model", OPENAI_MODEL_DEFAULT) if api_key: vecs = _openai_embed([query], api_key, model) if vecs: query_vec = vecs[0] if query_vec is None and stored_provider == "ollama": stored_model = emb_data.get("model", EMBEDDING_MODEL) vecs = _ollama_embed([query], model=stored_model) if vecs: query_vec = vecs[0] # Last resort: try any available provider if query_vec is None: vecs, _, _ = self._embed_texts_with_fallback([query], timeout=30) if vecs: query_vec = vecs[0] if query_vec is None: return [] query_dim = len(query_vec) # Score all memories by cosine similarity, skipping archived/dormant index = self._load_index() state = self._load_state() scored = [] for mid, vec in stored.items(): s = state.get("entries", {}).get(mid, {}) if s.get("decay_score", 0.5) < THRESHOLD_DORMANT: continue # Skip dimension mismatch if len(vec) != query_dim: continue sim = _cosine_similarity(query_vec, vec) entry = index.get("entries", {}).get(mid) if entry: scored.append( { "id": mid, "title": entry.get("title", ""), "type": entry.get("type", "general"), "tags": entry.get("tags", []), "similarity": round(sim, 4), "path": entry.get("path", ""), } ) scored.sort(key=lambda x: x["similarity"], reverse=True) return scored[:limit]