You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/similarity/lookup.py

101 lines
3.3 KiB

from typing import Optional, List, Dict
import logging
from database import MotionDatabase
_logger = logging.getLogger(__name__)
def get_similar_motions(
motion_id: int,
vector_type: str = "fused",
window_id: Optional[str] = None,
top_k: int = 10,
db_path: Optional[str] = None,
) -> List[Dict]:
"""Return a list of similar motions as dicts with keys: motion_id, score
Prefers MotionDatabase.get_cached_similarities if available; otherwise falls
back to a direct SQL query using duckdb which is imported lazily.
"""
db = MotionDatabase(db_path=db_path) if db_path else MotionDatabase()
# Prefer cached accessor if available
if hasattr(db, "get_cached_similarities"):
try:
rows = db.get_cached_similarities(
source_motion_id=motion_id,
vector_type=vector_type,
window_id=window_id,
top_k=top_k,
)
except TypeError:
# fallback if signature differs
rows = db.get_cached_similarities(motion_id, vector_type, window_id, top_k)
# normalize shapes to [{'motion_id': int, 'score': float}, ...]
out = []
for r in rows:
# r may be dict-like with target_motion_id or motion_id keys
if isinstance(r, dict):
mid = r.get("target_motion_id") or r.get("motion_id") or r.get("target")
score = r.get("score") or r.get("similarity") or r.get("score_float")
else:
# r might be a tuple like (target_motion_id, score)
try:
mid, score = r[0], r[1]
except Exception:
continue
try:
out.append({"motion_id": int(mid), "score": float(score)})
except Exception:
# skip malformed rows
continue
# ensure ordered by score desc
out.sort(key=lambda x: x["score"], reverse=True)
return out[:top_k]
# Fallback: query duckdb directly (import inside function)
try:
duckdb = __import__("duckdb")
except Exception:
_logger.error(
"duckdb not available and MotionDatabase lacks get_cached_similarities"
)
return []
conn = duckdb.connect(db.db_path)
try:
if window_id is None:
query = (
"SELECT target_motion_id, score FROM similarity_cache "
"WHERE source_motion_id = ? AND vector_type = ? AND window_id IS NULL "
"ORDER BY score DESC LIMIT ?"
)
params = (motion_id, vector_type, top_k)
else:
query = (
"SELECT target_motion_id, score FROM similarity_cache "
"WHERE source_motion_id = ? AND vector_type = ? AND window_id = ? "
"ORDER BY score DESC LIMIT ?"
)
params = (motion_id, vector_type, window_id, top_k)
rows = conn.execute(query, params).fetchall()
finally:
try:
conn.close()
except Exception:
pass
out = []
for r in rows:
try:
out.append({"motion_id": int(r[0]), "score": float(r[1])})
except Exception:
continue
out.sort(key=lambda x: x["score"], reverse=True)
return out