You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
214 lines
6.6 KiB
214 lines
6.6 KiB
import json
|
|
import logging
|
|
from typing import List, Optional
|
|
|
|
import numpy as np
|
|
|
|
from database import MotionDatabase
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def compute_similarities(
|
|
vector_type: str = "fused",
|
|
window_id: Optional[str] = None,
|
|
top_k: int = 10,
|
|
db_path: Optional[str] = None,
|
|
):
|
|
"""Compute pairwise cosine similarities for vectors of a given type and store top-k neighbors.
|
|
|
|
Returns number of inserted rows.
|
|
"""
|
|
db = MotionDatabase(db_path=db_path) if db_path is not None else MotionDatabase()
|
|
|
|
# Build SQL query depending on vector type
|
|
if vector_type == "fused":
|
|
if window_id is not None:
|
|
query = "SELECT motion_id AS id, vector FROM fused_embeddings WHERE window_id = ?"
|
|
params = (window_id,)
|
|
else:
|
|
# fallback to all fused embeddings if no window specified
|
|
query = "SELECT motion_id AS id, vector FROM fused_embeddings"
|
|
params = ()
|
|
elif vector_type == "text":
|
|
query = "SELECT motion_id AS id, vector FROM embeddings"
|
|
params = ()
|
|
elif vector_type == "svd":
|
|
if window_id is not None:
|
|
query = "SELECT entity_id AS id, vector FROM svd_vectors WHERE entity_type = 'motion' AND window_id = ?"
|
|
params = (window_id,)
|
|
else:
|
|
query = "SELECT entity_id AS id, vector FROM svd_vectors WHERE entity_type = 'motion'"
|
|
params = ()
|
|
else:
|
|
logger.error(f"Unknown vector_type: {vector_type}")
|
|
return 0
|
|
|
|
# Load vectors in a single query
|
|
try:
|
|
try:
|
|
import duckdb
|
|
except Exception:
|
|
logger.exception("duckdb import failed; cannot load vectors")
|
|
return 0
|
|
|
|
with duckdb.connect(db.db_path) as conn:
|
|
rows = conn.execute(query, params).fetchall()
|
|
except Exception:
|
|
logger.exception("Error loading vectors for similarity compute")
|
|
return 0
|
|
|
|
if not rows:
|
|
logger.info("No vectors found for %s window=%s", vector_type, window_id)
|
|
return 0
|
|
|
|
ids: List[int] = []
|
|
vecs: List[List[float]] = []
|
|
|
|
for r in rows:
|
|
_id, vec_json = r
|
|
# parse vector robustly: accept list/tuple, bytes/bytearray, or JSON string
|
|
vec = None
|
|
if isinstance(vec_json, (list, tuple)):
|
|
vec = list(vec_json)
|
|
elif isinstance(vec_json, (bytes, bytearray)):
|
|
try:
|
|
text = vec_json.decode("utf-8")
|
|
except Exception:
|
|
logger.warning(
|
|
"Skipping row with non-decodable bytes vector for id=%s", _id
|
|
)
|
|
continue
|
|
try:
|
|
vec = json.loads(text)
|
|
except Exception:
|
|
logger.warning(
|
|
"Skipping row with invalid JSON bytes vector for id=%s", _id
|
|
)
|
|
continue
|
|
elif isinstance(vec_json, str):
|
|
try:
|
|
vec = json.loads(vec_json)
|
|
except Exception:
|
|
logger.warning(
|
|
"Skipping row with invalid JSON string vector for id=%s", _id
|
|
)
|
|
continue
|
|
else:
|
|
logger.warning(
|
|
"Skipping row with unsupported vector type %s for id=%s",
|
|
type(vec_json),
|
|
_id,
|
|
)
|
|
continue
|
|
|
|
# ensure numeric conversion
|
|
try:
|
|
vec_floats = [float(x) for x in vec]
|
|
except Exception:
|
|
logger.warning(
|
|
"Skipping row with non-numeric vector entries for id=%s", _id
|
|
)
|
|
continue
|
|
|
|
# cast id to int for consistency; skip if cannot cast
|
|
try:
|
|
ids.append(int(_id))
|
|
except Exception:
|
|
logger.warning("Skipping row with non-integer id=%s", _id)
|
|
continue
|
|
|
|
vecs.append(vec_floats)
|
|
|
|
if not vecs:
|
|
logger.info(
|
|
"No valid vectors after parsing for %s window=%s", vector_type, window_id
|
|
)
|
|
return 0
|
|
|
|
# Ensure consistent dimensionality: pad shorter vectors with zeros
|
|
lengths = [len(v) for v in vecs]
|
|
max_dim = max(lengths)
|
|
if len(set(lengths)) != 1:
|
|
logger.warning(
|
|
"Inconsistent vector dimensions detected (max=%d). Padding shorter vectors with zeros.",
|
|
max_dim,
|
|
)
|
|
|
|
matrix = np.zeros((len(vecs), max_dim), dtype=np.float32)
|
|
for i, v in enumerate(vecs):
|
|
matrix[i, : len(v)] = v
|
|
|
|
# Normalize rows
|
|
norms = np.linalg.norm(matrix, axis=1, keepdims=True)
|
|
# avoid division by zero
|
|
norms[norms == 0] = 1.0
|
|
normalized = matrix / norms
|
|
|
|
# Compute similarity matrix
|
|
sim = normalized @ normalized.T
|
|
|
|
n = sim.shape[0]
|
|
rows_to_insert: List[dict] = []
|
|
|
|
for i in range(n):
|
|
scores = sim[i].copy()
|
|
# exclude self
|
|
scores[i] = -np.inf
|
|
|
|
# number of neighbors to take is min(top_k, n-1)
|
|
k = min(top_k, n - 1)
|
|
if k <= 0:
|
|
continue
|
|
|
|
# get top k indices
|
|
if k == 1:
|
|
idx = int(np.argmax(scores))
|
|
top_idx = [idx]
|
|
else:
|
|
# argpartition for efficiency then sort. avoid negating scores because
|
|
# we set self to -inf earlier which would become +inf when negated and
|
|
# incorrectly be picked as a top neighbor. Instead, partition at
|
|
# n - k to obtain the k largest elements, then sort them descending.
|
|
part = np.argpartition(scores, n - k)[-k:]
|
|
top_idx = part[np.argsort(-scores[part])]
|
|
|
|
src_id = ids[i]
|
|
for j in top_idx:
|
|
rows_to_insert.append(
|
|
{
|
|
"source_motion_id": int(src_id),
|
|
"target_motion_id": int(ids[j]),
|
|
"score": float(scores[j]),
|
|
"vector_type": vector_type,
|
|
"window_id": window_id,
|
|
}
|
|
)
|
|
|
|
# Clear existing cache for this vector_type/window and store new rows
|
|
try:
|
|
deleted = db.clear_similarity_cache(
|
|
vector_type=vector_type, window_id=window_id
|
|
)
|
|
logger.info(
|
|
"Cleared %d existing similarity rows for %s window=%s",
|
|
deleted,
|
|
vector_type,
|
|
window_id,
|
|
)
|
|
except Exception:
|
|
logger.exception("Error clearing similarity cache")
|
|
|
|
try:
|
|
inserted = db.store_similarity_batch(rows_to_insert)
|
|
logger.info(
|
|
"Inserted %d similarity rows for %s window=%s",
|
|
inserted,
|
|
vector_type,
|
|
window_id,
|
|
)
|
|
return inserted
|
|
except Exception:
|
|
logger.exception("Error storing similarity rows")
|
|
return 0
|
|
|