@ -15,9 +15,77 @@ from typing import Dict, List, Optional
import numpy as np
import duckdb
try :
from scipy . linalg import orthogonal_procrustes as _scipy_procrustes
_HAS_SCIPY = True
except ImportError :
_scipy_procrustes = None # type: ignore[assignment]
_HAS_SCIPY = False
_logger = logging . getLogger ( __name__ )
def _procrustes_align_windows (
window_vecs : Dict [ str , Dict [ str , np . ndarray ] ] ,
min_overlap : int = 5 ,
) - > Dict [ str , Dict [ str , np . ndarray ] ] :
""" Align SVD vectors across windows using Procrustes rotations.
Takes the first window as reference and aligns each subsequent window
to it via orthogonal Procrustes on the set of common entities .
Args :
window_vecs : { window_id : { entity_id : vector } }
min_overlap : minimum number of common entities needed for alignment
Returns same structure with rotated vectors for windows 1. . N .
"""
if not _HAS_SCIPY :
_logger . debug ( " scipy not available, skipping Procrustes alignment " )
return window_vecs
window_ids = list ( window_vecs . keys ( ) )
if len ( window_ids ) < 2 :
return window_vecs
result = { window_ids [ 0 ] : window_vecs [ window_ids [ 0 ] ] }
# Accumulate the aligned reference — each window aligns to the *previous aligned* window
prev_aligned = window_vecs [ window_ids [ 0 ] ]
for wid in window_ids [ 1 : ] :
cur = window_vecs [ wid ]
common = [ e for e in cur if e in prev_aligned ]
if len ( common ) < min_overlap :
_logger . debug (
" Procrustes skipped for %s : only %d common entities (need %d ) " ,
wid ,
len ( common ) ,
min_overlap ,
)
result [ wid ] = cur
prev_aligned = cur
continue
ref_mat = np . vstack ( [ prev_aligned [ e ] for e in common ] )
cur_mat = np . vstack ( [ cur [ e ] for e in common ] )
try :
assert _scipy_procrustes is not None
R , _ = _scipy_procrustes ( cur_mat , ref_mat )
aligned = { e : v . dot ( R ) for e , v in cur . items ( ) }
except Exception :
_logger . exception ( " Procrustes failed for window %s " , wid )
aligned = cur
result [ wid ] = aligned
prev_aligned = aligned
return result
def _load_window_ids ( db_path : str ) - > List [ str ] :
""" Return all distinct window IDs from svd_vectors, in lexicographic order. """
conn = duckdb . connect ( db_path )
@ -49,15 +117,23 @@ def _load_mp_vectors_for_window(db_path: str, window_id: str) -> Dict[str, np.nd
def compute_trajectories (
db_path : str ,
window_ids : Optional [ List [ str ] ] = None ,
normalize : bool = True ,
) - > Dict [ str , Dict ] :
""" Compute per-MP trajectories across windows.
Args :
db_path : Path to DuckDB database .
window_ids : Subset of window IDs to use ( default : all , ordered ) .
normalize : If True ( default ) , L2 - normalise each vector before computing
drift so that cross - window magnitude differences ( caused by
different numbers of motions per window ) don ' t inflate drift.
Returns :
{
mp_name : {
" windows " : [ window_id , . . . ] ,
" vectors " : [ [ . . . ] , . . . ] , # one vector per window
" drift " : [ float , . . . ] , # consecutive Euclidean distances
" vectors " : [ [ . . . ] , . . . ] , # one vector per window (raw, not normalised)
" drift " : [ float , . . . ] , # consecutive Euclidean distances on unit sphere
" total_drift " : float ,
}
}
@ -70,12 +146,18 @@ def compute_trajectories(
_logger . info ( " Fewer than 2 windows — no trajectories to compute " )
return { }
# Collect per-window vectors for each MP
mp_data : Dict [ str , Dict ] = { }
# Collect per-window vectors keyed as {window_id: {entity_id: vector}}
raw_window_vecs : Dict [ str , Dict [ str , np . ndarray ] ] = { }
for wid in window_ids :
raw_window_vecs [ wid ] = _load_mp_vectors_for_window ( db_path , wid )
# Align windows via Procrustes to remove arbitrary SVD sign/rotation flips
aligned_window_vecs = _procrustes_align_windows ( raw_window_vecs )
# Reshape into per-MP view
mp_data : Dict [ str , Dict ] = { }
for wid in window_ids :
vecs = _load_mp_vectors_for_window ( db_path , wid )
for mp_name , vec in vecs . items ( ) :
for mp_name , vec in aligned_window_vecs [ wid ] . items ( ) :
if mp_name not in mp_data :
mp_data [ mp_name ] = { " windows " : [ ] , " vectors " : [ ] }
mp_data [ mp_name ] [ " windows " ] . append ( wid )
@ -87,8 +169,16 @@ def compute_trajectories(
if len ( data [ " windows " ] ) < 2 :
continue
vecs = data [ " vectors " ]
if normalize :
normed = [ ]
for v in vecs :
n = np . linalg . norm ( v )
normed . append ( v / n if n > 1e-10 else v )
else :
normed = vecs
drifts = [
float ( np . linalg . norm ( vecs [ i + 1 ] - vecs [ i ] ) ) for i in range ( len ( vecs ) - 1 )
float ( np . linalg . norm ( normed [ i + 1 ] - normed [ i ] ) )
for i in range ( len ( normed ) - 1 )
]
result [ mp_name ] = {
" windows " : data [ " windows " ] ,