You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/analysis/trajectory.py

437 lines
15 KiB

"""trajectory.py — Compute MP political drift across aligned time windows.
For each MP that appears in multiple windows, computes:
- The aligned SVD vector per window
- The Euclidean distance between consecutive windows (drift)
- Total cumulative drift
Returns a dict keyed by mp_name containing per-window positions and drift scores.
"""
import json
import logging
import re
from typing import Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
try:
import duckdb
except (
Exception
): # pragma: no cover - import-time guard for environments without duckdb
duckdb = None # type: ignore
try:
from scipy.linalg import orthogonal_procrustes as _scipy_procrustes
_HAS_SCIPY = True
except ImportError:
_scipy_procrustes = None # type: ignore[assignment]
_HAS_SCIPY = False
_logger = logging.getLogger(__name__)
__all__ = [
"compute_trajectories",
"compute_2d_trajectories",
"top_drifters",
"compute_party_discipline",
"window_to_dates",
"choose_trajectory_title",
]
def _procrustes_align_windows(
window_vecs: Dict[str, Dict[str, np.ndarray]],
min_overlap: int = 5,
) -> Dict[str, Dict[str, np.ndarray]]:
"""Align SVD vectors across windows using Procrustes rotations.
Takes the first window as reference and aligns each subsequent window
to it via orthogonal Procrustes on the set of common entities.
Args:
window_vecs: {window_id: {entity_id: vector}}
min_overlap: minimum number of common entities needed for alignment
Returns same structure with rotated vectors for windows 1..N.
"""
if not _HAS_SCIPY:
_logger.debug("scipy not available, skipping Procrustes alignment")
return window_vecs
window_ids = list(window_vecs.keys())
if len(window_ids) < 2:
return window_vecs
result = {window_ids[0]: window_vecs[window_ids[0]]}
# Accumulate the aligned reference — each window aligns to the *previous aligned* window
prev_aligned = window_vecs[window_ids[0]]
for wid in window_ids[1:]:
cur = window_vecs[wid]
# Only consider common entities whose vectors share the same dimensionality
common = [
e
for e in cur
if e in prev_aligned and cur[e].shape == prev_aligned[e].shape
]
# If there are common entities but their vector dimensions differ between
# the current and previously aligned window, skip Procrustes alignment
# for this window rather than raising an exception in orthogonal_procrustes.
if any(
e
for e in cur
if e in prev_aligned and cur[e].shape != prev_aligned[e].shape
):
_logger.debug(
"Procrustes skipped for %s: vector dimensionality mismatch between windows",
wid,
)
result[wid] = cur
prev_aligned = cur
continue
if len(common) < min_overlap:
_logger.debug(
"Procrustes skipped for %s: only %d common entities (need %d)",
wid,
len(common),
min_overlap,
)
result[wid] = cur
prev_aligned = cur
continue
ref_mat = np.vstack([prev_aligned[e] for e in common])
cur_mat = np.vstack([cur[e] for e in common])
try:
assert _scipy_procrustes is not None
R, _ = _scipy_procrustes(cur_mat, ref_mat)
aligned = {e: v.dot(R) for e, v in cur.items()}
except Exception:
_logger.exception("Procrustes failed for window %s", wid)
aligned = cur
result[wid] = aligned
prev_aligned = aligned
return result
def _load_window_ids(db_path: str) -> List[str]:
"""Return all distinct window IDs from svd_vectors, in lexicographic order."""
conn = duckdb.connect(db_path, read_only=True)
rows = conn.execute(
"SELECT DISTINCT window_id FROM svd_vectors WHERE entity_type = 'mp' ORDER BY window_id"
).fetchall()
conn.close()
return [r[0] for r in rows]
def _load_mp_vectors_for_window(db_path: str, window_id: str) -> Dict[str, np.ndarray]:
conn = duckdb.connect(db_path, read_only=True)
rows = conn.execute(
"SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = 'mp'",
(window_id,),
).fetchall()
conn.close()
result = {}
for mp_name, vec_json in rows:
try:
result[mp_name] = np.array(json.loads(vec_json), dtype=float)
except Exception:
_logger.warning(
"Could not parse vector for MP %s window %s", mp_name, window_id
)
return result
def compute_trajectories(
db_path: str,
window_ids: Optional[List[str]] = None,
normalize: bool = True,
) -> Dict[str, Dict]:
"""Compute per-MP trajectories across windows.
Args:
db_path: Path to DuckDB database.
window_ids: Subset of window IDs to use (default: all, ordered).
normalize: If True (default), L2-normalise each vector before computing
drift so that cross-window magnitude differences (caused by
different numbers of motions per window) don't inflate drift.
Returns:
{
mp_name: {
"windows": [window_id, ...],
"vectors": [[...], ...], # one vector per window (raw, not normalised)
"drift": [float, ...], # consecutive Euclidean distances on unit sphere
"total_drift": float,
}
}
Only MPs present in at least 2 windows are included.
"""
if window_ids is None:
window_ids = _load_window_ids(db_path)
if len(window_ids) < 2:
_logger.info("Fewer than 2 windows — no trajectories to compute")
return {}
# Collect per-window vectors keyed as {window_id: {entity_id: vector}}
raw_window_vecs: Dict[str, Dict[str, np.ndarray]] = {}
for wid in window_ids:
raw_window_vecs[wid] = _load_mp_vectors_for_window(db_path, wid)
# Align windows via Procrustes to remove arbitrary SVD sign/rotation flips
aligned_window_vecs = _procrustes_align_windows(raw_window_vecs)
# Reshape into per-MP view
mp_data: Dict[str, Dict] = {}
for wid in window_ids:
for mp_name, vec in aligned_window_vecs[wid].items():
if mp_name not in mp_data:
mp_data[mp_name] = {"windows": [], "vectors": []}
mp_data[mp_name]["windows"].append(wid)
mp_data[mp_name]["vectors"].append(vec)
# Compute drift for MPs with >= 2 windows
result = {}
for mp_name, data in mp_data.items():
if len(data["windows"]) < 2:
continue
vecs = data["vectors"]
if normalize:
normed = []
for v in vecs:
n = np.linalg.norm(v)
normed.append(v / n if n > 1e-10 else v)
else:
normed = vecs
drifts = [
float(np.linalg.norm(normed[i + 1] - normed[i]))
for i in range(len(normed) - 1)
]
result[mp_name] = {
"windows": data["windows"],
"vectors": [v.tolist() for v in vecs],
"drift": drifts,
"total_drift": float(sum(drifts)),
}
_logger.info(
"Trajectories computed for %d MPs across %d windows",
len(result),
len(window_ids),
)
return result
def compute_2d_trajectories(
db_path: str,
method: str = "pca",
anchor_kwargs: Optional[Dict] = None,
normalize_vectors: bool = True,
) -> Dict[str, Dict]:
"""Compute 2D trajectory positions for MPs using compute_2d_axes.
Returns dict keyed by mp_name with:
{
'windows': [window_ids...],
'coords': [[x,y], ...],
'step_vectors': [[dx,dy], ...],
'step_magnitudes': [float,...],
'total_magnitude': float,
}
Only MPs present in >=2 windows are included.
"""
from .political_axis import compute_2d_axes
window_ids = _load_window_ids(db_path)
if len(window_ids) < 2:
_logger.info("Fewer than 2 windows — no 2D trajectories to compute")
return {}
positions_by_window, axes = compute_2d_axes(
db_path,
window_ids=window_ids,
method=method,
anchor_kwargs=anchor_kwargs,
normalize_vectors=normalize_vectors,
)
# Build per-MP time-ordered coords
mp_data: Dict[str, Dict] = {}
for wid in window_ids:
pos = positions_by_window.get(wid, {})
for mp_name, coord in pos.items():
if mp_name not in mp_data:
mp_data[mp_name] = {"windows": [], "coords": []}
mp_data[mp_name]["windows"].append(wid)
mp_data[mp_name]["coords"].append(tuple(coord))
result: Dict[str, Dict] = {}
for mp_name, data in mp_data.items():
if len(data["windows"]) < 2:
continue
coords = [np.array(c, dtype=float) for c in data["coords"]]
step_vecs = [coords[i + 1] - coords[i] for i in range(len(coords) - 1)]
mags = [float(np.linalg.norm(v)) for v in step_vecs]
result[mp_name] = {
"windows": data["windows"],
"coords": [[float(c[0]), float(c[1])] for c in coords],
"step_vectors": [[float(v[0]), float(v[1])] for v in step_vecs],
"step_magnitudes": mags,
"total_magnitude": float(sum(mags)),
}
_logger.info("2D trajectories computed for %d MPs", len(result))
return result
def top_drifters(trajectories: Dict[str, Dict], n: int = 10) -> List[Dict]:
"""Return the top-n MPs by total drift, sorted descending.
Each entry: {"mp_name": ..., "total_drift": ..., "windows": [...]}
"""
ranked = sorted(
trajectories.items(), key=lambda kv: kv[1]["total_drift"], reverse=True
)
return [
{
"mp_name": mp,
"total_drift": data["total_drift"],
"windows": data["windows"],
}
for mp, data in ranked[:n]
]
def compute_party_discipline(
db_path: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""Compute per-party voting discipline (Rice index) for roll-call votes in a date range.
Only individual MP vote rows are used (mp_name LIKE '%,%').
Returns a DataFrame with columns [party, n_motions, discipline] sorted by discipline ascending.
Returns an empty DataFrame if fewer than 1 qualifying motion exists or on any DB error.
Rice index per motion per party = fraction of party MPs voting with the party majority.
The per-party score is the average Rice index across all motions in the date range.
Only 'voor' and 'tegen' votes are counted; absent and abstaining MPs are excluded.
"""
conn = None
try:
conn = duckdb.connect(db_path, read_only=True)
result = conn.execute(
"""
WITH individual_votes AS (
SELECT
motion_id,
party,
LOWER(vote) AS vote
FROM mp_votes
WHERE mp_name LIKE '%,%'
AND date >= CAST(? AS DATE)
AND date <= CAST(? AS DATE)
AND vote IN ('voor', 'tegen')
),
vote_counts AS (
SELECT
motion_id,
party,
vote,
COUNT(*) AS cnt
FROM individual_votes
GROUP BY motion_id, party, vote
),
majority_vote AS (
SELECT
motion_id,
party,
FIRST(vote ORDER BY cnt DESC, vote ASC) AS maj_vote,
SUM(cnt) AS total_mp_votes
FROM vote_counts
GROUP BY motion_id, party
),
rice_per_motion AS (
SELECT
mv.motion_id,
mv.party,
SUM(CASE WHEN vc.vote = mv.maj_vote THEN vc.cnt ELSE 0 END)
* 1.0 / mv.total_mp_votes AS rice
FROM majority_vote mv
JOIN vote_counts vc
ON mv.motion_id = vc.motion_id AND mv.party = vc.party
GROUP BY mv.motion_id, mv.party, mv.total_mp_votes
)
SELECT
party,
COUNT(DISTINCT motion_id) AS n_motions,
AVG(rice) AS discipline
FROM rice_per_motion
GROUP BY party
ORDER BY discipline ASC
""",
[start_date, end_date],
).fetchdf()
return result
except Exception as exc:
_logger.warning("compute_party_discipline failed: %s", exc)
return pd.DataFrame(columns=["party", "n_motions", "discipline"])
finally:
if conn is not None:
try:
conn.close()
except Exception:
pass
def window_to_dates(window_id: str) -> Tuple[str, str]:
"""Return (start_date, end_date) ISO strings for a given window_id.
Annual windows like '2024' → ('2024-01-01', '2024-12-31').
'current_parliament' → ('2023-11-22', '2099-12-31') (2023 formation date, open end).
Unknown formats → ('2000-01-01', '2099-12-31') (effectively all time).
"""
if window_id == "current_parliament":
return ("2023-11-22", "2099-12-31")
if re.fullmatch(r"\d{4}", window_id):
return (f"{window_id}-01-01", f"{window_id}-12-31")
m = re.fullmatch(r"(\d{4})-Q([1-4])", window_id)
if m:
year, q = int(m.group(1)), int(m.group(2))
starts = {1: "01-01", 2: "04-01", 3: "07-01", 4: "10-01"}
ends = {1: "03-31", 2: "06-30", 3: "09-30", 4: "12-31"}
return (f"{year}-{starts[q]}", f"{year}-{ends[q]}")
return ("2000-01-01", "2099-12-31")
def choose_trajectory_title(axis_def: dict, axis: str, threshold: float = 0.65) -> str:
"""Choose a short trajectory axis title based on aggregated confidence.
axis: 'x' or 'y'. Returns axis_def label when its mean confidence >= threshold,
otherwise returns the compact fallback 'As 1' / 'As 2'. Matches previous logic.
"""
conf_map = axis_def.get(f"{axis}_label_confidence", {}) or {}
vals = [v for v in conf_map.values() if v is not None]
mean = float(sum(vals) / len(vals)) if vals else None
label = axis_def.get(f"{axis}_label")
if mean is not None and mean >= threshold and label:
return label
try:
from analysis.axis_classifier import display_label_for_modal
fallback_modal = "As 1" if axis == "x" else "As 2"
return display_label_for_modal(fallback_modal, axis)
except Exception:
return "As 1" if axis == "x" else "As 2"