You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/pipeline/fusion.py

116 lines
3.6 KiB

import json
import logging
from typing import Dict
import duckdb
from database import MotionDatabase
_logger = logging.getLogger(__name__)
def fuse_for_window(
window_id: str, db_path: str = None, model: str = None
) -> Dict[str, int]:
"""Fuse SVD vectors with text embeddings for motions in a window.
Parameters:
- window_id: id of the window to process
- db_path: optional path to duckdb database (if None MotionDatabase default is used)
- model: optional model name to filter text embeddings
Returns a dict with counts: inserted, skipped_missing_text, skipped_missing_svd, errors
"""
# Create MotionDatabase using provided path if given, otherwise use default
if db_path:
db = MotionDatabase(db_path=db_path)
conn = duckdb.connect(db_path)
else:
db = MotionDatabase()
# MotionDatabase always exposes the path it uses
conn = duckdb.connect(db.db_path)
# Fetch svd vectors for the window and entity_type=motion
rows = conn.execute(
"SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = ?",
(window_id, "motion"),
).fetchall()
# debug
_logger.debug("Found %d svd rows for window %s", len(rows), window_id)
inserted = 0
skipped_missing_text = 0
skipped_missing_svd = 0
errors = 0
for entity_id, svd_json in rows:
try:
svd_vec = json.loads(svd_json)
except Exception:
_logger.exception("Invalid SVD vector JSON for entity %s", entity_id)
skipped_missing_svd += 1
continue
# Look up text embedding for this motion (most recent). If model is provided
# filter by model as well.
if model:
emb_row = conn.execute(
"SELECT vector FROM embeddings WHERE motion_id = ? AND model = ? ORDER BY created_at DESC LIMIT 1",
(int(entity_id), model),
).fetchone()
else:
emb_row = conn.execute(
"SELECT vector FROM embeddings WHERE motion_id = ? ORDER BY created_at DESC LIMIT 1",
(int(entity_id),),
).fetchone()
if not emb_row:
skipped_missing_text += 1
continue
try:
text_vec = json.loads(emb_row[0])
except Exception:
_logger.exception("Invalid text embedding JSON for motion %s", entity_id)
skipped_missing_text += 1
continue
try:
fused = list(svd_vec) + list(text_vec)
except Exception:
_logger.exception("Error concatenating vectors for motion %s", entity_id)
errors += 1
continue
# store fused embedding and check result
try:
res = db.store_fused_embedding(
int(entity_id),
window_id,
fused,
svd_dims=len(svd_vec),
text_dims=len(text_vec),
)
if res and res > 0:
inserted += 1
else:
errors += 1
_logger.error(
"Failed to store fused embedding for motion %s (db returned %s)",
entity_id,
res,
)
except Exception:
_logger.exception(
"Exception while storing fused embedding for motion %s", entity_id
)
errors += 1
conn.close()
return {
"inserted": inserted,
"skipped_missing_text": skipped_missing_text,
"skipped_missing_svd": skipped_missing_svd,
"errors": errors,
}