diff --git a/analysis/trajectory.py b/analysis/trajectory.py index 44bbdd5..ef4e782 100644 --- a/analysis/trajectory.py +++ b/analysis/trajectory.py @@ -15,9 +15,77 @@ from typing import Dict, List, Optional import numpy as np import duckdb +try: + from scipy.linalg import orthogonal_procrustes as _scipy_procrustes + + _HAS_SCIPY = True +except ImportError: + _scipy_procrustes = None # type: ignore[assignment] + _HAS_SCIPY = False + _logger = logging.getLogger(__name__) +def _procrustes_align_windows( + window_vecs: Dict[str, Dict[str, np.ndarray]], + min_overlap: int = 5, +) -> Dict[str, Dict[str, np.ndarray]]: + """Align SVD vectors across windows using Procrustes rotations. + + Takes the first window as reference and aligns each subsequent window + to it via orthogonal Procrustes on the set of common entities. + + Args: + window_vecs: {window_id: {entity_id: vector}} + min_overlap: minimum number of common entities needed for alignment + + Returns same structure with rotated vectors for windows 1..N. + """ + if not _HAS_SCIPY: + _logger.debug("scipy not available, skipping Procrustes alignment") + return window_vecs + + window_ids = list(window_vecs.keys()) + if len(window_ids) < 2: + return window_vecs + + result = {window_ids[0]: window_vecs[window_ids[0]]} + + # Accumulate the aligned reference — each window aligns to the *previous aligned* window + prev_aligned = window_vecs[window_ids[0]] + + for wid in window_ids[1:]: + cur = window_vecs[wid] + common = [e for e in cur if e in prev_aligned] + + if len(common) < min_overlap: + _logger.debug( + "Procrustes skipped for %s: only %d common entities (need %d)", + wid, + len(common), + min_overlap, + ) + result[wid] = cur + prev_aligned = cur + continue + + ref_mat = np.vstack([prev_aligned[e] for e in common]) + cur_mat = np.vstack([cur[e] for e in common]) + + try: + assert _scipy_procrustes is not None + R, _ = _scipy_procrustes(cur_mat, ref_mat) + aligned = {e: v.dot(R) for e, v in cur.items()} + except Exception: + _logger.exception("Procrustes failed for window %s", wid) + aligned = cur + + result[wid] = aligned + prev_aligned = aligned + + return result + + def _load_window_ids(db_path: str) -> List[str]: """Return all distinct window IDs from svd_vectors, in lexicographic order.""" conn = duckdb.connect(db_path) @@ -49,15 +117,23 @@ def _load_mp_vectors_for_window(db_path: str, window_id: str) -> Dict[str, np.nd def compute_trajectories( db_path: str, window_ids: Optional[List[str]] = None, + normalize: bool = True, ) -> Dict[str, Dict]: """Compute per-MP trajectories across windows. + Args: + db_path: Path to DuckDB database. + window_ids: Subset of window IDs to use (default: all, ordered). + normalize: If True (default), L2-normalise each vector before computing + drift so that cross-window magnitude differences (caused by + different numbers of motions per window) don't inflate drift. + Returns: { mp_name: { "windows": [window_id, ...], - "vectors": [[...], ...], # one vector per window - "drift": [float, ...], # consecutive Euclidean distances + "vectors": [[...], ...], # one vector per window (raw, not normalised) + "drift": [float, ...], # consecutive Euclidean distances on unit sphere "total_drift": float, } } @@ -70,12 +146,18 @@ def compute_trajectories( _logger.info("Fewer than 2 windows — no trajectories to compute") return {} - # Collect per-window vectors for each MP - mp_data: Dict[str, Dict] = {} + # Collect per-window vectors keyed as {window_id: {entity_id: vector}} + raw_window_vecs: Dict[str, Dict[str, np.ndarray]] = {} + for wid in window_ids: + raw_window_vecs[wid] = _load_mp_vectors_for_window(db_path, wid) + + # Align windows via Procrustes to remove arbitrary SVD sign/rotation flips + aligned_window_vecs = _procrustes_align_windows(raw_window_vecs) + # Reshape into per-MP view + mp_data: Dict[str, Dict] = {} for wid in window_ids: - vecs = _load_mp_vectors_for_window(db_path, wid) - for mp_name, vec in vecs.items(): + for mp_name, vec in aligned_window_vecs[wid].items(): if mp_name not in mp_data: mp_data[mp_name] = {"windows": [], "vectors": []} mp_data[mp_name]["windows"].append(wid) @@ -87,8 +169,16 @@ def compute_trajectories( if len(data["windows"]) < 2: continue vecs = data["vectors"] + if normalize: + normed = [] + for v in vecs: + n = np.linalg.norm(v) + normed.append(v / n if n > 1e-10 else v) + else: + normed = vecs drifts = [ - float(np.linalg.norm(vecs[i + 1] - vecs[i])) for i in range(len(vecs) - 1) + float(np.linalg.norm(normed[i + 1] - normed[i])) + for i in range(len(normed) - 1) ] result[mp_name] = { "windows": data["windows"], diff --git a/outputs/trajectories_normalized_top15.html b/outputs/trajectories_normalized_top15.html new file mode 100644 index 0000000..2124e2e --- /dev/null +++ b/outputs/trajectories_normalized_top15.html @@ -0,0 +1,7 @@ + + + +
+
+ + \ No newline at end of file diff --git a/outputs/trajectories_party_aligned.html b/outputs/trajectories_party_aligned.html new file mode 100644 index 0000000..c659e10 --- /dev/null +++ b/outputs/trajectories_party_aligned.html @@ -0,0 +1,7 @@ + + + +
+
+ + \ No newline at end of file