You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/analysis/clustering.py

136 lines
3.9 KiB

"""clustering.py — UMAP dimensionality reduction on fused embeddings.
Reduces fused motion embeddings to 2D (or 3D) for visualisation,
and optionally labels clusters using KMeans.
Requires: umap-learn, scikit-learn (for KMeans)
"""
import json
import logging
from typing import Dict, List, Optional, Tuple
import numpy as np
try:
import duckdb
except (
Exception
): # pragma: no cover - import-time guard for environments without duckdb
duckdb = None # type: ignore
_logger = logging.getLogger(__name__)
def _load_fused_vectors(
db_path: str, window_id: Optional[str] = None
) -> Tuple[List[int], List[str], np.ndarray]:
"""Load fused embeddings from the DB.
Returns (motion_ids, window_ids, matrix).
Optionally filter by window_id.
"""
conn = duckdb.connect(db_path)
if window_id:
rows = conn.execute(
"SELECT motion_id, window_id, vector FROM fused_embeddings WHERE window_id = ?",
(window_id,),
).fetchall()
else:
rows = conn.execute(
"SELECT motion_id, window_id, vector FROM fused_embeddings ORDER BY window_id, motion_id"
).fetchall()
conn.close()
motion_ids, window_ids, vectors = [], [], []
for motion_id, wid, vec_json in rows:
try:
vec = json.loads(vec_json)
motion_ids.append(int(motion_id))
window_ids.append(wid)
vectors.append(vec)
except Exception:
_logger.warning("Could not parse fused vector for motion %s", motion_id)
if not vectors:
return [], [], np.zeros((0, 0))
# Pad to common length if needed (shouldn't happen if pipeline is consistent)
max_len = max(len(v) for v in vectors)
mat = np.zeros((len(vectors), max_len), dtype=float)
for i, v in enumerate(vectors):
mat[i, : len(v)] = v
return motion_ids, window_ids, mat
def run_umap(
db_path: str,
window_id: Optional[str] = None,
n_components: int = 2,
n_neighbors: int = 15,
min_dist: float = 0.1,
random_state: int = 42,
) -> Dict:
"""Run UMAP on fused embeddings and return 2D/3D coordinates.
Returns:
{
"motion_ids": [...],
"window_ids": [...],
"coords": [[x, y], ...], # or [x, y, z] if n_components=3
"n_components": int,
}
"""
try:
import umap
except ImportError:
_logger.error("umap-learn is not installed; cannot run UMAP")
return {}
motion_ids, window_ids, mat = _load_fused_vectors(db_path, window_id)
if mat.size == 0:
_logger.warning("No fused embeddings found for window_id=%s", window_id)
return {}
if mat.shape[0] < n_neighbors + 1:
# UMAP requires at least n_neighbors+1 samples
n_neighbors = max(2, mat.shape[0] - 1)
_logger.warning(
"Reduced n_neighbors to %d due to small dataset (%d samples)",
n_neighbors,
mat.shape[0],
)
reducer = umap.UMAP(
n_components=n_components,
n_neighbors=n_neighbors,
min_dist=min_dist,
random_state=random_state,
)
coords = reducer.fit_transform(mat)
return {
"motion_ids": motion_ids,
"window_ids": window_ids,
"coords": coords.tolist(),
"n_components": n_components,
}
def cluster_kmeans(
coords: np.ndarray, n_clusters: int = 8, random_state: int = 42
) -> np.ndarray:
"""Run KMeans on 2D/3D UMAP coordinates.
Returns array of integer cluster labels (length = len(coords)).
"""
try:
from sklearn.cluster import KMeans
except ImportError:
_logger.error("scikit-learn is not installed; cannot run KMeans")
return np.zeros(len(coords), dtype=int)
n_clusters = min(n_clusters, len(coords))
km = KMeans(n_clusters=n_clusters, random_state=random_state, n_init="auto")
return km.fit_predict(coords)