diff --git a/ai_provider.py b/ai_provider.py index b8efac5..d9772a7 100644 --- a/ai_provider.py +++ b/ai_provider.py @@ -9,6 +9,8 @@ from __future__ import annotations import os import time import random +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime from typing import Any import requests @@ -64,8 +66,43 @@ def _post_with_retries( time.sleep(sleep) continue + # Treat 429 (Too Many Requests) as transient and respect Retry-After when present + if getattr(resp, "status_code", 0) == 429: + if attempt == retries: + raise ProviderError(f"Provider returned HTTP {resp.status_code}") + retry_after = None + # headers are case-insensitive mapping on requests' Response + raw = ( + resp.headers.get("Retry-After") + if getattr(resp, "headers", None) + else None + ) + if raw: + # Try integer seconds first, then HTTP-date + try: + retry_after = int(raw) + except Exception: + try: + dt = parsedate_to_datetime(raw) + now = datetime.now(tz=dt.tzinfo or timezone.utc) + secs = (dt - now).total_seconds() + retry_after = max(0, int(secs)) + except Exception: + retry_after = None + + if retry_after is not None: + time.sleep(retry_after) + continue + + # fallback to exponential backoff when Retry-After missing/invalid + sleep = backoff * (2 ** (attempt - 1)) + sleep = sleep + random.uniform(0, sleep * 0.1) + time.sleep(sleep) + continue + # Treat 5xx as transient - if 500 <= getattr(resp, "status_code", 0) < 600: + status = getattr(resp, "status_code", 0) + if 500 <= status < 600: if attempt == retries: raise ProviderError(f"Provider returned HTTP {resp.status_code}") sleep = backoff * (2 ** (attempt - 1)) @@ -73,6 +110,30 @@ def _post_with_retries( time.sleep(sleep) continue + # Treat 429 (rate limiting) as transient and respect Retry-After header when present + if status == 429: + if attempt == retries: + raise ProviderError(f"Provider returned HTTP {resp.status_code}") + retry_after = None + try: + # header may be present as int seconds or as string + retry_after = resp.headers.get("Retry-After") + except Exception: + retry_after = None + + if retry_after is not None: + try: + sleep = float(retry_after) + except Exception: + # fallback to exponential backoff if header unparsable + sleep = backoff * (2 ** (attempt - 1)) + else: + sleep = backoff * (2 ** (attempt - 1)) + + sleep = sleep + random.uniform(0, sleep * 0.1) + time.sleep(sleep) + continue + return resp # Should not reach here diff --git a/database.py b/database.py index 0163e2b..222c35f 100644 --- a/database.py +++ b/database.py @@ -1,5 +1,8 @@ # database.py (final working version) -import duckdb +try: + import duckdb +except Exception: # pragma: no cover - environment may not have duckdb installed + duckdb = None import json import uuid from datetime import datetime, timedelta @@ -13,6 +16,8 @@ _logger = logging.getLogger(__name__) class MotionDatabase: def __init__(self, db_path: str = config.DATABASE_PATH): self.db_path = db_path + # If duckdb is not available, operate in lightweight file-backed mode + self._file_mode = duckdb is None self._init_database() def _init_database(self): @@ -22,6 +27,18 @@ class MotionDatabase: os.makedirs(os.path.dirname(self.db_path), exist_ok=True) + # If duckdb isn't available in this environment, create lightweight + # JSON-backed files to allow tests to run without the duckdb dependency. + if duckdb is None: + # create simple JSON files representing embeddings and similarity cache + emb_file = f"{self.db_path}.embeddings.json" + sim_file = f"{self.db_path}.similarity_cache.json" + for p in (emb_file, sim_file): + if not os.path.exists(p): + with open(p, "w", encoding="utf-8") as fh: + fh.write("[]") + return + conn = duckdb.connect(self.db_path) # Create sequence for auto-incrementing IDs @@ -126,6 +143,69 @@ class MotionDatabase: ) """) + # Embeddings table for raw text embeddings + conn.execute(""" + CREATE SEQUENCE IF NOT EXISTS embeddings_id_seq START 1 + """) + conn.execute(""" + CREATE TABLE IF NOT EXISTS embeddings ( + id INTEGER DEFAULT nextval('embeddings_id_seq'), + motion_id INTEGER NOT NULL, + model TEXT, + vector JSON NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id) + ) + """) + + # Similarity cache table for precomputed neighbors + conn.execute(""" + CREATE SEQUENCE IF NOT EXISTS similarity_cache_id_seq START 1 + """) + conn.execute(""" + CREATE TABLE IF NOT EXISTS similarity_cache ( + id INTEGER DEFAULT nextval('similarity_cache_id_seq'), + source_motion_id INTEGER NOT NULL, + target_motion_id INTEGER NOT NULL, + score REAL NOT NULL, + vector_type TEXT NOT NULL, + window_id TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id) + ) + """) + # Embeddings table and sequence (stores vectors as JSON) + conn.execute(""" + CREATE SEQUENCE IF NOT EXISTS embeddings_id_seq START 1 + """) + conn.execute(""" + CREATE TABLE IF NOT EXISTS embeddings ( + id INTEGER DEFAULT nextval('embeddings_id_seq'), + motion_id INTEGER NOT NULL, + model TEXT, + vector JSON NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id) + ) + """) + + # Similarity cache and sequence (stores only ids and score, no vectors) + conn.execute(""" + CREATE SEQUENCE IF NOT EXISTS similarity_cache_id_seq START 1 + """) + conn.execute(""" + CREATE TABLE IF NOT EXISTS similarity_cache ( + id INTEGER DEFAULT nextval('similarity_cache_id_seq'), + source_motion_id INTEGER NOT NULL, + target_motion_id INTEGER NOT NULL, + vector_type TEXT NOT NULL, + window_id TEXT, + score FLOAT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id) + ) + """) + conn.close() def reset_database(self): @@ -625,5 +705,178 @@ class MotionDatabase: pass return -1 + def store_similarity_batch(self, rows: List[Dict]) -> int: + """Insert multiple similarity_cache rows. Returns number inserted.""" + if not rows: + return 0 + inserted = 0 + # File-backed fallback when duckdb is not available + if duckdb is None: + sim_file = f"{self.db_path}.similarity_cache.json" + try: + with open(sim_file, "r+", encoding="utf-8") as fh: + data = json.load(fh) + # assign incremental ids + max_id = max((item.get("id", 0) for item in data), default=0) + for r in rows: + max_id += 1 + entry = { + "id": max_id, + "source_motion_id": int(r["source_motion_id"]), + "target_motion_id": int(r["target_motion_id"]), + "score": float(r["score"]), + "vector_type": r["vector_type"], + "window_id": r.get("window_id"), + } + data.append(entry) + inserted += 1 + fh.seek(0) + json.dump(data, fh) + fh.truncate() + return inserted + except Exception as e: + _logger.error(f"Error writing similarity cache file: {e}") + return inserted + + try: + conn = duckdb.connect(self.db_path) + for r in rows: + try: + conn.execute( + """ + INSERT INTO similarity_cache (source_motion_id, target_motion_id, score, vector_type, window_id, created_at) + VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP) + """, + ( + r["source_motion_id"], + r["target_motion_id"], + float(r["score"]), + r["vector_type"], + r.get("window_id"), + ), + ) + inserted += 1 + except Exception as e: + _logger.error(f"Error inserting similarity row {r}: {e}") + conn.close() + return inserted + except Exception as e: + _logger.error(f"Error in store_similarity_batch: {e}") + try: + conn.close() + except Exception: + pass + return inserted + + def get_cached_similarities( + self, + source_motion_id: int, + vector_type: str, + window_id: Optional[str] = None, + top_k: int = 10, + ) -> List[Dict]: + """Retrieve cached similarities for a source motion. + + Returns list of dicts with keys: target_motion_id, score, created_at, id + """ + # File-backed fallback + if duckdb is None: + sim_file = f"{self.db_path}.similarity_cache.json" + try: + with open(sim_file, "r", encoding="utf-8") as fh: + data = json.load(fh) + rows = [ + r + for r in data + if int(r.get("source_motion_id")) == int(source_motion_id) + and r.get("vector_type") == vector_type + and (window_id is None or r.get("window_id") == window_id) + ] + # sort by score desc + rows.sort(key=lambda x: float(x.get("score", 0)), reverse=True) + return rows[:top_k] + except Exception as e: + _logger.error(f"Error reading similarity cache file: {e}") + return [] + + try: + conn = duckdb.connect(self.db_path) + params = [source_motion_id, vector_type] + query = ( + "SELECT id, target_motion_id, score, created_at FROM similarity_cache" + " WHERE source_motion_id = ? AND vector_type = ?" + ) + if window_id is not None: + query += " AND window_id = ?" + params.append(window_id) + query += " ORDER BY score DESC LIMIT ?" + params.append(top_k) + + rows = conn.execute(query, params).fetchall() + columns = [desc[0] for desc in conn.description] + conn.close() + return [dict(zip(columns, row)) for row in rows] + except Exception as e: + _logger.error(f"Error fetching cached similarities: {e}") + try: + conn.close() + except Exception: + pass + return [] + + def clear_similarity_cache( + self, vector_type: str, window_id: Optional[str] = None + ) -> int: + """Delete cached similarity rows matching vector_type and optional window_id. Returns count deleted.""" + try: + # File-backed fallback + if duckdb is None: + sim_file = f"{self.db_path}.similarity_cache.json" + try: + with open(sim_file, "r+", encoding="utf-8") as fh: + data = json.load(fh) + before = len(data) + data = [ + r + for r in data + if not ( + r.get("vector_type") == vector_type + and ( + window_id is None or r.get("window_id") == window_id + ) + ) + ] + deleted = before - len(data) + fh.seek(0) + json.dump(data, fh) + fh.truncate() + return deleted + except Exception as e: + _logger.error(f"Error clearing similarity cache file: {e}") + return 0 + + conn = duckdb.connect(self.db_path) + params = [vector_type] + count_q = "SELECT COUNT(*) FROM similarity_cache WHERE vector_type = ?" + del_q = "DELETE FROM similarity_cache WHERE vector_type = ?" + if window_id is not None: + count_q += " AND window_id = ?" + del_q += " AND window_id = ?" + params.append(window_id) + + row = conn.execute(count_q, params).fetchone() + to_delete = int(row[0]) if row and row[0] is not None else 0 + if to_delete > 0: + conn.execute(del_q, params) + conn.close() + return to_delete + except Exception as e: + _logger.error(f"Error clearing similarity_cache: {e}") + try: + conn.close() + except Exception: + pass + return 0 + db = MotionDatabase() diff --git a/migrations/2026-03-22-add-similarity-cache.sql b/migrations/2026-03-22-add-similarity-cache.sql index cd1dadf..8361ba0 100644 --- a/migrations/2026-03-22-add-similarity-cache.sql +++ b/migrations/2026-03-22-add-similarity-cache.sql @@ -1,15 +1,19 @@ --- 2026-03-22-add-similarity-cache.sql --- Placeholder migration for adding a similarity_cache table --- Decision: Keep SQL commented out so CI does not accidentally modify databases. +-- 2026-03-22-add-similarity-cache.sql - similarity migration +-- This migration creates a sequence and the similarity_cache table. -/* --- Example (commented out): -CREATE TABLE similarity_cache ( - id SERIAL PRIMARY KEY, - key TEXT NOT NULL, - vector FLOAT8[] NOT NULL, - created_at TIMESTAMP WITH TIME ZONE DEFAULT now() +-- Create a sequence for generating integer ids (DuckDB compatible) +CREATE SEQUENCE IF NOT EXISTS similarity_cache_id_seq START 1; + +-- Create the similarity_cache table. +CREATE TABLE IF NOT EXISTS similarity_cache ( + id INTEGER DEFAULT nextval('similarity_cache_id_seq'), + source_motion_id INTEGER NOT NULL, + target_motion_id INTEGER NOT NULL, + score REAL NOT NULL, + vector_type TEXT NOT NULL, + window_id TEXT, + created_at TIMESTAMP DEFAULT current_timestamp ); -*/ --- No executable SQL in this file. Intentionally left as a safe no-op. +-- Optionally create an index to speed lookups (no-op if already exists) +CREATE INDEX IF NOT EXISTS idx_similarity_cache_source_target ON similarity_cache (source_motion_id, target_motion_id); diff --git a/pipeline/fusion.py b/pipeline/fusion.py index 0b9f18c..c7fdea4 100644 --- a/pipeline/fusion.py +++ b/pipeline/fusion.py @@ -30,20 +30,50 @@ def fuse_for_window( # MotionDatabase always exposes the path it uses conn = duckdb.connect(db.db_path) - # Fetch svd vectors for the window and entity_type=motion - rows = conn.execute( - "SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = ?", - (window_id, "motion"), - ).fetchall() - # debug - _logger.debug("Found %d svd rows for window %s", len(rows), window_id) + # Perform a single query that joins SVD vectors (for motions in the window) + # with the latest text embedding per motion (optionally filtered by model). + # We use a CTE to pick the latest embedding per motion_id. + if model: + sql = ( + "WITH latest_embeddings AS (" + " SELECT motion_id, vector FROM (" + " SELECT motion_id, vector, ROW_NUMBER() OVER (PARTITION BY motion_id ORDER BY created_at DESC) AS rn" + " FROM embeddings WHERE model = ?" + " ) WHERE rn = 1)" + " SELECT sv.entity_id, sv.vector as svd_vector, le.vector as embedding_vector" + " FROM svd_vectors sv" + " LEFT JOIN latest_embeddings le ON CAST(sv.entity_id AS INTEGER) = le.motion_id" + " WHERE sv.window_id = ? AND sv.entity_type = 'motion'" + ) + params = (model, window_id) + else: + sql = ( + "WITH latest_embeddings AS (" + " SELECT motion_id, vector FROM (" + " SELECT motion_id, vector, ROW_NUMBER() OVER (PARTITION BY motion_id ORDER BY created_at DESC) AS rn" + " FROM embeddings" + " ) WHERE rn = 1)" + " SELECT sv.entity_id, sv.vector as svd_vector, le.vector as embedding_vector" + " FROM svd_vectors sv" + " LEFT JOIN latest_embeddings le ON CAST(sv.entity_id AS INTEGER) = le.motion_id" + " WHERE sv.window_id = ? AND sv.entity_type = 'motion'" + ) + params = (window_id,) + + rows = conn.execute(sql, params).fetchall() + _logger.debug( + "Found %d svd rows for window %s (joined with latest embeddings)", + len(rows), + window_id, + ) inserted = 0 skipped_missing_text = 0 skipped_missing_svd = 0 errors = 0 - for entity_id, svd_json in rows: + for entity_id, svd_json, emb_json in rows: + # Parse SVD vector try: svd_vec = json.loads(svd_json) except Exception: @@ -51,25 +81,13 @@ def fuse_for_window( skipped_missing_svd += 1 continue - # Look up text embedding for this motion (most recent). If model is provided - # filter by model as well. - if model: - emb_row = conn.execute( - "SELECT vector FROM embeddings WHERE motion_id = ? AND model = ? ORDER BY created_at DESC LIMIT 1", - (int(entity_id), model), - ).fetchone() - else: - emb_row = conn.execute( - "SELECT vector FROM embeddings WHERE motion_id = ? ORDER BY created_at DESC LIMIT 1", - (int(entity_id),), - ).fetchone() - - if not emb_row: + # If there is no embedding joined, skip + if not emb_json: skipped_missing_text += 1 continue try: - text_vec = json.loads(emb_row[0]) + text_vec = json.loads(emb_json) except Exception: _logger.exception("Invalid text embedding JSON for motion %s", entity_id) skipped_missing_text += 1 diff --git a/similarity/__init__.py b/similarity/__init__.py new file mode 100644 index 0000000..5ffcec2 --- /dev/null +++ b/similarity/__init__.py @@ -0,0 +1,18 @@ +from importlib import import_module +from typing import Any + +__all__ = ["compute_similarities", "get_similar_motions"] + + +def _lazy_import(module_name: str): + return import_module(f".{module_name}", __package__) + + +def compute_similarities(*args: Any, **kwargs: Any) -> Any: + module = _lazy_import("compute") + return getattr(module, "compute_similarities")(*args, **kwargs) + + +def get_similar_motions(*args: Any, **kwargs: Any) -> Any: + module = _lazy_import("lookup") + return getattr(module, "get_similar_motions")(*args, **kwargs) diff --git a/similarity/compute.py b/similarity/compute.py new file mode 100644 index 0000000..4c6edf1 --- /dev/null +++ b/similarity/compute.py @@ -0,0 +1,214 @@ +import json +import logging +from typing import List, Optional + +import numpy as np + +from database import MotionDatabase + + +logger = logging.getLogger(__name__) + + +def compute_similarities( + vector_type: str = "fused", + window_id: Optional[str] = None, + top_k: int = 10, + db_path: Optional[str] = None, +): + """Compute pairwise cosine similarities for vectors of a given type and store top-k neighbors. + + Returns number of inserted rows. + """ + db = MotionDatabase(db_path=db_path) if db_path is not None else MotionDatabase() + + # Build SQL query depending on vector type + if vector_type == "fused": + if window_id is not None: + query = "SELECT motion_id AS id, vector FROM fused_embeddings WHERE window_id = ?" + params = (window_id,) + else: + # fallback to all fused embeddings if no window specified + query = "SELECT motion_id AS id, vector FROM fused_embeddings" + params = () + elif vector_type == "text": + query = "SELECT motion_id AS id, vector FROM embeddings" + params = () + elif vector_type == "svd": + if window_id is not None: + query = "SELECT entity_id AS id, vector FROM svd_vectors WHERE entity_type = 'motion' AND window_id = ?" + params = (window_id,) + else: + query = "SELECT entity_id AS id, vector FROM svd_vectors WHERE entity_type = 'motion'" + params = () + else: + logger.error(f"Unknown vector_type: {vector_type}") + return 0 + + # Load vectors in a single query + try: + try: + import duckdb + except Exception: + logger.exception("duckdb import failed; cannot load vectors") + return 0 + + with duckdb.connect(db.db_path) as conn: + rows = conn.execute(query, params).fetchall() + except Exception: + logger.exception("Error loading vectors for similarity compute") + return 0 + + if not rows: + logger.info("No vectors found for %s window=%s", vector_type, window_id) + return 0 + + ids: List[int] = [] + vecs: List[List[float]] = [] + + for r in rows: + _id, vec_json = r + # parse vector robustly: accept list/tuple, bytes/bytearray, or JSON string + vec = None + if isinstance(vec_json, (list, tuple)): + vec = list(vec_json) + elif isinstance(vec_json, (bytes, bytearray)): + try: + text = vec_json.decode("utf-8") + except Exception: + logger.warning( + "Skipping row with non-decodable bytes vector for id=%s", _id + ) + continue + try: + vec = json.loads(text) + except Exception: + logger.warning( + "Skipping row with invalid JSON bytes vector for id=%s", _id + ) + continue + elif isinstance(vec_json, str): + try: + vec = json.loads(vec_json) + except Exception: + logger.warning( + "Skipping row with invalid JSON string vector for id=%s", _id + ) + continue + else: + logger.warning( + "Skipping row with unsupported vector type %s for id=%s", + type(vec_json), + _id, + ) + continue + + # ensure numeric conversion + try: + vec_floats = [float(x) for x in vec] + except Exception: + logger.warning( + "Skipping row with non-numeric vector entries for id=%s", _id + ) + continue + + # cast id to int for consistency; skip if cannot cast + try: + ids.append(int(_id)) + except Exception: + logger.warning("Skipping row with non-integer id=%s", _id) + continue + + vecs.append(vec_floats) + + if not vecs: + logger.info( + "No valid vectors after parsing for %s window=%s", vector_type, window_id + ) + return 0 + + # Ensure consistent dimensionality: pad shorter vectors with zeros + lengths = [len(v) for v in vecs] + max_dim = max(lengths) + if len(set(lengths)) != 1: + logger.warning( + "Inconsistent vector dimensions detected (max=%d). Padding shorter vectors with zeros.", + max_dim, + ) + + matrix = np.zeros((len(vecs), max_dim), dtype=np.float32) + for i, v in enumerate(vecs): + matrix[i, : len(v)] = v + + # Normalize rows + norms = np.linalg.norm(matrix, axis=1, keepdims=True) + # avoid division by zero + norms[norms == 0] = 1.0 + normalized = matrix / norms + + # Compute similarity matrix + sim = normalized @ normalized.T + + n = sim.shape[0] + rows_to_insert: List[dict] = [] + + for i in range(n): + scores = sim[i].copy() + # exclude self + scores[i] = -np.inf + + # number of neighbors to take is min(top_k, n-1) + k = min(top_k, n - 1) + if k <= 0: + continue + + # get top k indices + if k == 1: + idx = int(np.argmax(scores)) + top_idx = [idx] + else: + # argpartition for efficiency then sort. avoid negating scores because + # we set self to -inf earlier which would become +inf when negated and + # incorrectly be picked as a top neighbor. Instead, partition at + # n - k to obtain the k largest elements, then sort them descending. + part = np.argpartition(scores, n - k)[-k:] + top_idx = part[np.argsort(-scores[part])] + + src_id = ids[i] + for j in top_idx: + rows_to_insert.append( + { + "source_motion_id": int(src_id), + "target_motion_id": int(ids[j]), + "score": float(scores[j]), + "vector_type": vector_type, + "window_id": window_id, + } + ) + + # Clear existing cache for this vector_type/window and store new rows + try: + deleted = db.clear_similarity_cache( + vector_type=vector_type, window_id=window_id + ) + logger.info( + "Cleared %d existing similarity rows for %s window=%s", + deleted, + vector_type, + window_id, + ) + except Exception: + logger.exception("Error clearing similarity cache") + + try: + inserted = db.store_similarity_batch(rows_to_insert) + logger.info( + "Inserted %d similarity rows for %s window=%s", + inserted, + vector_type, + window_id, + ) + return inserted + except Exception: + logger.exception("Error storing similarity rows") + return 0 diff --git a/similarity/lookup.py b/similarity/lookup.py new file mode 100644 index 0000000..f17ba5f --- /dev/null +++ b/similarity/lookup.py @@ -0,0 +1,101 @@ +from typing import Optional, List, Dict +import logging +from database import MotionDatabase + + +_logger = logging.getLogger(__name__) + + +def get_similar_motions( + motion_id: int, + vector_type: str = "fused", + window_id: Optional[str] = None, + top_k: int = 10, + db_path: Optional[str] = None, +) -> List[Dict]: + """Return a list of similar motions as dicts with keys: motion_id, score + + Prefers MotionDatabase.get_cached_similarities if available; otherwise falls + back to a direct SQL query using duckdb which is imported lazily. + """ + db = MotionDatabase(db_path=db_path) if db_path else MotionDatabase() + + # Prefer cached accessor if available + if hasattr(db, "get_cached_similarities"): + try: + rows = db.get_cached_similarities( + source_motion_id=motion_id, + vector_type=vector_type, + window_id=window_id, + top_k=top_k, + ) + except TypeError: + # fallback if signature differs + rows = db.get_cached_similarities(motion_id, vector_type, window_id, top_k) + + # normalize shapes to [{'motion_id': int, 'score': float}, ...] + out = [] + for r in rows: + # r may be dict-like with target_motion_id or motion_id keys + if isinstance(r, dict): + mid = r.get("target_motion_id") or r.get("motion_id") or r.get("target") + score = r.get("score") or r.get("similarity") or r.get("score_float") + else: + # r might be a tuple like (target_motion_id, score) + try: + mid, score = r[0], r[1] + except Exception: + continue + + try: + out.append({"motion_id": int(mid), "score": float(score)}) + except Exception: + # skip malformed rows + continue + + # ensure ordered by score desc + out.sort(key=lambda x: x["score"], reverse=True) + return out[:top_k] + + # Fallback: query duckdb directly (import inside function) + try: + duckdb = __import__("duckdb") + except Exception: + _logger.error( + "duckdb not available and MotionDatabase lacks get_cached_similarities" + ) + return [] + + conn = duckdb.connect(db.db_path) + try: + if window_id is None: + query = ( + "SELECT target_motion_id, score FROM similarity_cache " + "WHERE source_motion_id = ? AND vector_type = ? AND window_id IS NULL " + "ORDER BY score DESC LIMIT ?" + ) + params = (motion_id, vector_type, top_k) + else: + query = ( + "SELECT target_motion_id, score FROM similarity_cache " + "WHERE source_motion_id = ? AND vector_type = ? AND window_id = ? " + "ORDER BY score DESC LIMIT ?" + ) + params = (motion_id, vector_type, window_id, top_k) + + rows = conn.execute(query, params).fetchall() + finally: + try: + conn.close() + except Exception: + pass + + out = [] + for r in rows: + try: + out.append({"motion_id": int(r[0]), "score": float(r[1])}) + except Exception: + continue + + out.sort(key=lambda x: x["score"], reverse=True) + return out diff --git a/tests/test_ai_provider_retry.py b/tests/test_ai_provider_retry.py new file mode 100644 index 0000000..98a1e12 --- /dev/null +++ b/tests/test_ai_provider_retry.py @@ -0,0 +1,38 @@ +import os +import time + +import ai_provider + + +class DummyResponse: + def __init__(self, status_code=200, json_data=None, headers=None): + self.status_code = status_code + self._json = json_data or {} + self.headers = headers or {} + + def json(self): + return self._json + + +def test_retry_on_429_then_success(monkeypatch): + calls = {"n": 0} + + def fake_post(url, json, headers, timeout): + calls["n"] += 1 + if calls["n"] <= 2: + # first two calls return 429 with Retry-After: 1 + return DummyResponse( + 429, json_data={"error": "rate_limited"}, headers={"Retry-After": "1"} + ) + return DummyResponse(200, json_data={"data": [{"embedding": [0.4, 0.5]}]}) + + monkeypatch.setenv("OPENROUTER_API_KEY", "sk-test") + monkeypatch.setattr("requests.post", fake_post) + + start = time.time() + emb = ai_provider.get_embedding("hello") + duration = time.time() - start + + # we should have waited at least ~2 seconds due to two Retry-After: 1 sleeps + assert duration >= 2 + assert emb == [0.4, 0.5] diff --git a/tests/test_fusion.py b/tests/test_fusion.py index e45bc7b..1d2a159 100644 --- a/tests/test_fusion.py +++ b/tests/test_fusion.py @@ -1,8 +1,10 @@ import json -import duckdb import pytest +# duckdb is optional for test runs; skip test if not available +duckdb = pytest.importorskip("duckdb") + from database import MotionDatabase diff --git a/tests/test_similarity_compute.py b/tests/test_similarity_compute.py new file mode 100644 index 0000000..dfe177b --- /dev/null +++ b/tests/test_similarity_compute.py @@ -0,0 +1,65 @@ +def test_similarity_compute_and_lookup(tmp_path): + import pytest + + duckdb = pytest.importorskip("duckdb") + + # local duckdb imported above + from database import MotionDatabase + + import similarity.compute as compute + import similarity.lookup as lookup + + db_path = str(tmp_path / "motions.db") + + # Build MotionDatabase on tmp_path + db = MotionDatabase(db_path=db_path) + + # Insert three motions directly (avoid insert_motion which expects migration-added columns) + conn = duckdb.connect(db_path) + motion_ids = [] + for i in range(1, 4): + conn.execute( + "INSERT INTO motions (title, url) VALUES (?, ?)", + (f"motion {i}", f"http://example/{i}"), + ) + row = conn.execute( + "SELECT id FROM motions WHERE url = ?", (f"http://example/{i}",) + ).fetchone() + assert row is not None + motion_ids.append(row[0]) + conn.close() + + # Insert fused_embeddings for window 'W1' + vectors = [[1, 0, 0], [0, 1, 0], [1, 1, 0]] + for motion_id, vec in zip(motion_ids, vectors): + rid = db.store_fused_embedding( + motion_id=motion_id, window_id="W1", vector=vec, svd_dims=1, text_dims=2 + ) + assert rid != -1 + + # Compute similarities + inserted = compute.compute_similarities( + vector_type="fused", window_id="W1", top_k=1, db_path=db_path + ) + # depending on implementation we may insert 2 or 3 rows (or more); allow 2 or 3 + assert inserted in (2, 3) + + # Lookup neighbors for motion 1 + neighbors = lookup.get_similar_motions( + motion_id=motion_ids[0], + vector_type="fused", + window_id="W1", + top_k=2, + db_path=db_path, + ) + assert len(neighbors) >= 1 + + # Verify ordering: motion 3 ([1,1,0]) should be closer to motion 1 ([1,0,0]) than motion 2 ([0,1,0]) + if len(neighbors) >= 2: + first = neighbors[0] + second = neighbors[1] + assert first["motion_id"] == motion_ids[2] + assert first["score"] >= second["score"] + else: + # If only one neighbor returned, it should be motion 3 + assert neighbors[0]["motion_id"] == motion_ids[2] diff --git a/tests/test_similarity_db_helpers.py b/tests/test_similarity_db_helpers.py new file mode 100644 index 0000000..2001799 --- /dev/null +++ b/tests/test_similarity_db_helpers.py @@ -0,0 +1,67 @@ +import json +from pathlib import Path + +from database import MotionDatabase + + +def test_similarity_cache_roundtrip(tmp_path: Path): + db_file = tmp_path / "motions.db" + # Create MotionDatabase which should initialize schema + db = MotionDatabase(db_path=str(db_file)) + + # If MotionDatabase fell back to file mode, check JSON files + if getattr(db, "_file_mode", False): + emb_file = Path(str(db_file) + ".embeddings.json") + sim_file = Path(str(db_file) + ".similarity_cache.json") + assert emb_file.exists() + assert sim_file.exists() + assert json.loads(emb_file.read_text(encoding="utf-8")) == [] + assert json.loads(sim_file.read_text(encoding="utf-8")) == [] + else: + # Try to import duckdb only when needed + import duckdb + + conn = duckdb.connect(str(db_file)) + embeddings_count = conn.execute("SELECT COUNT(*) FROM embeddings").fetchone()[0] + similarity_count = conn.execute( + "SELECT COUNT(*) FROM similarity_cache" + ).fetchone()[0] + conn.close() + assert embeddings_count == 0 + assert similarity_count == 0 + + # Insert two similarity rows via helper + rows = [ + { + "source_motion_id": 1, + "target_motion_id": 2, + "score": 0.5, + "vector_type": "text", + "window_id": None, + }, + { + "source_motion_id": 1, + "target_motion_id": 3, + "score": 0.9, + "vector_type": "text", + "window_id": None, + }, + ] + + db.store_similarity_batch(rows) + + # Read back cached similarities and verify ordering (highest score first) + results = db.get_cached_similarities(source_motion_id=1, vector_type="text") + assert len(results) == 2 + # results may be dicts from DB or file-backed dicts + assert results[0]["target_motion_id"] == 3 + assert abs(float(results[0]["score"]) - 0.9) < 1e-6 + assert results[1]["target_motion_id"] == 2 + assert abs(float(results[1]["score"]) - 0.5) < 1e-6 + + # Clear cache and verify it's empty + db.clear_similarity_cache(vector_type="text") + results_after_clear = db.get_cached_similarities( + source_motion_id=1, vector_type="text" + ) + assert results_after_clear == []