from typing import Optional, List, Dict import logging from database import MotionDatabase _logger = logging.getLogger(__name__) def get_similar_motions( motion_id: int, vector_type: str = "fused", window_id: Optional[str] = None, top_k: int = 10, db_path: Optional[str] = None, ) -> List[Dict]: """Return a list of similar motions as dicts with keys: motion_id, score Prefers MotionDatabase.get_cached_similarities if available; otherwise falls back to a direct SQL query using duckdb which is imported lazily. """ db = MotionDatabase(db_path=db_path) if db_path else MotionDatabase() # Prefer cached accessor if available if hasattr(db, "get_cached_similarities"): try: rows = db.get_cached_similarities( source_motion_id=motion_id, vector_type=vector_type, window_id=window_id, top_k=top_k, ) except TypeError: # fallback if signature differs rows = db.get_cached_similarities(motion_id, vector_type, window_id, top_k) # normalize shapes to [{'motion_id': int, 'score': float}, ...] out = [] for r in rows: # r may be dict-like with target_motion_id or motion_id keys if isinstance(r, dict): mid = r.get("target_motion_id") or r.get("motion_id") or r.get("target") score = r.get("score") or r.get("similarity") or r.get("score_float") else: # r might be a tuple like (target_motion_id, score) try: mid, score = r[0], r[1] except Exception: continue try: out.append({"motion_id": int(mid), "score": float(score)}) except Exception: # skip malformed rows continue # ensure ordered by score desc out.sort(key=lambda x: x["score"], reverse=True) return out[:top_k] # Fallback: query duckdb directly (import inside function) try: duckdb = __import__("duckdb") except Exception: _logger.error( "duckdb not available and MotionDatabase lacks get_cached_similarities" ) return [] conn = duckdb.connect(db.db_path) try: if window_id is None: query = ( "SELECT target_motion_id, score FROM similarity_cache " "WHERE source_motion_id = ? AND vector_type = ? AND window_id IS NULL " "ORDER BY score DESC LIMIT ?" ) params = (motion_id, vector_type, top_k) else: query = ( "SELECT target_motion_id, score FROM similarity_cache " "WHERE source_motion_id = ? AND vector_type = ? AND window_id = ? " "ORDER BY score DESC LIMIT ?" ) params = (motion_id, vector_type, window_id, top_k) rows = conn.execute(query, params).fetchall() finally: try: conn.close() except Exception: pass out = [] for r in rows: try: out.append({"motion_id": int(r[0]), "score": float(r[1])}) except Exception: continue out.sort(key=lambda x: x["score"], reverse=True) return out