"""trajectory.py — Compute MP political drift across aligned time windows. For each MP that appears in multiple windows, computes: - The aligned SVD vector per window - The Euclidean distance between consecutive windows (drift) - Total cumulative drift Returns a dict keyed by mp_name containing per-window positions and drift scores. """ import json import logging from typing import Dict, List, Optional import numpy as np import duckdb try: from scipy.linalg import orthogonal_procrustes as _scipy_procrustes _HAS_SCIPY = True except ImportError: _scipy_procrustes = None # type: ignore[assignment] _HAS_SCIPY = False _logger = logging.getLogger(__name__) def _procrustes_align_windows( window_vecs: Dict[str, Dict[str, np.ndarray]], min_overlap: int = 5, ) -> Dict[str, Dict[str, np.ndarray]]: """Align SVD vectors across windows using Procrustes rotations. Takes the first window as reference and aligns each subsequent window to it via orthogonal Procrustes on the set of common entities. Args: window_vecs: {window_id: {entity_id: vector}} min_overlap: minimum number of common entities needed for alignment Returns same structure with rotated vectors for windows 1..N. """ if not _HAS_SCIPY: _logger.debug("scipy not available, skipping Procrustes alignment") return window_vecs window_ids = list(window_vecs.keys()) if len(window_ids) < 2: return window_vecs result = {window_ids[0]: window_vecs[window_ids[0]]} # Accumulate the aligned reference — each window aligns to the *previous aligned* window prev_aligned = window_vecs[window_ids[0]] for wid in window_ids[1:]: cur = window_vecs[wid] # Only consider common entities whose vectors share the same dimensionality common = [ e for e in cur if e in prev_aligned and cur[e].shape == prev_aligned[e].shape ] # If there are common entities but their vector dimensions differ between # the current and previously aligned window, skip Procrustes alignment # for this window rather than raising an exception in orthogonal_procrustes. if any( e for e in cur if e in prev_aligned and cur[e].shape != prev_aligned[e].shape ): _logger.debug( "Procrustes skipped for %s: vector dimensionality mismatch between windows", wid, ) result[wid] = cur prev_aligned = cur continue if len(common) < min_overlap: _logger.debug( "Procrustes skipped for %s: only %d common entities (need %d)", wid, len(common), min_overlap, ) result[wid] = cur prev_aligned = cur continue ref_mat = np.vstack([prev_aligned[e] for e in common]) cur_mat = np.vstack([cur[e] for e in common]) try: assert _scipy_procrustes is not None R, _ = _scipy_procrustes(cur_mat, ref_mat) aligned = {e: v.dot(R) for e, v in cur.items()} except Exception: _logger.exception("Procrustes failed for window %s", wid) aligned = cur result[wid] = aligned prev_aligned = aligned return result def _load_window_ids(db_path: str) -> List[str]: """Return all distinct window IDs from svd_vectors, in lexicographic order.""" conn = duckdb.connect(db_path) rows = conn.execute( "SELECT DISTINCT window_id FROM svd_vectors WHERE entity_type = 'mp' ORDER BY window_id" ).fetchall() conn.close() return [r[0] for r in rows] def _load_mp_vectors_for_window(db_path: str, window_id: str) -> Dict[str, np.ndarray]: conn = duckdb.connect(db_path) rows = conn.execute( "SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = 'mp'", (window_id,), ).fetchall() conn.close() result = {} for mp_name, vec_json in rows: try: result[mp_name] = np.array(json.loads(vec_json), dtype=float) except Exception: _logger.warning( "Could not parse vector for MP %s window %s", mp_name, window_id ) return result def compute_trajectories( db_path: str, window_ids: Optional[List[str]] = None, normalize: bool = True, ) -> Dict[str, Dict]: """Compute per-MP trajectories across windows. Args: db_path: Path to DuckDB database. window_ids: Subset of window IDs to use (default: all, ordered). normalize: If True (default), L2-normalise each vector before computing drift so that cross-window magnitude differences (caused by different numbers of motions per window) don't inflate drift. Returns: { mp_name: { "windows": [window_id, ...], "vectors": [[...], ...], # one vector per window (raw, not normalised) "drift": [float, ...], # consecutive Euclidean distances on unit sphere "total_drift": float, } } Only MPs present in at least 2 windows are included. """ if window_ids is None: window_ids = _load_window_ids(db_path) if len(window_ids) < 2: _logger.info("Fewer than 2 windows — no trajectories to compute") return {} # Collect per-window vectors keyed as {window_id: {entity_id: vector}} raw_window_vecs: Dict[str, Dict[str, np.ndarray]] = {} for wid in window_ids: raw_window_vecs[wid] = _load_mp_vectors_for_window(db_path, wid) # Align windows via Procrustes to remove arbitrary SVD sign/rotation flips aligned_window_vecs = _procrustes_align_windows(raw_window_vecs) # Reshape into per-MP view mp_data: Dict[str, Dict] = {} for wid in window_ids: for mp_name, vec in aligned_window_vecs[wid].items(): if mp_name not in mp_data: mp_data[mp_name] = {"windows": [], "vectors": []} mp_data[mp_name]["windows"].append(wid) mp_data[mp_name]["vectors"].append(vec) # Compute drift for MPs with >= 2 windows result = {} for mp_name, data in mp_data.items(): if len(data["windows"]) < 2: continue vecs = data["vectors"] if normalize: normed = [] for v in vecs: n = np.linalg.norm(v) normed.append(v / n if n > 1e-10 else v) else: normed = vecs drifts = [ float(np.linalg.norm(normed[i + 1] - normed[i])) for i in range(len(normed) - 1) ] result[mp_name] = { "windows": data["windows"], "vectors": [v.tolist() for v in vecs], "drift": drifts, "total_drift": float(sum(drifts)), } _logger.info( "Trajectories computed for %d MPs across %d windows", len(result), len(window_ids), ) return result def compute_2d_trajectories( db_path: str, method: str = "pca", anchor_kwargs: Optional[Dict] = None, normalize_vectors: bool = True, ) -> Dict[str, Dict]: """Compute 2D trajectory positions for MPs using compute_2d_axes. Returns dict keyed by mp_name with: { 'windows': [window_ids...], 'coords': [[x,y], ...], 'step_vectors': [[dx,dy], ...], 'step_magnitudes': [float,...], 'total_magnitude': float, } Only MPs present in >=2 windows are included. """ from .political_axis import compute_2d_axes window_ids = _load_window_ids(db_path) if len(window_ids) < 2: _logger.info("Fewer than 2 windows — no 2D trajectories to compute") return {} positions_by_window, axes = compute_2d_axes( db_path, window_ids=window_ids, method=method, anchor_kwargs=anchor_kwargs, normalize_vectors=normalize_vectors, ) # Build per-MP time-ordered coords mp_data: Dict[str, Dict] = {} for wid in window_ids: pos = positions_by_window.get(wid, {}) for mp_name, coord in pos.items(): if mp_name not in mp_data: mp_data[mp_name] = {"windows": [], "coords": []} mp_data[mp_name]["windows"].append(wid) mp_data[mp_name]["coords"].append(tuple(coord)) result: Dict[str, Dict] = {} for mp_name, data in mp_data.items(): if len(data["windows"]) < 2: continue coords = [np.array(c, dtype=float) for c in data["coords"]] step_vecs = [coords[i + 1] - coords[i] for i in range(len(coords) - 1)] mags = [float(np.linalg.norm(v)) for v in step_vecs] result[mp_name] = { "windows": data["windows"], "coords": [[float(c[0]), float(c[1])] for c in coords], "step_vectors": [[float(v[0]), float(v[1])] for v in step_vecs], "step_magnitudes": mags, "total_magnitude": float(sum(mags)), } _logger.info("2D trajectories computed for %d MPs", len(result)) return result def top_drifters(trajectories: Dict[str, Dict], n: int = 10) -> List[Dict]: """Return the top-n MPs by total drift, sorted descending. Each entry: {"mp_name": ..., "total_drift": ..., "windows": [...]} """ ranked = sorted( trajectories.items(), key=lambda kv: kv[1]["total_drift"], reverse=True ) return [ { "mp_name": mp, "total_drift": data["total_drift"], "windows": data["windows"], } for mp, data in ranked[:n] ]