#!/usr/bin/env python3 """Visualize SVD spatial drift over 10 annual windows. Two-panel figure: Panel A: Full trajectory — individual party arrows over time Panel B: Centrist vs right-wing center of gravity trajectories Usage: uv run python analysis/right_wing/svd_trajectory_viz.py """ from __future__ import annotations import logging import os import sys from pathlib import Path from typing import Dict, List import matplotlib import matplotlib.pyplot as plt import numpy as np matplotlib.use("Agg") from analysis.right_wing.common import ROOT, DB_PATH, REPORTS_DIR if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from analysis.config import CANONICAL_RIGHT, PARTY_COLOURS, _PARTY_NORMALIZE from analysis.explorer_data import ( get_uniform_dim_windows, load_party_scores_all_windows_aligned, ) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("svd_trajectory_viz") CANONICAL_CENTRIST = frozenset( {"VVD", "D66", "CDA", "NSC", "BBB", "CU", "ChristenUnie"} ) OUTPUT_PATH = str(REPORTS_DIR / "svd_trajectory_figure.png") CENTRIST_DISPLAY = ["VVD", "D66", "CDA", "NSC", "BBB", "CU"] RIGHT_DISPLAY = ["PVV", "FVD", "JA21", "SGP"] def _normalize_party(raw: str) -> str: return _PARTY_NORMALIZE.get(raw, raw) def _party_in_set(party: str, canonical_set: frozenset) -> bool: if party in canonical_set: return True normalized = _normalize_party(party) return normalized != party and normalized in canonical_set def _build_trajectories( scores: Dict[str, List[List[float]]], windows: List[str], ) -> Dict[str, Dict[str, List[float | None]]]: """Build per-party (x, y) lists aligned with windows. Returns {party: {"x": [...], "y": [...], "windows": [...]}} where each list has one entry per window (None if party missing). """ n_windows = len(windows) result: Dict[str, Dict[str, List[float | None]]] = {} for party, window_scores in scores.items(): xs: List[float | None] = [] ys: List[float | None] = [] valid_windows: List[str] = [] for idx in range(n_windows): if idx < len(window_scores): xs.append(window_scores[idx][0]) ys.append(window_scores[idx][1]) valid_windows.append(windows[idx]) else: xs.append(None) ys.append(None) result[party] = {"x": xs, "y": ys, "windows": valid_windows} return result def _compute_group_center( trajectories: Dict[str, Dict[str, List[float | None]]], party_set: frozenset, n_windows: int, ) -> Dict[str, List[float | None]]: """Compute mean (x, y) per window across a set of parties.""" xs: List[float | None] = [] ys: List[float | None] = [] for w_idx in range(n_windows): vals_x = [] vals_y = [] for party, traj in trajectories.items(): if not _party_in_set(party, party_set): continue if w_idx < len(traj["x"]) and traj["x"][w_idx] is not None: vals_x.append(traj["x"][w_idx]) vals_y.append(traj["y"][w_idx]) if vals_x: xs.append(float(np.mean(vals_x))) ys.append(float(np.mean(vals_y))) else: xs.append(None) ys.append(None) return {"x": xs, "y": ys} def _plot_party_trajectory( ax: plt.Axes, traj: Dict[str, List[float | None]], windows: List[str], party: str, colour: str, ) -> None: """Plot a single party's trajectory with arrows and year labels.""" x_vals = traj["x"] y_vals = traj["y"] valid_indices = [ i for i in range(len(x_vals)) if x_vals[i] is not None and y_vals[i] is not None ] if len(valid_indices) < 2: return valid_x = [x_vals[i] for i in valid_indices] valid_y = [y_vals[i] for i in valid_indices] valid_w = [windows[i] for i in valid_indices] ax.plot(valid_x, valid_y, "-", color=colour, linewidth=1.2, alpha=0.5, zorder=1) for i in range(len(valid_x) - 1): ax.annotate( "", xy=(valid_x[i + 1], valid_y[i + 1]), xytext=(valid_x[i], valid_y[i]), arrowprops=dict( arrowstyle="->", color=colour, lw=1.0, alpha=0.5, shrinkA=4, shrinkB=4, ), zorder=2, ) ax.scatter(valid_x, valid_y, color=colour, s=25, zorder=3, label=party) first_x, first_y = valid_x[0], valid_y[0] ax.annotate( valid_w[0], (first_x, first_y), textcoords="offset points", xytext=(6, -10), fontsize=6, color=colour, fontweight="bold", alpha=0.8, ) last_x, last_y = valid_x[-1], valid_y[-1] ax.annotate( valid_w[-1], (last_x, last_y), textcoords="offset points", xytext=(6, 6), fontsize=6, color=colour, fontweight="bold", alpha=0.8, ) def main() -> None: os.makedirs(str(REPORTS_DIR), exist_ok=True) logger.info("Loading aligned party positions...") windows = get_uniform_dim_windows(DB_PATH) if not windows: logger.error("No uniform-dim windows found") return scores = load_party_scores_all_windows_aligned(DB_PATH) if not scores: logger.error("No aligned party scores loaded") return logger.info("Windows: %s", windows) logger.info("Parties: %s", sorted(scores.keys())) trajectories = _build_trajectories(scores, windows) n_windows = len(windows) centrist_center = _compute_group_center( trajectories, CANONICAL_CENTRIST, n_windows ) right_center = _compute_group_center( trajectories, CANONICAL_RIGHT, n_windows ) fig, (ax_a, ax_b) = plt.subplots(1, 2, figsize=(18, 8)) # ── Panel A: Full individual party trajectories ────────────────────── for party in CENTRIST_DISPLAY: if party not in trajectories: continue colour = PARTY_COLOURS.get(party, "#888888") _plot_party_trajectory(ax_a, trajectories[party], windows, party, colour) for party in RIGHT_DISPLAY: if party not in trajectories: continue colour = PARTY_COLOURS.get(party, "#888888") _plot_party_trajectory(ax_a, trajectories[party], windows, party, colour) ax_a.axhline(0, color="#CCCCCC", linewidth=0.5, linestyle="-") ax_a.axvline(0, color="#CCCCCC", linewidth=0.5, linestyle="-") ax_a.set_xlabel("PCA Axis 1 (Procrustes-aligned)") ax_a.set_ylabel("PCA Axis 2 (Procrustes-aligned)") ax_a.set_title("Panel A: Party Trajectories (All Windows)", fontsize=11) ax_a.set_aspect("equal", adjustable="datalim") ax_a.grid(True, alpha=0.2) ax_a.legend(loc="upper left", fontsize=7, framealpha=0.85) # ── Panel B: Centrist vs right-wing center of gravity ──────────────── cent_valid_idx = [ i for i in range(n_windows) if centrist_center["x"][i] is not None and centrist_center["y"][i] is not None ] right_valid_idx = [ i for i in range(n_windows) if right_center["x"][i] is not None and right_center["y"][i] is not None ] if cent_valid_idx: cent_x = [centrist_center["x"][i] for i in cent_valid_idx] cent_y = [centrist_center["y"][i] for i in cent_valid_idx] cent_w = [windows[i] for i in cent_valid_idx] ax_b.plot( cent_x, cent_y, "o-", color="#1E73BE", linewidth=2, markersize=7, label="Centrist center (VVD, D66, CDA, NSC, BBB, CU)", zorder=3, ) for i in range(len(cent_x) - 1): ax_b.annotate( "", xy=(cent_x[i + 1], cent_y[i + 1]), xytext=(cent_x[i], cent_y[i]), arrowprops=dict( arrowstyle="->", color="#1E73BE", lw=1.5, alpha=0.6, ), zorder=2, ) for i, label in enumerate(cent_w): ax_b.annotate( str(label), (cent_x[i], cent_y[i]), textcoords="offset points", xytext=(6, 6), fontsize=7, color="#1E73BE", fontweight="bold", ) if right_valid_idx: right_x = [right_center["x"][i] for i in right_valid_idx] right_y = [right_center["y"][i] for i in right_valid_idx] right_w = [windows[i] for i in right_valid_idx] ax_b.plot( right_x, right_y, "s--", color="#6A1B9A", linewidth=1.5, markersize=6, alpha=0.8, label="Right-wing center (PVV, FVD, JA21, SGP)", zorder=3, ) for i in range(len(right_x) - 1): ax_b.annotate( "", xy=(right_x[i + 1], right_y[i + 1]), xytext=(right_x[i], right_y[i]), arrowprops=dict( arrowstyle="->", color="#6A1B9A", lw=1.2, alpha=0.5, ), zorder=2, ) for i, label in enumerate(right_w): ax_b.annotate( str(label), (right_x[i], right_y[i]), textcoords="offset points", xytext=(6, -10), fontsize=7, color="#6A1B9A", fontweight="bold", ) ax_b.axhline(0, color="#CCCCCC", linewidth=0.5, linestyle="-") ax_b.axvline(0, color="#CCCCCC", linewidth=0.5, linestyle="-") ax_b.set_xlabel("PCA Axis 1 (Procrustes-aligned)") ax_b.set_ylabel("PCA Axis 2 (Procrustes-aligned)") ax_b.set_title("Panel B: Group Center of Gravity Trajectories", fontsize=11) ax_b.set_aspect("equal", adjustable="datalim") ax_b.grid(True, alpha=0.2) ax_b.legend(loc="upper left", fontsize=7, framealpha=0.85) fig.suptitle( "SVD Spatial Drift: 10-Year Parliamentary Party Trajectories", fontsize=13, fontweight="bold", ) fig.tight_layout(rect=[0, 0, 1, 0.96]) fig.savefig(OUTPUT_PATH, dpi=150, bbox_inches="tight", facecolor="white") plt.close(fig) logger.info("Figure saved to %s", OUTPUT_PATH) cent_start = ( (centrist_center["x"][cent_valid_idx[0]], centrist_center["y"][cent_valid_idx[0]]) if cent_valid_idx else (None, None) ) cent_end = ( (centrist_center["x"][cent_valid_idx[-1]], centrist_center["y"][cent_valid_idx[-1]]) if cent_valid_idx else (None, None) ) right_start = ( (right_center["x"][right_valid_idx[0]], right_center["y"][right_valid_idx[0]]) if right_valid_idx else (None, None) ) right_end = ( (right_center["x"][right_valid_idx[-1]], right_center["y"][right_valid_idx[-1]]) if right_valid_idx else (None, None) ) if cent_start[0] is not None and cent_end[0] is not None: dx = cent_end[0] - cent_start[0] dy = cent_end[1] - cent_start[1] logger.info( "Centrist center drift: dx=%.4f dy=%.4f net=%.4f", dx, dy, float(np.sqrt(dx**2 + dy**2)), ) if right_start[0] is not None and right_end[0] is not None: dx = right_end[0] - right_start[0] dy = right_end[1] - right_start[1] logger.info( "Right-wing center drift: dx=%.4f dy=%.4f net=%.4f", dx, dy, float(np.sqrt(dx**2 + dy**2)), ) if __name__ == "__main__": main()