Add bootstrap CIs to party axis chart with error bars and diamond markers

- Add load_party_mp_vectors() to return raw per-MP SVD vectors by party
- Extract _build_party_axis_figure() as pure function for testability
- Modify _render_party_axis_chart to accept bootstrap_data and delegate
  to the new builder
- When bootstrap_data present: show error_x bars, diamond markers for
  N=1 parties, and N=count in hover text
- Wire up bootstrap computation in build_svd_components_tab via cached
  _cached_bootstrap_cis wrapper
- Add 6 tests covering figure construction, bootstrap rendering, flip
  behavior, and importability
main
Sven Geboers 1 month ago
parent 88110b0aaa
commit 3938eecc53
  1. 188
      explorer.py
  2. 175
      tests/test_explorer_chart.py

@ -485,6 +485,77 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
pass
@st.cache_data(show_spinner="Partij-MP vectoren laden…")
def load_party_mp_vectors(db_path: str) -> Dict[str, List[np.ndarray]]:
"""Return per-party lists of individual MP SVD vectors.
Same MPparty mapping as load_party_axis_scores(), but returns the raw
per-MP vectors instead of averaging them. Suitable for bootstrap CI
computation.
Returns:
{party_name: [np.ndarray(50,), ...]} one array per MP.
"""
try:
con = duckdb.connect(database=db_path, read_only=True)
# Build mp → party mapping (same logic as load_party_axis_scores)
meta_ordered = con.execute(
"SELECT mp_name, party FROM mp_metadata "
"WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' "
"ORDER BY van ASC"
).fetchall()
mp_party: Dict[str, str] = {}
for mp_name, party in meta_ordered:
if mp_name and party:
mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party)
# Individual MP vectors from current_parliament
rows = con.execute(
"SELECT entity_id, vector FROM svd_vectors "
"WHERE entity_type='mp' AND window_id='current_parliament'"
).fetchall()
party_vecs: Dict[str, List[np.ndarray]] = {}
for entity_id, raw_vec in rows:
party = mp_party.get(entity_id)
if party is None or party not in CURRENT_PARLIAMENT_PARTIES:
continue
if isinstance(raw_vec, str):
vec = json.loads(raw_vec)
elif isinstance(raw_vec, (bytes, bytearray)):
vec = json.loads(raw_vec.decode())
elif isinstance(raw_vec, list):
vec = raw_vec
else:
try:
vec = list(raw_vec)
except Exception:
continue
fvec = np.array([float(v) if v is not None else 0.0 for v in vec])
party_vecs.setdefault(party, []).append(fvec)
return party_vecs
except Exception:
logger.exception("Failed to load party MP vectors")
return {}
finally:
try:
con.close()
except Exception:
pass
@st.cache_data(show_spinner="Bootstrap CI berekenen…")
def _cached_bootstrap_cis(
_party_mp_vectors: Dict[str, List[np.ndarray]],
) -> Dict[str, Dict]:
"""Thin caching wrapper around compute_party_bootstrap_cis."""
from analysis.political_axis import compute_party_bootstrap_cis
return compute_party_bootstrap_cis(_party_mp_vectors)
@st.cache_data(show_spinner="Scree-plot laden…")
def load_scree_data(db_path: str) -> List[float]:
"""Return explained variance ratios (%) for all SVD components, sorted descending.
@ -621,18 +692,28 @@ def _render_scree_plot(importances: List[float], n_show: int = 15) -> None:
st.plotly_chart(fig, use_container_width=True)
def _render_party_axis_chart(
party_scores: Dict[str, List[float]], comp_sel: int, theme: dict
) -> None:
"""Render a 1D horizontal Plotly scatter of party positions on SVD axis `comp_sel`.
def _build_party_axis_figure(
party_scores: Dict[str, List[float]],
comp_sel: int,
theme: dict,
bootstrap_data: Optional[Dict[str, Dict]] = None,
) -> Optional[go.Figure]:
"""Build a 1D horizontal Plotly scatter of party positions on SVD axis `comp_sel`.
Each party is plotted at its score on a single horizontal axis (y=0).
When theme['flip'] is True the scores are negated so that the progressive/left
side always appears on the left of the chart.
Pure function that returns a go.Figure (no Streamlit calls).
Args:
party_scores: {party_name: [float*k]} mean SVD vectors per party.
comp_sel: 1-indexed SVD axis number.
theme: dict with keys label, explanation, positive_pole, negative_pole, flip.
bootstrap_data: optional output from compute_party_bootstrap_cis
{party: {centroid, ci_lower, ci_upper, std, n_mps}}.
Returns:
go.Figure, or None if no data available.
"""
if not party_scores:
st.caption("_Partijdata niet beschikbaar voor deze as._")
return
return None
axis_idx = comp_sel - 1 # 0-based index into the 50-dim vector
flip = theme.get("flip", False)
@ -645,13 +726,21 @@ def _render_party_axis_chart(
data.append({"party": party, "score": score})
if not data:
st.caption("_Geen partijscores voor deze as._")
return
return None
scores = [d["score"] for d in data]
parties = [d["party"] for d in data]
colours = [PARTY_COLOURS.get(p, "#9E9E9E") for p in parties]
hover = [f"{p}: {s:.3f}" for p, s in zip(parties, scores)]
# Build hover text: include N when bootstrap data available
if bootstrap_data:
hover = []
for p, s in zip(parties, scores):
bd = bootstrap_data.get(p)
n_mps = bd["n_mps"] if bd else "?"
hover.append(f"{p}: {s:.3f} (N={n_mps})")
else:
hover = [f"{p}: {s:.3f}" for p, s in zip(parties, scores)]
# Determine axis labels: left = progressive pole, right = conservative pole
pos_pole = theme.get("positive_pole", "")
@ -674,20 +763,43 @@ def _render_party_axis_chart(
showlegend=False,
)
)
# Build marker kwargs — bootstrap data adds error bars and diamond markers
marker_kwargs: dict = {"size": 18, "color": colours}
error_x_kwargs: Optional[dict] = None
if bootstrap_data:
error_array = []
symbols = []
for p in parties:
bd = bootstrap_data.get(p)
if bd:
err = (bd["ci_upper"][axis_idx] - bd["ci_lower"][axis_idx]) / 2
error_array.append(abs(float(err)))
symbols.append("diamond" if bd["n_mps"] == 1 else "circle")
else:
error_array.append(0.0)
symbols.append("circle")
marker_kwargs["symbol"] = symbols
error_x_kwargs = {"type": "data", "array": error_array, "visible": True}
# Party markers
fig.add_trace(
go.Scatter(
x=scores,
y=[0] * len(scores),
mode="markers+text",
text=parties,
textposition="top center",
marker={"size": 18, "color": colours},
hovertext=hover,
hoverinfo="text",
showlegend=False,
)
)
scatter_kwargs: dict = {
"x": scores,
"y": [0] * len(scores),
"mode": "markers+text",
"text": parties,
"textposition": "top center",
"marker": marker_kwargs,
"hovertext": hover,
"hoverinfo": "text",
"showlegend": False,
}
if error_x_kwargs is not None:
scatter_kwargs["error_x"] = error_x_kwargs
fig.add_trace(go.Scatter(**scatter_kwargs))
fig.update_layout(
height=160,
margin={"l": 10, "r": 10, "t": 10, "b": 30},
@ -702,6 +814,24 @@ def _render_party_axis_chart(
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)
return fig
def _render_party_axis_chart(
party_scores: Dict[str, List[float]],
comp_sel: int,
theme: dict,
bootstrap_data: Optional[Dict[str, Dict]] = None,
) -> None:
"""Render a 1D horizontal Plotly scatter of party positions on SVD axis `comp_sel`.
Delegates figure construction to _build_party_axis_figure, then renders via
st.plotly_chart.
"""
fig = _build_party_axis_figure(party_scores, comp_sel, theme, bootstrap_data)
if fig is None:
st.caption("_Partijdata niet beschikbaar voor deze as._")
return
st.plotly_chart(fig, use_container_width=True)
@ -1648,7 +1778,13 @@ def build_svd_components_tab(db_path: str) -> None:
# Party axis chart
party_scores = load_party_axis_scores(db_path)
_render_party_axis_chart(party_scores, comp_sel, theme)
party_mp_vectors = load_party_mp_vectors(db_path)
bootstrap_data = (
_cached_bootstrap_cis(party_mp_vectors) if party_mp_vectors else None
)
_render_party_axis_chart(
party_scores, comp_sel, theme, bootstrap_data=bootstrap_data
)
# Batch-fetch motion details (title, date, policy_area, url, body_text, voting_results)
motion_ids = [m.get("motion_id") for m in motions if m.get("motion_id") is not None]

@ -0,0 +1,175 @@
"""Tests for _build_party_axis_figure and load_party_mp_vectors in explorer.py."""
import numpy as np
import plotly.graph_objects as go
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_party_scores(n_parties=3, dim=50):
"""Return a minimal party_scores dict for testing."""
rng = np.random.default_rng(0)
names = [f"Party{i}" for i in range(n_parties)]
return {name: rng.standard_normal(dim).tolist() for name in names}
def _make_theme(flip=False):
return {
"label": "Test axis",
"explanation": "A test axis.",
"positive_pole": "Left",
"negative_pole": "Right",
"flip": flip,
}
def _make_bootstrap_data(party_scores, dim=50):
"""Build synthetic bootstrap_data matching party_scores keys.
Party0 gets n_mps=1 (single-MP party diamond marker).
Others get n_mps > 1 with a real CI spread.
"""
rng = np.random.default_rng(1)
result = {}
for i, party in enumerate(party_scores):
centroid = np.array(party_scores[party])
if i == 0:
# Single-MP party
result[party] = {
"centroid": centroid,
"ci_lower": centroid.copy(),
"ci_upper": centroid.copy(),
"std": np.zeros(dim),
"n_mps": 1,
}
else:
spread = rng.uniform(0.01, 0.05, size=dim)
result[party] = {
"centroid": centroid,
"ci_lower": centroid - spread,
"ci_upper": centroid + spread,
"std": spread / 2,
"n_mps": 5 + i,
}
return result
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestBuildPartyAxisFigure:
"""Tests for _build_party_axis_figure (pure Plotly figure construction)."""
def test_returns_figure_without_bootstrap(self):
"""Basic call without bootstrap → returns go.Figure with 2 traces."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
theme = _make_theme()
fig = _build_party_axis_figure(party_scores, comp_sel=1, theme=theme)
assert isinstance(fig, go.Figure)
assert len(fig.data) == 2 # baseline + markers
# First trace is the baseline line
assert fig.data[0].mode == "lines"
# Second trace is the marker scatter
assert "markers" in fig.data[1].mode
def test_returns_none_for_empty_scores(self):
"""Empty party_scores returns None (no figure)."""
from explorer import _build_party_axis_figure
fig = _build_party_axis_figure({}, comp_sel=1, theme=_make_theme())
assert fig is None
def test_with_bootstrap_has_error_x_and_diamonds(self):
"""Call WITH bootstrap_data → error_x on marker trace, diamond for N=1."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
theme = _make_theme()
bootstrap_data = _make_bootstrap_data(party_scores)
fig = _build_party_axis_figure(
party_scores, comp_sel=1, theme=theme, bootstrap_data=bootstrap_data
)
assert isinstance(fig, go.Figure)
assert len(fig.data) == 2
marker_trace = fig.data[1]
# error_x should be present and visible
assert marker_trace.error_x is not None
assert marker_trace.error_x.visible is True
assert marker_trace.error_x.type == "data"
assert len(marker_trace.error_x.array) == 3 # 3 parties
# All error bar values should be non-negative
for err in marker_trace.error_x.array:
assert err >= 0.0
# Marker symbols: first party (N=1) → diamond, others → circle
symbols = list(marker_trace.marker.symbol)
assert symbols[0] == "diamond"
assert all(s == "circle" for s in symbols[1:])
def test_with_bootstrap_hover_includes_n(self):
"""Hover text includes N=<count> for each party."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
theme = _make_theme()
bootstrap_data = _make_bootstrap_data(party_scores)
fig = _build_party_axis_figure(
party_scores, comp_sel=1, theme=theme, bootstrap_data=bootstrap_data
)
marker_trace = fig.data[1]
for ht in marker_trace.hovertext:
assert "(N=" in ht
def test_flip_negates_scores_but_error_bars_stay_positive(self):
"""When flip=True, scores are negated but error bar magnitudes stay positive."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
theme_no_flip = _make_theme(flip=False)
theme_flip = _make_theme(flip=True)
bootstrap_data = _make_bootstrap_data(party_scores)
fig_normal = _build_party_axis_figure(
party_scores, comp_sel=1, theme=theme_no_flip, bootstrap_data=bootstrap_data
)
fig_flipped = _build_party_axis_figure(
party_scores, comp_sel=1, theme=theme_flip, bootstrap_data=bootstrap_data
)
normal_scores = list(fig_normal.data[1].x)
flipped_scores = list(fig_flipped.data[1].x)
# Scores should be negated
for ns, fs in zip(normal_scores, flipped_scores):
assert pytest.approx(ns) == -fs
# Error bars should be the same (positive) in both cases
normal_errors = list(fig_normal.data[1].error_x.array)
flipped_errors = list(fig_flipped.data[1].error_x.array)
for ne, fe in zip(normal_errors, flipped_errors):
assert ne >= 0.0
assert fe >= 0.0
assert pytest.approx(ne) == fe
class TestLoadPartyMpVectorsImportable:
"""Smoke test: verify load_party_mp_vectors is importable."""
def test_importable(self):
from explorer import load_party_mp_vectors
assert callable(load_party_mp_vectors)
Loading…
Cancel
Save