You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/similarity/compute.py

255 lines
8.2 KiB

import json
import logging
from typing import List, Optional
import numpy as np
from database import MotionDatabase
logger = logging.getLogger(__name__)
def compute_similarities(
vector_type: str = "fused",
window_id: Optional[str] = None,
top_k: int = 10,
db_path: Optional[str] = None,
db=None,
):
"""Compute pairwise cosine similarities for vectors of a given type and store top-k neighbors.
Returns number of inserted rows.
"""
if db is None:
db = (
MotionDatabase(db_path=db_path) if db_path is not None else MotionDatabase()
)
# Build SQL query depending on vector type
if vector_type == "fused":
if window_id is not None:
query = "SELECT motion_id AS id, vector FROM fused_embeddings WHERE window_id = ?"
params = (window_id,)
else:
# fallback to all fused embeddings if no window specified
query = "SELECT motion_id AS id, vector FROM fused_embeddings"
params = ()
elif vector_type == "text":
query = "SELECT motion_id AS id, vector FROM embeddings"
params = ()
elif vector_type == "svd":
if window_id is not None:
query = "SELECT entity_id AS id, vector FROM svd_vectors WHERE entity_type = 'motion' AND window_id = ?"
params = (window_id,)
else:
query = "SELECT entity_id AS id, vector FROM svd_vectors WHERE entity_type = 'motion'"
params = ()
else:
logger.error(f"Unknown vector_type: {vector_type}")
return 0
# Load vectors in a single query
try:
try:
import duckdb
except Exception:
logger.exception("duckdb import failed; cannot load vectors")
return 0
with duckdb.connect(db.db_path) as conn:
rows = conn.execute(query, params).fetchall()
except Exception:
logger.exception("Error loading vectors for similarity compute")
return 0
if not rows:
logger.info("No vectors found for %s window=%s", vector_type, window_id)
return 0
ids: List[int] = []
vecs: List[List[float]] = []
for r in rows:
_id, vec_json = r
# parse vector robustly: accept list/tuple, bytes/bytearray, or JSON string
vec = None
if isinstance(vec_json, (list, tuple)):
vec = list(vec_json)
elif isinstance(vec_json, (bytes, bytearray)):
try:
text = vec_json.decode("utf-8")
except Exception:
logger.warning(
"Skipping row with non-decodable bytes vector for id=%s", _id
)
continue
try:
vec = json.loads(text)
except Exception:
logger.warning(
"Skipping row with invalid JSON bytes vector for id=%s", _id
)
continue
elif isinstance(vec_json, str):
try:
vec = json.loads(vec_json)
except Exception:
logger.warning(
"Skipping row with invalid JSON string vector for id=%s", _id
)
continue
else:
logger.warning(
"Skipping row with unsupported vector type %s for id=%s",
type(vec_json),
_id,
)
continue
# ensure numeric conversion
try:
vec_floats = [float(x) for x in vec]
except Exception:
logger.warning(
"Skipping row with non-numeric vector entries for id=%s", _id
)
continue
# cast id to int for consistency; skip if cannot cast
try:
ids.append(int(_id))
except Exception:
logger.warning("Skipping row with non-integer id=%s", _id)
continue
vecs.append(vec_floats)
if not vecs:
logger.info(
"No valid vectors after parsing for %s window=%s", vector_type, window_id
)
return 0
# Ensure consistent dimensionality: pad shorter vectors with zeros
lengths = [len(v) for v in vecs]
max_dim = max(lengths)
if len(set(lengths)) != 1:
logger.warning(
"Inconsistent vector dimensions detected (max=%d). Padding shorter vectors with zeros.",
max_dim,
)
matrix = np.zeros((len(vecs), max_dim), dtype=np.float32)
for i, v in enumerate(vecs):
matrix[i, : len(v)] = v
# Normalize rows
norms = np.linalg.norm(matrix, axis=1, keepdims=True)
# avoid division by zero
norms[norms == 0] = 1.0
normalized = matrix / norms
# Compute similarity matrix
sim = normalized @ normalized.T
n = sim.shape[0]
rows_to_insert: List[dict] = []
for i in range(n):
scores = sim[i].copy()
# exclude self
scores[i] = -np.inf
# number of neighbors to take is min(top_k, n-1)
k = min(top_k, n - 1)
if k <= 0:
continue
# get top k indices
if k == 1:
idx = int(np.argmax(scores))
top_idx = [idx]
else:
# argpartition for efficiency then sort. avoid negating scores because
# we set self to -inf earlier which would become +inf when negated and
# incorrectly be picked as a top neighbor. Instead, partition at
# n - k to obtain the k largest elements, then sort them descending.
part = np.argpartition(scores, n - k)[-k:]
top_idx = part[np.argsort(-scores[part])]
src_id = ids[i]
for j in top_idx:
rows_to_insert.append(
{
"source_motion_id": int(src_id),
"target_motion_id": int(ids[j]),
"score": float(scores[j]),
"vector_type": vector_type,
"window_id": window_id,
}
)
# Filter trivial 1.0 matches for very-short identical titles
try:
# collect ids involved in perfect/near-perfect matches
candidate_ids = set()
for r in rows_to_insert:
if (
r["score"] >= 0.999999
and r["source_motion_id"] != r["target_motion_id"]
):
candidate_ids.add(r["source_motion_id"])
candidate_ids.add(r["target_motion_id"])
if candidate_ids:
titles_map = db.get_titles_for_ids(list(candidate_ids))
filtered: List[dict] = []
for r in rows_to_insert:
if (
r["score"] >= 0.999999
and r["source_motion_id"] != r["target_motion_id"]
):
t1 = (titles_map.get(r["source_motion_id"]) or "").strip()
t2 = (titles_map.get(r["target_motion_id"]) or "").strip()
if t1 and t1 == t2 and len(t1) < 12:
logger.info(
"Filtered trivial 1.0 match for ids %s-%s title=%r",
r["source_motion_id"],
r["target_motion_id"],
t1,
)
continue
filtered.append(r)
rows_to_insert = filtered
except Exception:
logger.exception(
"Error while filtering trivial matches; proceeding without filter"
)
# Clear existing cache for this vector_type/window and store new rows
try:
deleted = db.clear_similarity_cache(
vector_type=vector_type, window_id=window_id
)
logger.info(
"Cleared %d existing similarity rows for %s window=%s",
deleted,
vector_type,
window_id,
)
except Exception:
logger.exception("Error clearing similarity cache")
try:
inserted = db.store_similarity_batch(rows_to_insert)
logger.info(
"Inserted %d similarity rows for %s window=%s",
inserted,
vector_type,
window_id,
)
return inserted
except Exception:
logger.exception("Error storing similarity rows")
return 0