You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
134 lines
4.5 KiB
134 lines
4.5 KiB
import json
|
|
import logging
|
|
from typing import Dict
|
|
|
|
import duckdb
|
|
|
|
from database import MotionDatabase
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def fuse_for_window(
|
|
window_id: str, db_path: str = None, model: str = None
|
|
) -> Dict[str, int]:
|
|
"""Fuse SVD vectors with text embeddings for motions in a window.
|
|
|
|
Parameters:
|
|
- window_id: id of the window to process
|
|
- db_path: optional path to duckdb database (if None MotionDatabase default is used)
|
|
- model: optional model name to filter text embeddings
|
|
|
|
Returns a dict with counts: inserted, skipped_missing_text, skipped_missing_svd, errors
|
|
"""
|
|
# Create MotionDatabase using provided path if given, otherwise use default
|
|
if db_path:
|
|
db = MotionDatabase(db_path=db_path)
|
|
conn = duckdb.connect(db_path)
|
|
else:
|
|
db = MotionDatabase()
|
|
# MotionDatabase always exposes the path it uses
|
|
conn = duckdb.connect(db.db_path)
|
|
|
|
# Perform a single query that joins SVD vectors (for motions in the window)
|
|
# with the latest text embedding per motion (optionally filtered by model).
|
|
# We use a CTE to pick the latest embedding per motion_id.
|
|
if model:
|
|
sql = (
|
|
"WITH latest_embeddings AS ("
|
|
" SELECT motion_id, vector FROM ("
|
|
" SELECT motion_id, vector, ROW_NUMBER() OVER (PARTITION BY motion_id ORDER BY created_at DESC) AS rn"
|
|
" FROM embeddings WHERE model = ?"
|
|
" ) WHERE rn = 1)"
|
|
" SELECT sv.entity_id, sv.vector as svd_vector, le.vector as embedding_vector"
|
|
" FROM svd_vectors sv"
|
|
" LEFT JOIN latest_embeddings le ON CAST(sv.entity_id AS INTEGER) = le.motion_id"
|
|
" WHERE sv.window_id = ? AND sv.entity_type = 'motion'"
|
|
)
|
|
params = (model, window_id)
|
|
else:
|
|
sql = (
|
|
"WITH latest_embeddings AS ("
|
|
" SELECT motion_id, vector FROM ("
|
|
" SELECT motion_id, vector, ROW_NUMBER() OVER (PARTITION BY motion_id ORDER BY created_at DESC) AS rn"
|
|
" FROM embeddings"
|
|
" ) WHERE rn = 1)"
|
|
" SELECT sv.entity_id, sv.vector as svd_vector, le.vector as embedding_vector"
|
|
" FROM svd_vectors sv"
|
|
" LEFT JOIN latest_embeddings le ON CAST(sv.entity_id AS INTEGER) = le.motion_id"
|
|
" WHERE sv.window_id = ? AND sv.entity_type = 'motion'"
|
|
)
|
|
params = (window_id,)
|
|
|
|
rows = conn.execute(sql, params).fetchall()
|
|
_logger.debug(
|
|
"Found %d svd rows for window %s (joined with latest embeddings)",
|
|
len(rows),
|
|
window_id,
|
|
)
|
|
|
|
inserted = 0
|
|
skipped_missing_text = 0
|
|
skipped_missing_svd = 0
|
|
errors = 0
|
|
|
|
for entity_id, svd_json, emb_json in rows:
|
|
# Parse SVD vector
|
|
try:
|
|
svd_vec = json.loads(svd_json)
|
|
except Exception:
|
|
_logger.exception("Invalid SVD vector JSON for entity %s", entity_id)
|
|
skipped_missing_svd += 1
|
|
continue
|
|
|
|
# If there is no embedding joined, skip
|
|
if not emb_json:
|
|
skipped_missing_text += 1
|
|
continue
|
|
|
|
try:
|
|
text_vec = json.loads(emb_json)
|
|
except Exception:
|
|
_logger.exception("Invalid text embedding JSON for motion %s", entity_id)
|
|
skipped_missing_text += 1
|
|
continue
|
|
|
|
try:
|
|
fused = list(svd_vec) + list(text_vec)
|
|
except Exception:
|
|
_logger.exception("Error concatenating vectors for motion %s", entity_id)
|
|
errors += 1
|
|
continue
|
|
|
|
# store fused embedding and check result
|
|
try:
|
|
res = db.store_fused_embedding(
|
|
int(entity_id),
|
|
window_id,
|
|
fused,
|
|
svd_dims=len(svd_vec),
|
|
text_dims=len(text_vec),
|
|
)
|
|
if res and res > 0:
|
|
inserted += 1
|
|
else:
|
|
errors += 1
|
|
_logger.error(
|
|
"Failed to store fused embedding for motion %s (db returned %s)",
|
|
entity_id,
|
|
res,
|
|
)
|
|
except Exception:
|
|
_logger.exception(
|
|
"Exception while storing fused embedding for motion %s", entity_id
|
|
)
|
|
errors += 1
|
|
|
|
conn.close()
|
|
|
|
return {
|
|
"inserted": inserted,
|
|
"skipped_missing_text": skipped_missing_text,
|
|
"skipped_missing_svd": skipped_missing_svd,
|
|
"errors": errors,
|
|
}
|
|
|