cognitive-memory/embeddings.py
Cal Corum 48df2a89ce Initial commit: extract cognitive-memory app from skill directory
Moved application code from ~/.claude/skills/cognitive-memory/ to its own
project directory. The skill layer (SKILL.md, SCHEMA.md) remains in the
skill directory for Claude Code to read.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-28 16:02:28 -06:00

234 lines
8.6 KiB
Python

"""EmbeddingsMixin for CognitiveMemoryClient.
Provides embedding generation and semantic search capabilities. Extracted from
client.py as part of the mixin refactor. Methods rely on shared state (memory_dir,
_load_index, _load_embeddings_cached) provided by the base class via MRO.
"""
import json
import sys
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple
from common import (
EMBEDDING_MODEL,
EMBEDDINGS_PATH,
OPENAI_MODEL_DEFAULT,
_cosine_similarity,
_load_memory_config,
_ollama_embed,
_openai_embed,
)
class EmbeddingsMixin:
"""Mixin providing embedding generation and semantic recall for memory clients."""
def _get_embedding_provider(self) -> Dict[str, Any]:
"""Load embedding config from _config.json."""
return _load_memory_config(self.memory_dir / "_config.json")
def _embed_texts_with_fallback(
self,
texts: List[str],
timeout: int = 300,
) -> Tuple[Optional[List[List[float]]], str, str]:
"""Embed texts with fallback chain. Returns (vectors, provider_used, model_used)."""
config = self._get_embedding_provider()
provider = config.get("embedding_provider", "ollama")
# Try configured provider first
if provider == "openai":
api_key = config.get("openai_api_key")
model = config.get("openai_model", OPENAI_MODEL_DEFAULT)
if api_key:
vectors = _openai_embed(texts, api_key, model, timeout=timeout)
if vectors is not None:
return vectors, "openai", model
# Fallback to ollama
ollama_model = config.get("ollama_model", EMBEDDING_MODEL)
vectors = _ollama_embed(texts, model=ollama_model, timeout=timeout)
if vectors is not None:
return vectors, "ollama", ollama_model
else:
# ollama first
ollama_model = config.get("ollama_model", EMBEDDING_MODEL)
vectors = _ollama_embed(texts, model=ollama_model, timeout=timeout)
if vectors is not None:
return vectors, "ollama", ollama_model
# Fallback to openai
api_key = config.get("openai_api_key")
model = config.get("openai_model", OPENAI_MODEL_DEFAULT)
if api_key:
vectors = _openai_embed(texts, api_key, model, timeout=timeout)
if vectors is not None:
return vectors, "openai", model
return None, "", ""
def embed(self, if_changed: bool = False) -> Dict[str, Any]:
"""Generate embeddings for all memories using configured provider.
Detects provider changes and re-embeds everything (dimension mismatch safety).
Stores vectors in _embeddings.json (not git-tracked).
Args:
if_changed: If True, skip embedding if the set of memory IDs hasn't
changed since last run (no new/deleted memories).
"""
index = self._load_index()
entries = index.get("entries", {})
if not entries:
return {
"embedded": 0,
"provider": "none",
"model": "",
"path": str(EMBEDDINGS_PATH),
}
# Check for provider change
embeddings_path = self.memory_dir / "_embeddings.json"
old_provider = ""
if embeddings_path.exists():
try:
old_data = json.loads(embeddings_path.read_text())
old_provider = old_data.get("provider", "ollama")
except (json.JSONDecodeError, OSError):
pass
config = self._get_embedding_provider()
new_provider = config.get("embedding_provider", "ollama")
provider_changed = old_provider and old_provider != new_provider
if provider_changed:
print(
f"Provider changed ({old_provider} -> {new_provider}), re-embedding all memories...",
file=sys.stderr,
)
# Skip if nothing changed (unless provider switched)
if if_changed and not provider_changed and embeddings_path.exists():
try:
old_data = json.loads(embeddings_path.read_text())
embedded_ids = set(old_data.get("entries", {}).keys())
index_ids = set(entries.keys())
if embedded_ids == index_ids:
return {
"embedded": 0,
"skipped": True,
"reason": "no new or deleted memories",
"path": str(embeddings_path),
}
except (json.JSONDecodeError, OSError):
pass # Can't read old data, re-embed
# Build texts to embed
memory_ids = list(entries.keys())
texts = []
for mid in memory_ids:
entry = entries[mid]
title = entry.get("title", "")
preview = entry.get("content_preview", "")
texts.append(f"{title}. {preview}")
# Batch embed in groups of 50
all_embeddings: Dict[str, List[float]] = {}
batch_size = 50
provider_used = ""
model_used = ""
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
batch_ids = memory_ids[i : i + batch_size]
vectors, provider_used, model_used = self._embed_texts_with_fallback(
batch_texts,
timeout=300,
)
if vectors is None:
return {
"error": "All embedding providers unavailable",
"embedded": len(all_embeddings),
}
for mid, vec in zip(batch_ids, vectors):
all_embeddings[mid] = vec
# Write embeddings file with provider info
embeddings_data = {
"provider": provider_used,
"model": model_used,
"updated": datetime.now(timezone.utc).isoformat(),
"entries": all_embeddings,
}
embeddings_path.write_text(json.dumps(embeddings_data, default=str))
return {
"embedded": len(all_embeddings),
"provider": provider_used,
"model": model_used,
"path": str(embeddings_path),
}
def semantic_recall(self, query: str, limit: int = 10) -> List[Dict[str, Any]]:
"""Search memories by semantic similarity using embeddings.
Uses the same provider that generated stored embeddings to embed the query.
Skips vectors with dimension mismatch as safety guard.
"""
emb_data = self._load_embeddings_cached()
if emb_data is None:
return []
stored = emb_data.get("entries", {})
if not stored:
return []
# Embed query with matching provider
stored_provider = emb_data.get("provider", "ollama")
config = self._get_embedding_provider()
query_vec = None
if stored_provider == "openai":
api_key = config.get("openai_api_key")
model = emb_data.get("model", OPENAI_MODEL_DEFAULT)
if api_key:
vecs = _openai_embed([query], api_key, model)
if vecs:
query_vec = vecs[0]
if query_vec is None and stored_provider == "ollama":
stored_model = emb_data.get("model", EMBEDDING_MODEL)
vecs = _ollama_embed([query], model=stored_model)
if vecs:
query_vec = vecs[0]
# Last resort: try any available provider
if query_vec is None:
vecs, _, _ = self._embed_texts_with_fallback([query], timeout=30)
if vecs:
query_vec = vecs[0]
if query_vec is None:
return []
query_dim = len(query_vec)
# Score all memories by cosine similarity
index = self._load_index()
scored = []
for mid, vec in stored.items():
# Skip dimension mismatch
if len(vec) != query_dim:
continue
sim = _cosine_similarity(query_vec, vec)
entry = index.get("entries", {}).get(mid)
if entry:
scored.append(
{
"id": mid,
"title": entry.get("title", ""),
"type": entry.get("type", "general"),
"tags": entry.get("tags", []),
"similarity": round(sim, 4),
"path": entry.get("path", ""),
}
)
scored.sort(key=lambda x: x["similarity"], reverse=True)
return scored[:limit]