diff --git a/explorer.py b/explorer.py index ea94ad5..b1955de 100644 --- a/explorer.py +++ b/explorer.py @@ -485,6 +485,77 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]: pass +@st.cache_data(show_spinner="Partij-MP vectoren laden…") +def load_party_mp_vectors(db_path: str) -> Dict[str, List[np.ndarray]]: + """Return per-party lists of individual MP SVD vectors. + + Same MP→party mapping as load_party_axis_scores(), but returns the raw + per-MP vectors instead of averaging them. Suitable for bootstrap CI + computation. + + Returns: + {party_name: [np.ndarray(50,), ...]} — one array per MP. + """ + try: + con = duckdb.connect(database=db_path, read_only=True) + + # Build mp → party mapping (same logic as load_party_axis_scores) + meta_ordered = con.execute( + "SELECT mp_name, party FROM mp_metadata " + "WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' " + "ORDER BY van ASC" + ).fetchall() + mp_party: Dict[str, str] = {} + for mp_name, party in meta_ordered: + if mp_name and party: + mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party) + + # Individual MP vectors from current_parliament + rows = con.execute( + "SELECT entity_id, vector FROM svd_vectors " + "WHERE entity_type='mp' AND window_id='current_parliament'" + ).fetchall() + + party_vecs: Dict[str, List[np.ndarray]] = {} + for entity_id, raw_vec in rows: + party = mp_party.get(entity_id) + if party is None or party not in CURRENT_PARLIAMENT_PARTIES: + continue + if isinstance(raw_vec, str): + vec = json.loads(raw_vec) + elif isinstance(raw_vec, (bytes, bytearray)): + vec = json.loads(raw_vec.decode()) + elif isinstance(raw_vec, list): + vec = raw_vec + else: + try: + vec = list(raw_vec) + except Exception: + continue + fvec = np.array([float(v) if v is not None else 0.0 for v in vec]) + party_vecs.setdefault(party, []).append(fvec) + + return party_vecs + except Exception: + logger.exception("Failed to load party MP vectors") + return {} + finally: + try: + con.close() + except Exception: + pass + + +@st.cache_data(show_spinner="Bootstrap CI berekenen…") +def _cached_bootstrap_cis( + _party_mp_vectors: Dict[str, List[np.ndarray]], +) -> Dict[str, Dict]: + """Thin caching wrapper around compute_party_bootstrap_cis.""" + from analysis.political_axis import compute_party_bootstrap_cis + + return compute_party_bootstrap_cis(_party_mp_vectors) + + @st.cache_data(show_spinner="Scree-plot laden…") def load_scree_data(db_path: str) -> List[float]: """Return explained variance ratios (%) for all SVD components, sorted descending. @@ -621,18 +692,28 @@ def _render_scree_plot(importances: List[float], n_show: int = 15) -> None: st.plotly_chart(fig, use_container_width=True) -def _render_party_axis_chart( - party_scores: Dict[str, List[float]], comp_sel: int, theme: dict -) -> None: - """Render a 1D horizontal Plotly scatter of party positions on SVD axis `comp_sel`. +def _build_party_axis_figure( + party_scores: Dict[str, List[float]], + comp_sel: int, + theme: dict, + bootstrap_data: Optional[Dict[str, Dict]] = None, +) -> Optional[go.Figure]: + """Build a 1D horizontal Plotly scatter of party positions on SVD axis `comp_sel`. - Each party is plotted at its score on a single horizontal axis (y=0). - When theme['flip'] is True the scores are negated so that the progressive/left - side always appears on the left of the chart. + Pure function that returns a go.Figure (no Streamlit calls). + + Args: + party_scores: {party_name: [float*k]} — mean SVD vectors per party. + comp_sel: 1-indexed SVD axis number. + theme: dict with keys label, explanation, positive_pole, negative_pole, flip. + bootstrap_data: optional output from compute_party_bootstrap_cis — + {party: {centroid, ci_lower, ci_upper, std, n_mps}}. + + Returns: + go.Figure, or None if no data available. """ if not party_scores: - st.caption("_Partijdata niet beschikbaar voor deze as._") - return + return None axis_idx = comp_sel - 1 # 0-based index into the 50-dim vector flip = theme.get("flip", False) @@ -645,13 +726,21 @@ def _render_party_axis_chart( data.append({"party": party, "score": score}) if not data: - st.caption("_Geen partijscores voor deze as._") - return + return None scores = [d["score"] for d in data] parties = [d["party"] for d in data] colours = [PARTY_COLOURS.get(p, "#9E9E9E") for p in parties] - hover = [f"{p}: {s:.3f}" for p, s in zip(parties, scores)] + + # Build hover text: include N when bootstrap data available + if bootstrap_data: + hover = [] + for p, s in zip(parties, scores): + bd = bootstrap_data.get(p) + n_mps = bd["n_mps"] if bd else "?" + hover.append(f"{p}: {s:.3f} (N={n_mps})") + else: + hover = [f"{p}: {s:.3f}" for p, s in zip(parties, scores)] # Determine axis labels: left = progressive pole, right = conservative pole pos_pole = theme.get("positive_pole", "") @@ -674,20 +763,43 @@ def _render_party_axis_chart( showlegend=False, ) ) + + # Build marker kwargs — bootstrap data adds error bars and diamond markers + marker_kwargs: dict = {"size": 18, "color": colours} + error_x_kwargs: Optional[dict] = None + + if bootstrap_data: + error_array = [] + symbols = [] + for p in parties: + bd = bootstrap_data.get(p) + if bd: + err = (bd["ci_upper"][axis_idx] - bd["ci_lower"][axis_idx]) / 2 + error_array.append(abs(float(err))) + symbols.append("diamond" if bd["n_mps"] == 1 else "circle") + else: + error_array.append(0.0) + symbols.append("circle") + marker_kwargs["symbol"] = symbols + error_x_kwargs = {"type": "data", "array": error_array, "visible": True} + # Party markers - fig.add_trace( - go.Scatter( - x=scores, - y=[0] * len(scores), - mode="markers+text", - text=parties, - textposition="top center", - marker={"size": 18, "color": colours}, - hovertext=hover, - hoverinfo="text", - showlegend=False, - ) - ) + scatter_kwargs: dict = { + "x": scores, + "y": [0] * len(scores), + "mode": "markers+text", + "text": parties, + "textposition": "top center", + "marker": marker_kwargs, + "hovertext": hover, + "hoverinfo": "text", + "showlegend": False, + } + if error_x_kwargs is not None: + scatter_kwargs["error_x"] = error_x_kwargs + + fig.add_trace(go.Scatter(**scatter_kwargs)) + fig.update_layout( height=160, margin={"l": 10, "r": 10, "t": 10, "b": 30}, @@ -702,6 +814,24 @@ def _render_party_axis_chart( plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)", ) + return fig + + +def _render_party_axis_chart( + party_scores: Dict[str, List[float]], + comp_sel: int, + theme: dict, + bootstrap_data: Optional[Dict[str, Dict]] = None, +) -> None: + """Render a 1D horizontal Plotly scatter of party positions on SVD axis `comp_sel`. + + Delegates figure construction to _build_party_axis_figure, then renders via + st.plotly_chart. + """ + fig = _build_party_axis_figure(party_scores, comp_sel, theme, bootstrap_data) + if fig is None: + st.caption("_Partijdata niet beschikbaar voor deze as._") + return st.plotly_chart(fig, use_container_width=True) @@ -1648,7 +1778,13 @@ def build_svd_components_tab(db_path: str) -> None: # Party axis chart party_scores = load_party_axis_scores(db_path) - _render_party_axis_chart(party_scores, comp_sel, theme) + party_mp_vectors = load_party_mp_vectors(db_path) + bootstrap_data = ( + _cached_bootstrap_cis(party_mp_vectors) if party_mp_vectors else None + ) + _render_party_axis_chart( + party_scores, comp_sel, theme, bootstrap_data=bootstrap_data + ) # Batch-fetch motion details (title, date, policy_area, url, body_text, voting_results) motion_ids = [m.get("motion_id") for m in motions if m.get("motion_id") is not None] diff --git a/tests/test_explorer_chart.py b/tests/test_explorer_chart.py new file mode 100644 index 0000000..c1629b1 --- /dev/null +++ b/tests/test_explorer_chart.py @@ -0,0 +1,175 @@ +"""Tests for _build_party_axis_figure and load_party_mp_vectors in explorer.py.""" + +import numpy as np +import plotly.graph_objects as go +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_party_scores(n_parties=3, dim=50): + """Return a minimal party_scores dict for testing.""" + rng = np.random.default_rng(0) + names = [f"Party{i}" for i in range(n_parties)] + return {name: rng.standard_normal(dim).tolist() for name in names} + + +def _make_theme(flip=False): + return { + "label": "Test axis", + "explanation": "A test axis.", + "positive_pole": "Left", + "negative_pole": "Right", + "flip": flip, + } + + +def _make_bootstrap_data(party_scores, dim=50): + """Build synthetic bootstrap_data matching party_scores keys. + + Party0 gets n_mps=1 (single-MP party → diamond marker). + Others get n_mps > 1 with a real CI spread. + """ + rng = np.random.default_rng(1) + result = {} + for i, party in enumerate(party_scores): + centroid = np.array(party_scores[party]) + if i == 0: + # Single-MP party + result[party] = { + "centroid": centroid, + "ci_lower": centroid.copy(), + "ci_upper": centroid.copy(), + "std": np.zeros(dim), + "n_mps": 1, + } + else: + spread = rng.uniform(0.01, 0.05, size=dim) + result[party] = { + "centroid": centroid, + "ci_lower": centroid - spread, + "ci_upper": centroid + spread, + "std": spread / 2, + "n_mps": 5 + i, + } + return result + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestBuildPartyAxisFigure: + """Tests for _build_party_axis_figure (pure Plotly figure construction).""" + + def test_returns_figure_without_bootstrap(self): + """Basic call without bootstrap → returns go.Figure with 2 traces.""" + from explorer import _build_party_axis_figure + + party_scores = _make_party_scores() + theme = _make_theme() + fig = _build_party_axis_figure(party_scores, comp_sel=1, theme=theme) + + assert isinstance(fig, go.Figure) + assert len(fig.data) == 2 # baseline + markers + # First trace is the baseline line + assert fig.data[0].mode == "lines" + # Second trace is the marker scatter + assert "markers" in fig.data[1].mode + + def test_returns_none_for_empty_scores(self): + """Empty party_scores returns None (no figure).""" + from explorer import _build_party_axis_figure + + fig = _build_party_axis_figure({}, comp_sel=1, theme=_make_theme()) + assert fig is None + + def test_with_bootstrap_has_error_x_and_diamonds(self): + """Call WITH bootstrap_data → error_x on marker trace, diamond for N=1.""" + from explorer import _build_party_axis_figure + + party_scores = _make_party_scores() + theme = _make_theme() + bootstrap_data = _make_bootstrap_data(party_scores) + fig = _build_party_axis_figure( + party_scores, comp_sel=1, theme=theme, bootstrap_data=bootstrap_data + ) + + assert isinstance(fig, go.Figure) + assert len(fig.data) == 2 + + marker_trace = fig.data[1] + + # error_x should be present and visible + assert marker_trace.error_x is not None + assert marker_trace.error_x.visible is True + assert marker_trace.error_x.type == "data" + assert len(marker_trace.error_x.array) == 3 # 3 parties + + # All error bar values should be non-negative + for err in marker_trace.error_x.array: + assert err >= 0.0 + + # Marker symbols: first party (N=1) → diamond, others → circle + symbols = list(marker_trace.marker.symbol) + assert symbols[0] == "diamond" + assert all(s == "circle" for s in symbols[1:]) + + def test_with_bootstrap_hover_includes_n(self): + """Hover text includes N= for each party.""" + from explorer import _build_party_axis_figure + + party_scores = _make_party_scores() + theme = _make_theme() + bootstrap_data = _make_bootstrap_data(party_scores) + fig = _build_party_axis_figure( + party_scores, comp_sel=1, theme=theme, bootstrap_data=bootstrap_data + ) + + marker_trace = fig.data[1] + for ht in marker_trace.hovertext: + assert "(N=" in ht + + def test_flip_negates_scores_but_error_bars_stay_positive(self): + """When flip=True, scores are negated but error bar magnitudes stay positive.""" + from explorer import _build_party_axis_figure + + party_scores = _make_party_scores() + theme_no_flip = _make_theme(flip=False) + theme_flip = _make_theme(flip=True) + bootstrap_data = _make_bootstrap_data(party_scores) + + fig_normal = _build_party_axis_figure( + party_scores, comp_sel=1, theme=theme_no_flip, bootstrap_data=bootstrap_data + ) + fig_flipped = _build_party_axis_figure( + party_scores, comp_sel=1, theme=theme_flip, bootstrap_data=bootstrap_data + ) + + normal_scores = list(fig_normal.data[1].x) + flipped_scores = list(fig_flipped.data[1].x) + + # Scores should be negated + for ns, fs in zip(normal_scores, flipped_scores): + assert pytest.approx(ns) == -fs + + # Error bars should be the same (positive) in both cases + normal_errors = list(fig_normal.data[1].error_x.array) + flipped_errors = list(fig_flipped.data[1].error_x.array) + for ne, fe in zip(normal_errors, flipped_errors): + assert ne >= 0.0 + assert fe >= 0.0 + assert pytest.approx(ne) == fe + + +class TestLoadPartyMpVectorsImportable: + """Smoke test: verify load_party_mp_vectors is importable.""" + + def test_importable(self): + from explorer import load_party_mp_vectors + + assert callable(load_party_mp_vectors)