You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
3.3 KiB
101 lines
3.3 KiB
from typing import Optional, List, Dict
|
|
import logging
|
|
from database import MotionDatabase
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_similar_motions(
|
|
motion_id: int,
|
|
vector_type: str = "fused",
|
|
window_id: Optional[str] = None,
|
|
top_k: int = 10,
|
|
db_path: Optional[str] = None,
|
|
) -> List[Dict]:
|
|
"""Return a list of similar motions as dicts with keys: motion_id, score
|
|
|
|
Prefers MotionDatabase.get_cached_similarities if available; otherwise falls
|
|
back to a direct SQL query using duckdb which is imported lazily.
|
|
"""
|
|
db = MotionDatabase(db_path=db_path) if db_path else MotionDatabase()
|
|
|
|
# Prefer cached accessor if available
|
|
if hasattr(db, "get_cached_similarities"):
|
|
try:
|
|
rows = db.get_cached_similarities(
|
|
source_motion_id=motion_id,
|
|
vector_type=vector_type,
|
|
window_id=window_id,
|
|
top_k=top_k,
|
|
)
|
|
except TypeError:
|
|
# fallback if signature differs
|
|
rows = db.get_cached_similarities(motion_id, vector_type, window_id, top_k)
|
|
|
|
# normalize shapes to [{'motion_id': int, 'score': float}, ...]
|
|
out = []
|
|
for r in rows:
|
|
# r may be dict-like with target_motion_id or motion_id keys
|
|
if isinstance(r, dict):
|
|
mid = r.get("target_motion_id") or r.get("motion_id") or r.get("target")
|
|
score = r.get("score") or r.get("similarity") or r.get("score_float")
|
|
else:
|
|
# r might be a tuple like (target_motion_id, score)
|
|
try:
|
|
mid, score = r[0], r[1]
|
|
except Exception:
|
|
continue
|
|
|
|
try:
|
|
out.append({"motion_id": int(mid), "score": float(score)})
|
|
except Exception:
|
|
# skip malformed rows
|
|
continue
|
|
|
|
# ensure ordered by score desc
|
|
out.sort(key=lambda x: x["score"], reverse=True)
|
|
return out[:top_k]
|
|
|
|
# Fallback: query duckdb directly (import inside function)
|
|
try:
|
|
duckdb = __import__("duckdb")
|
|
except Exception:
|
|
_logger.error(
|
|
"duckdb not available and MotionDatabase lacks get_cached_similarities"
|
|
)
|
|
return []
|
|
|
|
conn = duckdb.connect(db.db_path)
|
|
try:
|
|
if window_id is None:
|
|
query = (
|
|
"SELECT target_motion_id, score FROM similarity_cache "
|
|
"WHERE source_motion_id = ? AND vector_type = ? AND window_id IS NULL "
|
|
"ORDER BY score DESC LIMIT ?"
|
|
)
|
|
params = (motion_id, vector_type, top_k)
|
|
else:
|
|
query = (
|
|
"SELECT target_motion_id, score FROM similarity_cache "
|
|
"WHERE source_motion_id = ? AND vector_type = ? AND window_id = ? "
|
|
"ORDER BY score DESC LIMIT ?"
|
|
)
|
|
params = (motion_id, vector_type, window_id, top_k)
|
|
|
|
rows = conn.execute(query, params).fetchall()
|
|
finally:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
out = []
|
|
for r in rows:
|
|
try:
|
|
out.append({"motion_id": int(r[0]), "score": float(r[1])})
|
|
except Exception:
|
|
continue
|
|
|
|
out.sort(key=lambda x: x["score"], reverse=True)
|
|
return out
|
|
|