You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
130 lines
3.7 KiB
130 lines
3.7 KiB
"""clustering.py — UMAP dimensionality reduction on fused embeddings.
|
|
|
|
Reduces fused motion embeddings to 2D (or 3D) for visualisation,
|
|
and optionally labels clusters using KMeans.
|
|
|
|
Requires: umap-learn, scikit-learn (for KMeans)
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import duckdb
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _load_fused_vectors(
|
|
db_path: str, window_id: Optional[str] = None
|
|
) -> Tuple[List[int], List[str], np.ndarray]:
|
|
"""Load fused embeddings from the DB.
|
|
|
|
Returns (motion_ids, window_ids, matrix).
|
|
Optionally filter by window_id.
|
|
"""
|
|
conn = duckdb.connect(db_path)
|
|
if window_id:
|
|
rows = conn.execute(
|
|
"SELECT motion_id, window_id, vector FROM fused_embeddings WHERE window_id = ?",
|
|
(window_id,),
|
|
).fetchall()
|
|
else:
|
|
rows = conn.execute(
|
|
"SELECT motion_id, window_id, vector FROM fused_embeddings ORDER BY window_id, motion_id"
|
|
).fetchall()
|
|
conn.close()
|
|
|
|
motion_ids, window_ids, vectors = [], [], []
|
|
for motion_id, wid, vec_json in rows:
|
|
try:
|
|
vec = json.loads(vec_json)
|
|
motion_ids.append(int(motion_id))
|
|
window_ids.append(wid)
|
|
vectors.append(vec)
|
|
except Exception:
|
|
_logger.warning("Could not parse fused vector for motion %s", motion_id)
|
|
|
|
if not vectors:
|
|
return [], [], np.zeros((0, 0))
|
|
|
|
# Pad to common length if needed (shouldn't happen if pipeline is consistent)
|
|
max_len = max(len(v) for v in vectors)
|
|
mat = np.zeros((len(vectors), max_len), dtype=float)
|
|
for i, v in enumerate(vectors):
|
|
mat[i, : len(v)] = v
|
|
|
|
return motion_ids, window_ids, mat
|
|
|
|
|
|
def run_umap(
|
|
db_path: str,
|
|
window_id: Optional[str] = None,
|
|
n_components: int = 2,
|
|
n_neighbors: int = 15,
|
|
min_dist: float = 0.1,
|
|
random_state: int = 42,
|
|
) -> Dict:
|
|
"""Run UMAP on fused embeddings and return 2D/3D coordinates.
|
|
|
|
Returns:
|
|
{
|
|
"motion_ids": [...],
|
|
"window_ids": [...],
|
|
"coords": [[x, y], ...], # or [x, y, z] if n_components=3
|
|
"n_components": int,
|
|
}
|
|
"""
|
|
try:
|
|
import umap
|
|
except ImportError:
|
|
_logger.error("umap-learn is not installed; cannot run UMAP")
|
|
return {}
|
|
|
|
motion_ids, window_ids, mat = _load_fused_vectors(db_path, window_id)
|
|
if mat.size == 0:
|
|
_logger.warning("No fused embeddings found for window_id=%s", window_id)
|
|
return {}
|
|
|
|
if mat.shape[0] < n_neighbors + 1:
|
|
# UMAP requires at least n_neighbors+1 samples
|
|
n_neighbors = max(2, mat.shape[0] - 1)
|
|
_logger.warning(
|
|
"Reduced n_neighbors to %d due to small dataset (%d samples)",
|
|
n_neighbors,
|
|
mat.shape[0],
|
|
)
|
|
|
|
reducer = umap.UMAP(
|
|
n_components=n_components,
|
|
n_neighbors=n_neighbors,
|
|
min_dist=min_dist,
|
|
random_state=random_state,
|
|
)
|
|
coords = reducer.fit_transform(mat)
|
|
|
|
return {
|
|
"motion_ids": motion_ids,
|
|
"window_ids": window_ids,
|
|
"coords": coords.tolist(),
|
|
"n_components": n_components,
|
|
}
|
|
|
|
|
|
def cluster_kmeans(
|
|
coords: np.ndarray, n_clusters: int = 8, random_state: int = 42
|
|
) -> np.ndarray:
|
|
"""Run KMeans on 2D/3D UMAP coordinates.
|
|
|
|
Returns array of integer cluster labels (length = len(coords)).
|
|
"""
|
|
try:
|
|
from sklearn.cluster import KMeans
|
|
except ImportError:
|
|
_logger.error("scikit-learn is not installed; cannot run KMeans")
|
|
return np.zeros(len(coords), dtype=int)
|
|
|
|
n_clusters = min(n_clusters, len(coords))
|
|
km = KMeans(n_clusters=n_clusters, random_state=random_state, n_init="auto")
|
|
return km.fit_predict(coords)
|
|
|