import json import logging from typing import Dict import duckdb from database import MotionDatabase _logger = logging.getLogger(__name__) def fuse_for_window( window_id: str, db_path: str = None, model: str = None ) -> Dict[str, int]: """Fuse SVD vectors with text embeddings for motions in a window. Parameters: - window_id: id of the window to process - db_path: optional path to duckdb database (if None MotionDatabase default is used) - model: optional model name to filter text embeddings Returns a dict with counts: inserted, skipped_missing_text, skipped_missing_svd, errors """ # Create MotionDatabase using provided path if given, otherwise use default if db_path: db = MotionDatabase(db_path=db_path) conn = duckdb.connect(db_path) else: db = MotionDatabase() # MotionDatabase always exposes the path it uses conn = duckdb.connect(db.db_path) # Perform a single query that joins SVD vectors (for motions in the window) # with the latest text embedding per motion (optionally filtered by model). # We use a CTE to pick the latest embedding per motion_id. if model: sql = ( "WITH latest_embeddings AS (" " SELECT motion_id, vector FROM (" " SELECT motion_id, vector, ROW_NUMBER() OVER (PARTITION BY motion_id ORDER BY created_at DESC) AS rn" " FROM embeddings WHERE model = ?" " ) WHERE rn = 1)" " SELECT sv.entity_id, sv.vector as svd_vector, le.vector as embedding_vector" " FROM svd_vectors sv" " LEFT JOIN latest_embeddings le ON CAST(sv.entity_id AS INTEGER) = le.motion_id" " WHERE sv.window_id = ? AND sv.entity_type = 'motion'" ) params = (model, window_id) else: sql = ( "WITH latest_embeddings AS (" " SELECT motion_id, vector FROM (" " SELECT motion_id, vector, ROW_NUMBER() OVER (PARTITION BY motion_id ORDER BY created_at DESC) AS rn" " FROM embeddings" " ) WHERE rn = 1)" " SELECT sv.entity_id, sv.vector as svd_vector, le.vector as embedding_vector" " FROM svd_vectors sv" " LEFT JOIN latest_embeddings le ON CAST(sv.entity_id AS INTEGER) = le.motion_id" " WHERE sv.window_id = ? AND sv.entity_type = 'motion'" ) params = (window_id,) rows = conn.execute(sql, params).fetchall() _logger.debug( "Found %d svd rows for window %s (joined with latest embeddings)", len(rows), window_id, ) inserted = 0 skipped_missing_text = 0 skipped_missing_svd = 0 errors = 0 for entity_id, svd_json, emb_json in rows: # Parse SVD vector try: svd_vec = json.loads(svd_json) except Exception: _logger.exception("Invalid SVD vector JSON for entity %s", entity_id) skipped_missing_svd += 1 continue # If there is no embedding joined, skip if not emb_json: skipped_missing_text += 1 continue try: text_vec = json.loads(emb_json) except Exception: _logger.exception("Invalid text embedding JSON for motion %s", entity_id) skipped_missing_text += 1 continue try: fused = list(svd_vec) + list(text_vec) except Exception: _logger.exception("Error concatenating vectors for motion %s", entity_id) errors += 1 continue # store fused embedding and check result try: res = db.store_fused_embedding( int(entity_id), window_id, fused, svd_dims=len(svd_vec), text_dims=len(text_vec), ) if res and res > 0: inserted += 1 else: errors += 1 _logger.error( "Failed to store fused embedding for motion %s (db returned %s)", entity_id, res, ) except Exception: _logger.exception( "Exception while storing fused embedding for motion %s", entity_id ) errors += 1 conn.close() return { "inserted": inserted, "skipped_missing_text": skipped_missing_text, "skipped_missing_svd": skipped_missing_svd, "errors": errors, }