You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
motief/analysis/right_wing/svd_trajectory_viz.py

366 lines
12 KiB

#!/usr/bin/env python3
"""Visualize SVD spatial drift over 10 annual windows.
Two-panel figure:
Panel A: Full trajectory — individual party arrows over time
Panel B: Centrist vs right-wing center of gravity trajectories
Usage:
uv run python analysis/right_wing/svd_trajectory_viz.py
"""
from __future__ import annotations
import logging
import os
import sys
from pathlib import Path
from typing import Dict, List
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
matplotlib.use("Agg")
ROOT = Path(__file__).parent.parent.parent.resolve()
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from analysis.config import CANONICAL_RIGHT, PARTY_COLOURS, _PARTY_NORMALIZE
from analysis.explorer_data import (
get_uniform_dim_windows,
load_party_scores_all_windows_aligned,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("svd_trajectory_viz")
CANONICAL_CENTRIST = frozenset(
{"VVD", "D66", "CDA", "NSC", "BBB", "CU", "ChristenUnie"}
)
DB_PATH = str(ROOT / "data" / "motions.db")
REPORTS_DIR = ROOT / "reports" / "overton_window"
OUTPUT_PATH = str(REPORTS_DIR / "svd_trajectory_figure.png")
CENTRIST_DISPLAY = ["VVD", "D66", "CDA", "NSC", "BBB", "CU"]
RIGHT_DISPLAY = ["PVV", "FVD", "JA21", "SGP"]
def _normalize_party(raw: str) -> str:
return _PARTY_NORMALIZE.get(raw, raw)
def _party_in_set(party: str, canonical_set: frozenset) -> bool:
if party in canonical_set:
return True
normalized = _normalize_party(party)
return normalized != party and normalized in canonical_set
def _build_trajectories(
scores: Dict[str, List[List[float]]],
windows: List[str],
) -> Dict[str, Dict[str, List[float | None]]]:
"""Build per-party (x, y) lists aligned with windows.
Returns {party: {"x": [...], "y": [...], "windows": [...]}}
where each list has one entry per window (None if party missing).
"""
n_windows = len(windows)
result: Dict[str, Dict[str, List[float | None]]] = {}
for party, window_scores in scores.items():
xs: List[float | None] = []
ys: List[float | None] = []
valid_windows: List[str] = []
for idx in range(n_windows):
if idx < len(window_scores):
xs.append(window_scores[idx][0])
ys.append(window_scores[idx][1])
valid_windows.append(windows[idx])
else:
xs.append(None)
ys.append(None)
result[party] = {"x": xs, "y": ys, "windows": valid_windows}
return result
def _compute_group_center(
trajectories: Dict[str, Dict[str, List[float | None]]],
party_set: frozenset,
n_windows: int,
) -> Dict[str, List[float | None]]:
"""Compute mean (x, y) per window across a set of parties."""
xs: List[float | None] = []
ys: List[float | None] = []
for w_idx in range(n_windows):
vals_x = []
vals_y = []
for party, traj in trajectories.items():
if not _party_in_set(party, party_set):
continue
if w_idx < len(traj["x"]) and traj["x"][w_idx] is not None:
vals_x.append(traj["x"][w_idx])
vals_y.append(traj["y"][w_idx])
if vals_x:
xs.append(float(np.mean(vals_x)))
ys.append(float(np.mean(vals_y)))
else:
xs.append(None)
ys.append(None)
return {"x": xs, "y": ys}
def _plot_party_trajectory(
ax: plt.Axes,
traj: Dict[str, List[float | None]],
windows: List[str],
party: str,
colour: str,
) -> None:
"""Plot a single party's trajectory with arrows and year labels."""
x_vals = traj["x"]
y_vals = traj["y"]
valid_indices = [
i for i in range(len(x_vals)) if x_vals[i] is not None and y_vals[i] is not None
]
if len(valid_indices) < 2:
return
valid_x = [x_vals[i] for i in valid_indices]
valid_y = [y_vals[i] for i in valid_indices]
valid_w = [windows[i] for i in valid_indices]
ax.plot(valid_x, valid_y, "-", color=colour, linewidth=1.2, alpha=0.5, zorder=1)
for i in range(len(valid_x) - 1):
ax.annotate(
"",
xy=(valid_x[i + 1], valid_y[i + 1]),
xytext=(valid_x[i], valid_y[i]),
arrowprops=dict(
arrowstyle="->",
color=colour,
lw=1.0,
alpha=0.5,
shrinkA=4,
shrinkB=4,
),
zorder=2,
)
ax.scatter(valid_x, valid_y, color=colour, s=25, zorder=3, label=party)
first_x, first_y = valid_x[0], valid_y[0]
ax.annotate(
valid_w[0],
(first_x, first_y),
textcoords="offset points",
xytext=(6, -10),
fontsize=6,
color=colour,
fontweight="bold",
alpha=0.8,
)
last_x, last_y = valid_x[-1], valid_y[-1]
ax.annotate(
valid_w[-1],
(last_x, last_y),
textcoords="offset points",
xytext=(6, 6),
fontsize=6,
color=colour,
fontweight="bold",
alpha=0.8,
)
def main() -> None:
os.makedirs(str(REPORTS_DIR), exist_ok=True)
logger.info("Loading aligned party positions...")
windows = get_uniform_dim_windows(DB_PATH)
if not windows:
logger.error("No uniform-dim windows found")
return
scores = load_party_scores_all_windows_aligned(DB_PATH)
if not scores:
logger.error("No aligned party scores loaded")
return
logger.info("Windows: %s", windows)
logger.info("Parties: %s", sorted(scores.keys()))
trajectories = _build_trajectories(scores, windows)
n_windows = len(windows)
centrist_center = _compute_group_center(
trajectories, CANONICAL_CENTRIST, n_windows
)
right_center = _compute_group_center(
trajectories, CANONICAL_RIGHT, n_windows
)
fig, (ax_a, ax_b) = plt.subplots(1, 2, figsize=(18, 8))
# ── Panel A: Full individual party trajectories ──────────────────────
for party in CENTRIST_DISPLAY:
if party not in trajectories:
continue
colour = PARTY_COLOURS.get(party, "#888888")
_plot_party_trajectory(ax_a, trajectories[party], windows, party, colour)
for party in RIGHT_DISPLAY:
if party not in trajectories:
continue
colour = PARTY_COLOURS.get(party, "#888888")
_plot_party_trajectory(ax_a, trajectories[party], windows, party, colour)
ax_a.axhline(0, color="#CCCCCC", linewidth=0.5, linestyle="-")
ax_a.axvline(0, color="#CCCCCC", linewidth=0.5, linestyle="-")
ax_a.set_xlabel("PCA Axis 1 (Procrustes-aligned)")
ax_a.set_ylabel("PCA Axis 2 (Procrustes-aligned)")
ax_a.set_title("Panel A: Party Trajectories (All Windows)", fontsize=11)
ax_a.set_aspect("equal", adjustable="datalim")
ax_a.grid(True, alpha=0.2)
ax_a.legend(loc="upper left", fontsize=7, framealpha=0.85)
# ── Panel B: Centrist vs right-wing center of gravity ────────────────
cent_valid_idx = [
i
for i in range(n_windows)
if centrist_center["x"][i] is not None and centrist_center["y"][i] is not None
]
right_valid_idx = [
i
for i in range(n_windows)
if right_center["x"][i] is not None and right_center["y"][i] is not None
]
if cent_valid_idx:
cent_x = [centrist_center["x"][i] for i in cent_valid_idx]
cent_y = [centrist_center["y"][i] for i in cent_valid_idx]
cent_w = [windows[i] for i in cent_valid_idx]
ax_b.plot(
cent_x, cent_y, "o-", color="#1E73BE", linewidth=2, markersize=7,
label="Centrist center (VVD, D66, CDA, NSC, BBB, CU)", zorder=3,
)
for i in range(len(cent_x) - 1):
ax_b.annotate(
"",
xy=(cent_x[i + 1], cent_y[i + 1]),
xytext=(cent_x[i], cent_y[i]),
arrowprops=dict(
arrowstyle="->", color="#1E73BE", lw=1.5, alpha=0.6,
),
zorder=2,
)
for i, label in enumerate(cent_w):
ax_b.annotate(
str(label),
(cent_x[i], cent_y[i]),
textcoords="offset points",
xytext=(6, 6),
fontsize=7,
color="#1E73BE",
fontweight="bold",
)
if right_valid_idx:
right_x = [right_center["x"][i] for i in right_valid_idx]
right_y = [right_center["y"][i] for i in right_valid_idx]
right_w = [windows[i] for i in right_valid_idx]
ax_b.plot(
right_x, right_y, "s--", color="#6A1B9A", linewidth=1.5,
markersize=6, alpha=0.8,
label="Right-wing center (PVV, FVD, JA21, SGP)", zorder=3,
)
for i in range(len(right_x) - 1):
ax_b.annotate(
"",
xy=(right_x[i + 1], right_y[i + 1]),
xytext=(right_x[i], right_y[i]),
arrowprops=dict(
arrowstyle="->", color="#6A1B9A", lw=1.2, alpha=0.5,
),
zorder=2,
)
for i, label in enumerate(right_w):
ax_b.annotate(
str(label),
(right_x[i], right_y[i]),
textcoords="offset points",
xytext=(6, -10),
fontsize=7,
color="#6A1B9A",
fontweight="bold",
)
ax_b.axhline(0, color="#CCCCCC", linewidth=0.5, linestyle="-")
ax_b.axvline(0, color="#CCCCCC", linewidth=0.5, linestyle="-")
ax_b.set_xlabel("PCA Axis 1 (Procrustes-aligned)")
ax_b.set_ylabel("PCA Axis 2 (Procrustes-aligned)")
ax_b.set_title("Panel B: Group Center of Gravity Trajectories", fontsize=11)
ax_b.set_aspect("equal", adjustable="datalim")
ax_b.grid(True, alpha=0.2)
ax_b.legend(loc="upper left", fontsize=7, framealpha=0.85)
fig.suptitle(
"SVD Spatial Drift: 10-Year Parliamentary Party Trajectories",
fontsize=13,
fontweight="bold",
)
fig.tight_layout(rect=[0, 0, 1, 0.96])
fig.savefig(OUTPUT_PATH, dpi=150, bbox_inches="tight", facecolor="white")
plt.close(fig)
logger.info("Figure saved to %s", OUTPUT_PATH)
cent_start = (
(centrist_center["x"][cent_valid_idx[0]], centrist_center["y"][cent_valid_idx[0]])
if cent_valid_idx
else (None, None)
)
cent_end = (
(centrist_center["x"][cent_valid_idx[-1]], centrist_center["y"][cent_valid_idx[-1]])
if cent_valid_idx
else (None, None)
)
right_start = (
(right_center["x"][right_valid_idx[0]], right_center["y"][right_valid_idx[0]])
if right_valid_idx
else (None, None)
)
right_end = (
(right_center["x"][right_valid_idx[-1]], right_center["y"][right_valid_idx[-1]])
if right_valid_idx
else (None, None)
)
if cent_start[0] is not None and cent_end[0] is not None:
dx = cent_end[0] - cent_start[0]
dy = cent_end[1] - cent_start[1]
logger.info(
"Centrist center drift: dx=%.4f dy=%.4f net=%.4f",
dx, dy, float(np.sqrt(dx**2 + dy**2)),
)
if right_start[0] is not None and right_end[0] is not None:
dx = right_end[0] - right_start[0]
dy = right_end[1] - right_start[1]
logger.info(
"Right-wing center drift: dx=%.4f dy=%.4f net=%.4f",
dx, dy, float(np.sqrt(dx**2 + dy**2)),
)
if __name__ == "__main__":
main()