You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/analysis/trajectory.py

297 lines
9.6 KiB

"""trajectory.py — Compute MP political drift across aligned time windows.
For each MP that appears in multiple windows, computes:
- The aligned SVD vector per window
- The Euclidean distance between consecutive windows (drift)
- Total cumulative drift
Returns a dict keyed by mp_name containing per-window positions and drift scores.
"""
import json
import logging
from typing import Dict, List, Optional
import numpy as np
import duckdb
try:
from scipy.linalg import orthogonal_procrustes as _scipy_procrustes
_HAS_SCIPY = True
except ImportError:
_scipy_procrustes = None # type: ignore[assignment]
_HAS_SCIPY = False
_logger = logging.getLogger(__name__)
def _procrustes_align_windows(
window_vecs: Dict[str, Dict[str, np.ndarray]],
min_overlap: int = 5,
) -> Dict[str, Dict[str, np.ndarray]]:
"""Align SVD vectors across windows using Procrustes rotations.
Takes the first window as reference and aligns each subsequent window
to it via orthogonal Procrustes on the set of common entities.
Args:
window_vecs: {window_id: {entity_id: vector}}
min_overlap: minimum number of common entities needed for alignment
Returns same structure with rotated vectors for windows 1..N.
"""
if not _HAS_SCIPY:
_logger.debug("scipy not available, skipping Procrustes alignment")
return window_vecs
window_ids = list(window_vecs.keys())
if len(window_ids) < 2:
return window_vecs
result = {window_ids[0]: window_vecs[window_ids[0]]}
# Accumulate the aligned reference — each window aligns to the *previous aligned* window
prev_aligned = window_vecs[window_ids[0]]
for wid in window_ids[1:]:
cur = window_vecs[wid]
# Only consider common entities whose vectors share the same dimensionality
common = [
e
for e in cur
if e in prev_aligned and cur[e].shape == prev_aligned[e].shape
]
# If there are common entities but their vector dimensions differ between
# the current and previously aligned window, skip Procrustes alignment
# for this window rather than raising an exception in orthogonal_procrustes.
if any(
e
for e in cur
if e in prev_aligned and cur[e].shape != prev_aligned[e].shape
):
_logger.debug(
"Procrustes skipped for %s: vector dimensionality mismatch between windows",
wid,
)
result[wid] = cur
prev_aligned = cur
continue
if len(common) < min_overlap:
_logger.debug(
"Procrustes skipped for %s: only %d common entities (need %d)",
wid,
len(common),
min_overlap,
)
result[wid] = cur
prev_aligned = cur
continue
ref_mat = np.vstack([prev_aligned[e] for e in common])
cur_mat = np.vstack([cur[e] for e in common])
try:
assert _scipy_procrustes is not None
R, _ = _scipy_procrustes(cur_mat, ref_mat)
aligned = {e: v.dot(R) for e, v in cur.items()}
except Exception:
_logger.exception("Procrustes failed for window %s", wid)
aligned = cur
result[wid] = aligned
prev_aligned = aligned
return result
def _load_window_ids(db_path: str) -> List[str]:
"""Return all distinct window IDs from svd_vectors, in lexicographic order."""
conn = duckdb.connect(db_path)
rows = conn.execute(
"SELECT DISTINCT window_id FROM svd_vectors WHERE entity_type = 'mp' ORDER BY window_id"
).fetchall()
conn.close()
return [r[0] for r in rows]
def _load_mp_vectors_for_window(db_path: str, window_id: str) -> Dict[str, np.ndarray]:
conn = duckdb.connect(db_path)
rows = conn.execute(
"SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = 'mp'",
(window_id,),
).fetchall()
conn.close()
result = {}
for mp_name, vec_json in rows:
try:
result[mp_name] = np.array(json.loads(vec_json), dtype=float)
except Exception:
_logger.warning(
"Could not parse vector for MP %s window %s", mp_name, window_id
)
return result
def compute_trajectories(
db_path: str,
window_ids: Optional[List[str]] = None,
normalize: bool = True,
) -> Dict[str, Dict]:
"""Compute per-MP trajectories across windows.
Args:
db_path: Path to DuckDB database.
window_ids: Subset of window IDs to use (default: all, ordered).
normalize: If True (default), L2-normalise each vector before computing
drift so that cross-window magnitude differences (caused by
different numbers of motions per window) don't inflate drift.
Returns:
{
mp_name: {
"windows": [window_id, ...],
"vectors": [[...], ...], # one vector per window (raw, not normalised)
"drift": [float, ...], # consecutive Euclidean distances on unit sphere
"total_drift": float,
}
}
Only MPs present in at least 2 windows are included.
"""
if window_ids is None:
window_ids = _load_window_ids(db_path)
if len(window_ids) < 2:
_logger.info("Fewer than 2 windows — no trajectories to compute")
return {}
# Collect per-window vectors keyed as {window_id: {entity_id: vector}}
raw_window_vecs: Dict[str, Dict[str, np.ndarray]] = {}
for wid in window_ids:
raw_window_vecs[wid] = _load_mp_vectors_for_window(db_path, wid)
# Align windows via Procrustes to remove arbitrary SVD sign/rotation flips
aligned_window_vecs = _procrustes_align_windows(raw_window_vecs)
# Reshape into per-MP view
mp_data: Dict[str, Dict] = {}
for wid in window_ids:
for mp_name, vec in aligned_window_vecs[wid].items():
if mp_name not in mp_data:
mp_data[mp_name] = {"windows": [], "vectors": []}
mp_data[mp_name]["windows"].append(wid)
mp_data[mp_name]["vectors"].append(vec)
# Compute drift for MPs with >= 2 windows
result = {}
for mp_name, data in mp_data.items():
if len(data["windows"]) < 2:
continue
vecs = data["vectors"]
if normalize:
normed = []
for v in vecs:
n = np.linalg.norm(v)
normed.append(v / n if n > 1e-10 else v)
else:
normed = vecs
drifts = [
float(np.linalg.norm(normed[i + 1] - normed[i]))
for i in range(len(normed) - 1)
]
result[mp_name] = {
"windows": data["windows"],
"vectors": [v.tolist() for v in vecs],
"drift": drifts,
"total_drift": float(sum(drifts)),
}
_logger.info(
"Trajectories computed for %d MPs across %d windows",
len(result),
len(window_ids),
)
return result
def compute_2d_trajectories(
db_path: str,
method: str = "pca",
anchor_kwargs: Optional[Dict] = None,
normalize_vectors: bool = True,
) -> Dict[str, Dict]:
"""Compute 2D trajectory positions for MPs using compute_2d_axes.
Returns dict keyed by mp_name with:
{
'windows': [window_ids...],
'coords': [[x,y], ...],
'step_vectors': [[dx,dy], ...],
'step_magnitudes': [float,...],
'total_magnitude': float,
}
Only MPs present in >=2 windows are included.
"""
from .political_axis import compute_2d_axes
window_ids = _load_window_ids(db_path)
if len(window_ids) < 2:
_logger.info("Fewer than 2 windows — no 2D trajectories to compute")
return {}
positions_by_window, axes = compute_2d_axes(
db_path,
window_ids=window_ids,
method=method,
anchor_kwargs=anchor_kwargs,
normalize_vectors=normalize_vectors,
)
# Build per-MP time-ordered coords
mp_data: Dict[str, Dict] = {}
for wid in window_ids:
pos = positions_by_window.get(wid, {})
for mp_name, coord in pos.items():
if mp_name not in mp_data:
mp_data[mp_name] = {"windows": [], "coords": []}
mp_data[mp_name]["windows"].append(wid)
mp_data[mp_name]["coords"].append(tuple(coord))
result: Dict[str, Dict] = {}
for mp_name, data in mp_data.items():
if len(data["windows"]) < 2:
continue
coords = [np.array(c, dtype=float) for c in data["coords"]]
step_vecs = [coords[i + 1] - coords[i] for i in range(len(coords) - 1)]
mags = [float(np.linalg.norm(v)) for v in step_vecs]
result[mp_name] = {
"windows": data["windows"],
"coords": [[float(c[0]), float(c[1])] for c in coords],
"step_vectors": [[float(v[0]), float(v[1])] for v in step_vecs],
"step_magnitudes": mags,
"total_magnitude": float(sum(mags)),
}
_logger.info("2D trajectories computed for %d MPs", len(result))
return result
def top_drifters(trajectories: Dict[str, Dict], n: int = 10) -> List[Dict]:
"""Return the top-n MPs by total drift, sorted descending.
Each entry: {"mp_name": ..., "total_drift": ..., "windows": [...]}
"""
ranked = sorted(
trajectories.items(), key=lambda kv: kv[1]["total_drift"], reverse=True
)
return [
{
"mp_name": mp,
"total_drift": data["total_drift"],
"windows": data["windows"],
}
for mp, data in ranked[:n]
]