You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
motief/analysis/right_wing/svd_trajectory_viz.py

365 lines
11 KiB

#!/usr/bin/env python3
"""Visualize SVD spatial drift over 10 annual windows.
Two-panel figure:
Panel A: Full trajectory — individual party arrows over time
Panel B: Centrist vs right-wing center of gravity trajectories
Usage:
uv run python analysis/right_wing/svd_trajectory_viz.py
"""
from __future__ import annotations
import logging
import os
import sys
from pathlib import Path
from typing import Dict, List
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
matplotlib.use("Agg")
from analysis.right_wing.common import ROOT, DB_PATH, REPORTS_DIR
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from analysis.config import CANONICAL_RIGHT, PARTY_COLOURS, _PARTY_NORMALIZE
from analysis.explorer_data import (
get_uniform_dim_windows,
load_party_scores_all_windows_aligned,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("svd_trajectory_viz")
CANONICAL_CENTRIST = frozenset(
{"VVD", "D66", "CDA", "NSC", "BBB", "CU", "ChristenUnie"}
)
OUTPUT_PATH = str(REPORTS_DIR / "svd_trajectory_figure.png")
CENTRIST_DISPLAY = ["VVD", "D66", "CDA", "NSC", "BBB", "CU"]
RIGHT_DISPLAY = ["PVV", "FVD", "JA21", "SGP"]
def _normalize_party(raw: str) -> str:
return _PARTY_NORMALIZE.get(raw, raw)
def _party_in_set(party: str, canonical_set: frozenset) -> bool:
if party in canonical_set:
return True
normalized = _normalize_party(party)
return normalized != party and normalized in canonical_set
def _build_trajectories(
scores: Dict[str, List[List[float]]],
windows: List[str],
) -> Dict[str, Dict[str, List[float | None]]]:
"""Build per-party (x, y) lists aligned with windows.
Returns {party: {"x": [...], "y": [...], "windows": [...]}}
where each list has one entry per window (None if party missing).
"""
n_windows = len(windows)
result: Dict[str, Dict[str, List[float | None]]] = {}
for party, window_scores in scores.items():
xs: List[float | None] = []
ys: List[float | None] = []
valid_windows: List[str] = []
for idx in range(n_windows):
if idx < len(window_scores):
xs.append(window_scores[idx][0])
ys.append(window_scores[idx][1])
valid_windows.append(windows[idx])
else:
xs.append(None)
ys.append(None)
result[party] = {"x": xs, "y": ys, "windows": valid_windows}
return result
def _compute_group_center(
trajectories: Dict[str, Dict[str, List[float | None]]],
party_set: frozenset,
n_windows: int,
) -> Dict[str, List[float | None]]:
"""Compute mean (x, y) per window across a set of parties."""
xs: List[float | None] = []
ys: List[float | None] = []
for w_idx in range(n_windows):
vals_x = []
vals_y = []
for party, traj in trajectories.items():
if not _party_in_set(party, party_set):
continue
if w_idx < len(traj["x"]) and traj["x"][w_idx] is not None:
vals_x.append(traj["x"][w_idx])
vals_y.append(traj["y"][w_idx])
if vals_x:
xs.append(float(np.mean(vals_x)))
ys.append(float(np.mean(vals_y)))
else:
xs.append(None)
ys.append(None)
return {"x": xs, "y": ys}
def _plot_party_trajectory(
ax: plt.Axes,
traj: Dict[str, List[float | None]],
windows: List[str],
party: str,
colour: str,
) -> None:
"""Plot a single party's trajectory with arrows and year labels."""
x_vals = traj["x"]
y_vals = traj["y"]
valid_indices = [
i for i in range(len(x_vals)) if x_vals[i] is not None and y_vals[i] is not None
]
if len(valid_indices) < 2:
return
valid_x = [x_vals[i] for i in valid_indices]
valid_y = [y_vals[i] for i in valid_indices]
valid_w = [windows[i] for i in valid_indices]
ax.plot(valid_x, valid_y, "-", color=colour, linewidth=1.2, alpha=0.5, zorder=1)
for i in range(len(valid_x) - 1):
ax.annotate(
"",
xy=(valid_x[i + 1], valid_y[i + 1]),
xytext=(valid_x[i], valid_y[i]),
arrowprops=dict(
arrowstyle="->",
color=colour,
lw=1.0,
alpha=0.5,
shrinkA=4,
shrinkB=4,
),
zorder=2,
)
ax.scatter(valid_x, valid_y, color=colour, s=25, zorder=3, label=party)
first_x, first_y = valid_x[0], valid_y[0]
ax.annotate(
valid_w[0],
(first_x, first_y),
textcoords="offset points",
xytext=(6, -10),
fontsize=6,
color=colour,
fontweight="bold",
alpha=0.8,
)
last_x, last_y = valid_x[-1], valid_y[-1]
ax.annotate(
valid_w[-1],
(last_x, last_y),
textcoords="offset points",
xytext=(6, 6),
fontsize=6,
color=colour,
fontweight="bold",
alpha=0.8,
)
def main() -> None:
os.makedirs(str(REPORTS_DIR), exist_ok=True)
logger.info("Loading aligned party positions...")
windows = get_uniform_dim_windows(DB_PATH)
if not windows:
logger.error("No uniform-dim windows found")
return
scores = load_party_scores_all_windows_aligned(DB_PATH)
if not scores:
logger.error("No aligned party scores loaded")
return
logger.info("Windows: %s", windows)
logger.info("Parties: %s", sorted(scores.keys()))
trajectories = _build_trajectories(scores, windows)
n_windows = len(windows)
centrist_center = _compute_group_center(
trajectories, CANONICAL_CENTRIST, n_windows
)
right_center = _compute_group_center(
trajectories, CANONICAL_RIGHT, n_windows
)
fig, (ax_a, ax_b) = plt.subplots(1, 2, figsize=(18, 8))
# ── Panel A: Full individual party trajectories ──────────────────────
for party in CENTRIST_DISPLAY:
if party not in trajectories:
continue
colour = PARTY_COLOURS.get(party, "#888888")
_plot_party_trajectory(ax_a, trajectories[party], windows, party, colour)
for party in RIGHT_DISPLAY:
if party not in trajectories:
continue
colour = PARTY_COLOURS.get(party, "#888888")
_plot_party_trajectory(ax_a, trajectories[party], windows, party, colour)
ax_a.axhline(0, color="#CCCCCC", linewidth=0.5, linestyle="-")
ax_a.axvline(0, color="#CCCCCC", linewidth=0.5, linestyle="-")
ax_a.set_xlabel("PCA Axis 1 (Procrustes-aligned)")
ax_a.set_ylabel("PCA Axis 2 (Procrustes-aligned)")
ax_a.set_title("Panel A: Party Trajectories (All Windows)", fontsize=11)
ax_a.set_aspect("equal", adjustable="datalim")
ax_a.grid(True, alpha=0.2)
ax_a.legend(loc="upper left", fontsize=7, framealpha=0.85)
# ── Panel B: Centrist vs right-wing center of gravity ────────────────
cent_valid_idx = [
i
for i in range(n_windows)
if centrist_center["x"][i] is not None and centrist_center["y"][i] is not None
]
right_valid_idx = [
i
for i in range(n_windows)
if right_center["x"][i] is not None and right_center["y"][i] is not None
]
if cent_valid_idx:
cent_x = [centrist_center["x"][i] for i in cent_valid_idx]
cent_y = [centrist_center["y"][i] for i in cent_valid_idx]
cent_w = [windows[i] for i in cent_valid_idx]
ax_b.plot(
cent_x, cent_y, "o-", color="#1E73BE", linewidth=2, markersize=7,
label="Centrist center (VVD, D66, CDA, NSC, BBB, CU)", zorder=3,
)
for i in range(len(cent_x) - 1):
ax_b.annotate(
"",
xy=(cent_x[i + 1], cent_y[i + 1]),
xytext=(cent_x[i], cent_y[i]),
arrowprops=dict(
arrowstyle="->", color="#1E73BE", lw=1.5, alpha=0.6,
),
zorder=2,
)
for i, label in enumerate(cent_w):
ax_b.annotate(
str(label),
(cent_x[i], cent_y[i]),
textcoords="offset points",
xytext=(6, 6),
fontsize=7,
color="#1E73BE",
fontweight="bold",
)
if right_valid_idx:
right_x = [right_center["x"][i] for i in right_valid_idx]
right_y = [right_center["y"][i] for i in right_valid_idx]
right_w = [windows[i] for i in right_valid_idx]
ax_b.plot(
right_x, right_y, "s--", color="#6A1B9A", linewidth=1.5,
markersize=6, alpha=0.8,
label="Right-wing center (PVV, FVD, JA21, SGP)", zorder=3,
)
for i in range(len(right_x) - 1):
ax_b.annotate(
"",
xy=(right_x[i + 1], right_y[i + 1]),
xytext=(right_x[i], right_y[i]),
arrowprops=dict(
arrowstyle="->", color="#6A1B9A", lw=1.2, alpha=0.5,
),
zorder=2,
)
for i, label in enumerate(right_w):
ax_b.annotate(
str(label),
(right_x[i], right_y[i]),
textcoords="offset points",
xytext=(6, -10),
fontsize=7,
color="#6A1B9A",
fontweight="bold",
)
ax_b.axhline(0, color="#CCCCCC", linewidth=0.5, linestyle="-")
ax_b.axvline(0, color="#CCCCCC", linewidth=0.5, linestyle="-")
ax_b.set_xlabel("PCA Axis 1 (Procrustes-aligned)")
ax_b.set_ylabel("PCA Axis 2 (Procrustes-aligned)")
ax_b.set_title("Panel B: Group Center of Gravity Trajectories", fontsize=11)
ax_b.set_aspect("equal", adjustable="datalim")
ax_b.grid(True, alpha=0.2)
ax_b.legend(loc="upper left", fontsize=7, framealpha=0.85)
fig.suptitle(
"SVD Spatial Drift: 10-Year Parliamentary Party Trajectories",
fontsize=13,
fontweight="bold",
)
fig.tight_layout(rect=[0, 0, 1, 0.96])
fig.savefig(OUTPUT_PATH, dpi=150, bbox_inches="tight", facecolor="white")
plt.close(fig)
logger.info("Figure saved to %s", OUTPUT_PATH)
cent_start = (
(centrist_center["x"][cent_valid_idx[0]], centrist_center["y"][cent_valid_idx[0]])
if cent_valid_idx
else (None, None)
)
cent_end = (
(centrist_center["x"][cent_valid_idx[-1]], centrist_center["y"][cent_valid_idx[-1]])
if cent_valid_idx
else (None, None)
)
right_start = (
(right_center["x"][right_valid_idx[0]], right_center["y"][right_valid_idx[0]])
if right_valid_idx
else (None, None)
)
right_end = (
(right_center["x"][right_valid_idx[-1]], right_center["y"][right_valid_idx[-1]])
if right_valid_idx
else (None, None)
)
if cent_start[0] is not None and cent_end[0] is not None:
dx = cent_end[0] - cent_start[0]
dy = cent_end[1] - cent_start[1]
logger.info(
"Centrist center drift: dx=%.4f dy=%.4f net=%.4f",
dx, dy, float(np.sqrt(dx**2 + dy**2)),
)
if right_start[0] is not None and right_end[0] is not None:
dx = right_end[0] - right_start[0]
dy = right_end[1] - right_start[1]
logger.info(
"Right-wing center drift: dx=%.4f dy=%.4f net=%.4f",
dx, dy, float(np.sqrt(dx**2 + dy**2)),
)
if __name__ == "__main__":
main()