- Add similarity/ package (compute.py, lookup.py) with numpy-based pairwise cosine similarity and cached lookup - database.py: create embeddings + similarity_cache tables in _init_database(), add store_similarity_batch/get_cached_similarities/clear_similarity_cache helpers - pipeline/fusion.py: replace N+1 per-motion embedding SELECT with single bulk JOIN using DuckDB QUALIFY window function - ai_provider.py: retry HTTP 429 with Retry-After header support - migrations/2026-03-22-add-similarity-cache.sql: make executable - Add tests for similarity compute, db helpers, and 429 retry (34 pass, 2 skip)main
parent
a248807e03
commit
a78bee9b0a
@ -1,15 +1,19 @@ |
||||
-- 2026-03-22-add-similarity-cache.sql |
||||
-- Placeholder migration for adding a similarity_cache table |
||||
-- Decision: Keep SQL commented out so CI does not accidentally modify databases. |
||||
-- 2026-03-22-add-similarity-cache.sql - similarity migration |
||||
-- This migration creates a sequence and the similarity_cache table. |
||||
|
||||
/* |
||||
-- Example (commented out): |
||||
CREATE TABLE similarity_cache ( |
||||
id SERIAL PRIMARY KEY, |
||||
key TEXT NOT NULL, |
||||
vector FLOAT8[] NOT NULL, |
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT now() |
||||
-- Create a sequence for generating integer ids (DuckDB compatible) |
||||
CREATE SEQUENCE IF NOT EXISTS similarity_cache_id_seq START 1; |
||||
|
||||
-- Create the similarity_cache table. |
||||
CREATE TABLE IF NOT EXISTS similarity_cache ( |
||||
id INTEGER DEFAULT nextval('similarity_cache_id_seq'), |
||||
source_motion_id INTEGER NOT NULL, |
||||
target_motion_id INTEGER NOT NULL, |
||||
score REAL NOT NULL, |
||||
vector_type TEXT NOT NULL, |
||||
window_id TEXT, |
||||
created_at TIMESTAMP DEFAULT current_timestamp |
||||
); |
||||
*/ |
||||
|
||||
-- No executable SQL in this file. Intentionally left as a safe no-op. |
||||
-- Optionally create an index to speed lookups (no-op if already exists) |
||||
CREATE INDEX IF NOT EXISTS idx_similarity_cache_source_target ON similarity_cache (source_motion_id, target_motion_id); |
||||
|
||||
@ -0,0 +1,18 @@ |
||||
from importlib import import_module |
||||
from typing import Any |
||||
|
||||
__all__ = ["compute_similarities", "get_similar_motions"] |
||||
|
||||
|
||||
def _lazy_import(module_name: str): |
||||
return import_module(f".{module_name}", __package__) |
||||
|
||||
|
||||
def compute_similarities(*args: Any, **kwargs: Any) -> Any: |
||||
module = _lazy_import("compute") |
||||
return getattr(module, "compute_similarities")(*args, **kwargs) |
||||
|
||||
|
||||
def get_similar_motions(*args: Any, **kwargs: Any) -> Any: |
||||
module = _lazy_import("lookup") |
||||
return getattr(module, "get_similar_motions")(*args, **kwargs) |
||||
@ -0,0 +1,214 @@ |
||||
import json |
||||
import logging |
||||
from typing import List, Optional |
||||
|
||||
import numpy as np |
||||
|
||||
from database import MotionDatabase |
||||
|
||||
|
||||
logger = logging.getLogger(__name__) |
||||
|
||||
|
||||
def compute_similarities( |
||||
vector_type: str = "fused", |
||||
window_id: Optional[str] = None, |
||||
top_k: int = 10, |
||||
db_path: Optional[str] = None, |
||||
): |
||||
"""Compute pairwise cosine similarities for vectors of a given type and store top-k neighbors. |
||||
|
||||
Returns number of inserted rows. |
||||
""" |
||||
db = MotionDatabase(db_path=db_path) if db_path is not None else MotionDatabase() |
||||
|
||||
# Build SQL query depending on vector type |
||||
if vector_type == "fused": |
||||
if window_id is not None: |
||||
query = "SELECT motion_id AS id, vector FROM fused_embeddings WHERE window_id = ?" |
||||
params = (window_id,) |
||||
else: |
||||
# fallback to all fused embeddings if no window specified |
||||
query = "SELECT motion_id AS id, vector FROM fused_embeddings" |
||||
params = () |
||||
elif vector_type == "text": |
||||
query = "SELECT motion_id AS id, vector FROM embeddings" |
||||
params = () |
||||
elif vector_type == "svd": |
||||
if window_id is not None: |
||||
query = "SELECT entity_id AS id, vector FROM svd_vectors WHERE entity_type = 'motion' AND window_id = ?" |
||||
params = (window_id,) |
||||
else: |
||||
query = "SELECT entity_id AS id, vector FROM svd_vectors WHERE entity_type = 'motion'" |
||||
params = () |
||||
else: |
||||
logger.error(f"Unknown vector_type: {vector_type}") |
||||
return 0 |
||||
|
||||
# Load vectors in a single query |
||||
try: |
||||
try: |
||||
import duckdb |
||||
except Exception: |
||||
logger.exception("duckdb import failed; cannot load vectors") |
||||
return 0 |
||||
|
||||
with duckdb.connect(db.db_path) as conn: |
||||
rows = conn.execute(query, params).fetchall() |
||||
except Exception: |
||||
logger.exception("Error loading vectors for similarity compute") |
||||
return 0 |
||||
|
||||
if not rows: |
||||
logger.info("No vectors found for %s window=%s", vector_type, window_id) |
||||
return 0 |
||||
|
||||
ids: List[int] = [] |
||||
vecs: List[List[float]] = [] |
||||
|
||||
for r in rows: |
||||
_id, vec_json = r |
||||
# parse vector robustly: accept list/tuple, bytes/bytearray, or JSON string |
||||
vec = None |
||||
if isinstance(vec_json, (list, tuple)): |
||||
vec = list(vec_json) |
||||
elif isinstance(vec_json, (bytes, bytearray)): |
||||
try: |
||||
text = vec_json.decode("utf-8") |
||||
except Exception: |
||||
logger.warning( |
||||
"Skipping row with non-decodable bytes vector for id=%s", _id |
||||
) |
||||
continue |
||||
try: |
||||
vec = json.loads(text) |
||||
except Exception: |
||||
logger.warning( |
||||
"Skipping row with invalid JSON bytes vector for id=%s", _id |
||||
) |
||||
continue |
||||
elif isinstance(vec_json, str): |
||||
try: |
||||
vec = json.loads(vec_json) |
||||
except Exception: |
||||
logger.warning( |
||||
"Skipping row with invalid JSON string vector for id=%s", _id |
||||
) |
||||
continue |
||||
else: |
||||
logger.warning( |
||||
"Skipping row with unsupported vector type %s for id=%s", |
||||
type(vec_json), |
||||
_id, |
||||
) |
||||
continue |
||||
|
||||
# ensure numeric conversion |
||||
try: |
||||
vec_floats = [float(x) for x in vec] |
||||
except Exception: |
||||
logger.warning( |
||||
"Skipping row with non-numeric vector entries for id=%s", _id |
||||
) |
||||
continue |
||||
|
||||
# cast id to int for consistency; skip if cannot cast |
||||
try: |
||||
ids.append(int(_id)) |
||||
except Exception: |
||||
logger.warning("Skipping row with non-integer id=%s", _id) |
||||
continue |
||||
|
||||
vecs.append(vec_floats) |
||||
|
||||
if not vecs: |
||||
logger.info( |
||||
"No valid vectors after parsing for %s window=%s", vector_type, window_id |
||||
) |
||||
return 0 |
||||
|
||||
# Ensure consistent dimensionality: pad shorter vectors with zeros |
||||
lengths = [len(v) for v in vecs] |
||||
max_dim = max(lengths) |
||||
if len(set(lengths)) != 1: |
||||
logger.warning( |
||||
"Inconsistent vector dimensions detected (max=%d). Padding shorter vectors with zeros.", |
||||
max_dim, |
||||
) |
||||
|
||||
matrix = np.zeros((len(vecs), max_dim), dtype=np.float32) |
||||
for i, v in enumerate(vecs): |
||||
matrix[i, : len(v)] = v |
||||
|
||||
# Normalize rows |
||||
norms = np.linalg.norm(matrix, axis=1, keepdims=True) |
||||
# avoid division by zero |
||||
norms[norms == 0] = 1.0 |
||||
normalized = matrix / norms |
||||
|
||||
# Compute similarity matrix |
||||
sim = normalized @ normalized.T |
||||
|
||||
n = sim.shape[0] |
||||
rows_to_insert: List[dict] = [] |
||||
|
||||
for i in range(n): |
||||
scores = sim[i].copy() |
||||
# exclude self |
||||
scores[i] = -np.inf |
||||
|
||||
# number of neighbors to take is min(top_k, n-1) |
||||
k = min(top_k, n - 1) |
||||
if k <= 0: |
||||
continue |
||||
|
||||
# get top k indices |
||||
if k == 1: |
||||
idx = int(np.argmax(scores)) |
||||
top_idx = [idx] |
||||
else: |
||||
# argpartition for efficiency then sort. avoid negating scores because |
||||
# we set self to -inf earlier which would become +inf when negated and |
||||
# incorrectly be picked as a top neighbor. Instead, partition at |
||||
# n - k to obtain the k largest elements, then sort them descending. |
||||
part = np.argpartition(scores, n - k)[-k:] |
||||
top_idx = part[np.argsort(-scores[part])] |
||||
|
||||
src_id = ids[i] |
||||
for j in top_idx: |
||||
rows_to_insert.append( |
||||
{ |
||||
"source_motion_id": int(src_id), |
||||
"target_motion_id": int(ids[j]), |
||||
"score": float(scores[j]), |
||||
"vector_type": vector_type, |
||||
"window_id": window_id, |
||||
} |
||||
) |
||||
|
||||
# Clear existing cache for this vector_type/window and store new rows |
||||
try: |
||||
deleted = db.clear_similarity_cache( |
||||
vector_type=vector_type, window_id=window_id |
||||
) |
||||
logger.info( |
||||
"Cleared %d existing similarity rows for %s window=%s", |
||||
deleted, |
||||
vector_type, |
||||
window_id, |
||||
) |
||||
except Exception: |
||||
logger.exception("Error clearing similarity cache") |
||||
|
||||
try: |
||||
inserted = db.store_similarity_batch(rows_to_insert) |
||||
logger.info( |
||||
"Inserted %d similarity rows for %s window=%s", |
||||
inserted, |
||||
vector_type, |
||||
window_id, |
||||
) |
||||
return inserted |
||||
except Exception: |
||||
logger.exception("Error storing similarity rows") |
||||
return 0 |
||||
@ -0,0 +1,101 @@ |
||||
from typing import Optional, List, Dict |
||||
import logging |
||||
from database import MotionDatabase |
||||
|
||||
|
||||
_logger = logging.getLogger(__name__) |
||||
|
||||
|
||||
def get_similar_motions( |
||||
motion_id: int, |
||||
vector_type: str = "fused", |
||||
window_id: Optional[str] = None, |
||||
top_k: int = 10, |
||||
db_path: Optional[str] = None, |
||||
) -> List[Dict]: |
||||
"""Return a list of similar motions as dicts with keys: motion_id, score |
||||
|
||||
Prefers MotionDatabase.get_cached_similarities if available; otherwise falls |
||||
back to a direct SQL query using duckdb which is imported lazily. |
||||
""" |
||||
db = MotionDatabase(db_path=db_path) if db_path else MotionDatabase() |
||||
|
||||
# Prefer cached accessor if available |
||||
if hasattr(db, "get_cached_similarities"): |
||||
try: |
||||
rows = db.get_cached_similarities( |
||||
source_motion_id=motion_id, |
||||
vector_type=vector_type, |
||||
window_id=window_id, |
||||
top_k=top_k, |
||||
) |
||||
except TypeError: |
||||
# fallback if signature differs |
||||
rows = db.get_cached_similarities(motion_id, vector_type, window_id, top_k) |
||||
|
||||
# normalize shapes to [{'motion_id': int, 'score': float}, ...] |
||||
out = [] |
||||
for r in rows: |
||||
# r may be dict-like with target_motion_id or motion_id keys |
||||
if isinstance(r, dict): |
||||
mid = r.get("target_motion_id") or r.get("motion_id") or r.get("target") |
||||
score = r.get("score") or r.get("similarity") or r.get("score_float") |
||||
else: |
||||
# r might be a tuple like (target_motion_id, score) |
||||
try: |
||||
mid, score = r[0], r[1] |
||||
except Exception: |
||||
continue |
||||
|
||||
try: |
||||
out.append({"motion_id": int(mid), "score": float(score)}) |
||||
except Exception: |
||||
# skip malformed rows |
||||
continue |
||||
|
||||
# ensure ordered by score desc |
||||
out.sort(key=lambda x: x["score"], reverse=True) |
||||
return out[:top_k] |
||||
|
||||
# Fallback: query duckdb directly (import inside function) |
||||
try: |
||||
duckdb = __import__("duckdb") |
||||
except Exception: |
||||
_logger.error( |
||||
"duckdb not available and MotionDatabase lacks get_cached_similarities" |
||||
) |
||||
return [] |
||||
|
||||
conn = duckdb.connect(db.db_path) |
||||
try: |
||||
if window_id is None: |
||||
query = ( |
||||
"SELECT target_motion_id, score FROM similarity_cache " |
||||
"WHERE source_motion_id = ? AND vector_type = ? AND window_id IS NULL " |
||||
"ORDER BY score DESC LIMIT ?" |
||||
) |
||||
params = (motion_id, vector_type, top_k) |
||||
else: |
||||
query = ( |
||||
"SELECT target_motion_id, score FROM similarity_cache " |
||||
"WHERE source_motion_id = ? AND vector_type = ? AND window_id = ? " |
||||
"ORDER BY score DESC LIMIT ?" |
||||
) |
||||
params = (motion_id, vector_type, window_id, top_k) |
||||
|
||||
rows = conn.execute(query, params).fetchall() |
||||
finally: |
||||
try: |
||||
conn.close() |
||||
except Exception: |
||||
pass |
||||
|
||||
out = [] |
||||
for r in rows: |
||||
try: |
||||
out.append({"motion_id": int(r[0]), "score": float(r[1])}) |
||||
except Exception: |
||||
continue |
||||
|
||||
out.sort(key=lambda x: x["score"], reverse=True) |
||||
return out |
||||
@ -0,0 +1,38 @@ |
||||
import os |
||||
import time |
||||
|
||||
import ai_provider |
||||
|
||||
|
||||
class DummyResponse: |
||||
def __init__(self, status_code=200, json_data=None, headers=None): |
||||
self.status_code = status_code |
||||
self._json = json_data or {} |
||||
self.headers = headers or {} |
||||
|
||||
def json(self): |
||||
return self._json |
||||
|
||||
|
||||
def test_retry_on_429_then_success(monkeypatch): |
||||
calls = {"n": 0} |
||||
|
||||
def fake_post(url, json, headers, timeout): |
||||
calls["n"] += 1 |
||||
if calls["n"] <= 2: |
||||
# first two calls return 429 with Retry-After: 1 |
||||
return DummyResponse( |
||||
429, json_data={"error": "rate_limited"}, headers={"Retry-After": "1"} |
||||
) |
||||
return DummyResponse(200, json_data={"data": [{"embedding": [0.4, 0.5]}]}) |
||||
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-test") |
||||
monkeypatch.setattr("requests.post", fake_post) |
||||
|
||||
start = time.time() |
||||
emb = ai_provider.get_embedding("hello") |
||||
duration = time.time() - start |
||||
|
||||
# we should have waited at least ~2 seconds due to two Retry-After: 1 sleeps |
||||
assert duration >= 2 |
||||
assert emb == [0.4, 0.5] |
||||
@ -0,0 +1,65 @@ |
||||
def test_similarity_compute_and_lookup(tmp_path): |
||||
import pytest |
||||
|
||||
duckdb = pytest.importorskip("duckdb") |
||||
|
||||
# local duckdb imported above |
||||
from database import MotionDatabase |
||||
|
||||
import similarity.compute as compute |
||||
import similarity.lookup as lookup |
||||
|
||||
db_path = str(tmp_path / "motions.db") |
||||
|
||||
# Build MotionDatabase on tmp_path |
||||
db = MotionDatabase(db_path=db_path) |
||||
|
||||
# Insert three motions directly (avoid insert_motion which expects migration-added columns) |
||||
conn = duckdb.connect(db_path) |
||||
motion_ids = [] |
||||
for i in range(1, 4): |
||||
conn.execute( |
||||
"INSERT INTO motions (title, url) VALUES (?, ?)", |
||||
(f"motion {i}", f"http://example/{i}"), |
||||
) |
||||
row = conn.execute( |
||||
"SELECT id FROM motions WHERE url = ?", (f"http://example/{i}",) |
||||
).fetchone() |
||||
assert row is not None |
||||
motion_ids.append(row[0]) |
||||
conn.close() |
||||
|
||||
# Insert fused_embeddings for window 'W1' |
||||
vectors = [[1, 0, 0], [0, 1, 0], [1, 1, 0]] |
||||
for motion_id, vec in zip(motion_ids, vectors): |
||||
rid = db.store_fused_embedding( |
||||
motion_id=motion_id, window_id="W1", vector=vec, svd_dims=1, text_dims=2 |
||||
) |
||||
assert rid != -1 |
||||
|
||||
# Compute similarities |
||||
inserted = compute.compute_similarities( |
||||
vector_type="fused", window_id="W1", top_k=1, db_path=db_path |
||||
) |
||||
# depending on implementation we may insert 2 or 3 rows (or more); allow 2 or 3 |
||||
assert inserted in (2, 3) |
||||
|
||||
# Lookup neighbors for motion 1 |
||||
neighbors = lookup.get_similar_motions( |
||||
motion_id=motion_ids[0], |
||||
vector_type="fused", |
||||
window_id="W1", |
||||
top_k=2, |
||||
db_path=db_path, |
||||
) |
||||
assert len(neighbors) >= 1 |
||||
|
||||
# Verify ordering: motion 3 ([1,1,0]) should be closer to motion 1 ([1,0,0]) than motion 2 ([0,1,0]) |
||||
if len(neighbors) >= 2: |
||||
first = neighbors[0] |
||||
second = neighbors[1] |
||||
assert first["motion_id"] == motion_ids[2] |
||||
assert first["score"] >= second["score"] |
||||
else: |
||||
# If only one neighbor returned, it should be motion 3 |
||||
assert neighbors[0]["motion_id"] == motion_ids[2] |
||||
@ -0,0 +1,67 @@ |
||||
import json |
||||
from pathlib import Path |
||||
|
||||
from database import MotionDatabase |
||||
|
||||
|
||||
def test_similarity_cache_roundtrip(tmp_path: Path): |
||||
db_file = tmp_path / "motions.db" |
||||
# Create MotionDatabase which should initialize schema |
||||
db = MotionDatabase(db_path=str(db_file)) |
||||
|
||||
# If MotionDatabase fell back to file mode, check JSON files |
||||
if getattr(db, "_file_mode", False): |
||||
emb_file = Path(str(db_file) + ".embeddings.json") |
||||
sim_file = Path(str(db_file) + ".similarity_cache.json") |
||||
assert emb_file.exists() |
||||
assert sim_file.exists() |
||||
assert json.loads(emb_file.read_text(encoding="utf-8")) == [] |
||||
assert json.loads(sim_file.read_text(encoding="utf-8")) == [] |
||||
else: |
||||
# Try to import duckdb only when needed |
||||
import duckdb |
||||
|
||||
conn = duckdb.connect(str(db_file)) |
||||
embeddings_count = conn.execute("SELECT COUNT(*) FROM embeddings").fetchone()[0] |
||||
similarity_count = conn.execute( |
||||
"SELECT COUNT(*) FROM similarity_cache" |
||||
).fetchone()[0] |
||||
conn.close() |
||||
assert embeddings_count == 0 |
||||
assert similarity_count == 0 |
||||
|
||||
# Insert two similarity rows via helper |
||||
rows = [ |
||||
{ |
||||
"source_motion_id": 1, |
||||
"target_motion_id": 2, |
||||
"score": 0.5, |
||||
"vector_type": "text", |
||||
"window_id": None, |
||||
}, |
||||
{ |
||||
"source_motion_id": 1, |
||||
"target_motion_id": 3, |
||||
"score": 0.9, |
||||
"vector_type": "text", |
||||
"window_id": None, |
||||
}, |
||||
] |
||||
|
||||
db.store_similarity_batch(rows) |
||||
|
||||
# Read back cached similarities and verify ordering (highest score first) |
||||
results = db.get_cached_similarities(source_motion_id=1, vector_type="text") |
||||
assert len(results) == 2 |
||||
# results may be dicts from DB or file-backed dicts |
||||
assert results[0]["target_motion_id"] == 3 |
||||
assert abs(float(results[0]["score"]) - 0.9) < 1e-6 |
||||
assert results[1]["target_motion_id"] == 2 |
||||
assert abs(float(results[1]["score"]) - 0.5) < 1e-6 |
||||
|
||||
# Clear cache and verify it's empty |
||||
db.clear_similarity_cache(vector_type="text") |
||||
results_after_clear = db.get_cached_similarities( |
||||
source_motion_id=1, vector_type="text" |
||||
) |
||||
assert results_after_clear == [] |
||||
Loading…
Reference in new issue