- Add load_party_mp_vectors() to return raw per-MP SVD vectors by party - Extract _build_party_axis_figure() as pure function for testability - Modify _render_party_axis_chart to accept bootstrap_data and delegate to the new builder - When bootstrap_data present: show error_x bars, diamond markers for N=1 parties, and N=count in hover text - Wire up bootstrap computation in build_svd_components_tab via cached _cached_bootstrap_cis wrapper - Add 6 tests covering figure construction, bootstrap rendering, flip behavior, and importabilitymain
parent
88110b0aaa
commit
3938eecc53
@ -0,0 +1,175 @@ |
||||
"""Tests for _build_party_axis_figure and load_party_mp_vectors in explorer.py.""" |
||||
|
||||
import numpy as np |
||||
import plotly.graph_objects as go |
||||
import pytest |
||||
|
||||
|
||||
# --------------------------------------------------------------------------- |
||||
# Helpers |
||||
# --------------------------------------------------------------------------- |
||||
|
||||
|
||||
def _make_party_scores(n_parties=3, dim=50): |
||||
"""Return a minimal party_scores dict for testing.""" |
||||
rng = np.random.default_rng(0) |
||||
names = [f"Party{i}" for i in range(n_parties)] |
||||
return {name: rng.standard_normal(dim).tolist() for name in names} |
||||
|
||||
|
||||
def _make_theme(flip=False): |
||||
return { |
||||
"label": "Test axis", |
||||
"explanation": "A test axis.", |
||||
"positive_pole": "Left", |
||||
"negative_pole": "Right", |
||||
"flip": flip, |
||||
} |
||||
|
||||
|
||||
def _make_bootstrap_data(party_scores, dim=50): |
||||
"""Build synthetic bootstrap_data matching party_scores keys. |
||||
|
||||
Party0 gets n_mps=1 (single-MP party → diamond marker). |
||||
Others get n_mps > 1 with a real CI spread. |
||||
""" |
||||
rng = np.random.default_rng(1) |
||||
result = {} |
||||
for i, party in enumerate(party_scores): |
||||
centroid = np.array(party_scores[party]) |
||||
if i == 0: |
||||
# Single-MP party |
||||
result[party] = { |
||||
"centroid": centroid, |
||||
"ci_lower": centroid.copy(), |
||||
"ci_upper": centroid.copy(), |
||||
"std": np.zeros(dim), |
||||
"n_mps": 1, |
||||
} |
||||
else: |
||||
spread = rng.uniform(0.01, 0.05, size=dim) |
||||
result[party] = { |
||||
"centroid": centroid, |
||||
"ci_lower": centroid - spread, |
||||
"ci_upper": centroid + spread, |
||||
"std": spread / 2, |
||||
"n_mps": 5 + i, |
||||
} |
||||
return result |
||||
|
||||
|
||||
# --------------------------------------------------------------------------- |
||||
# Tests |
||||
# --------------------------------------------------------------------------- |
||||
|
||||
|
||||
class TestBuildPartyAxisFigure: |
||||
"""Tests for _build_party_axis_figure (pure Plotly figure construction).""" |
||||
|
||||
def test_returns_figure_without_bootstrap(self): |
||||
"""Basic call without bootstrap → returns go.Figure with 2 traces.""" |
||||
from explorer import _build_party_axis_figure |
||||
|
||||
party_scores = _make_party_scores() |
||||
theme = _make_theme() |
||||
fig = _build_party_axis_figure(party_scores, comp_sel=1, theme=theme) |
||||
|
||||
assert isinstance(fig, go.Figure) |
||||
assert len(fig.data) == 2 # baseline + markers |
||||
# First trace is the baseline line |
||||
assert fig.data[0].mode == "lines" |
||||
# Second trace is the marker scatter |
||||
assert "markers" in fig.data[1].mode |
||||
|
||||
def test_returns_none_for_empty_scores(self): |
||||
"""Empty party_scores returns None (no figure).""" |
||||
from explorer import _build_party_axis_figure |
||||
|
||||
fig = _build_party_axis_figure({}, comp_sel=1, theme=_make_theme()) |
||||
assert fig is None |
||||
|
||||
def test_with_bootstrap_has_error_x_and_diamonds(self): |
||||
"""Call WITH bootstrap_data → error_x on marker trace, diamond for N=1.""" |
||||
from explorer import _build_party_axis_figure |
||||
|
||||
party_scores = _make_party_scores() |
||||
theme = _make_theme() |
||||
bootstrap_data = _make_bootstrap_data(party_scores) |
||||
fig = _build_party_axis_figure( |
||||
party_scores, comp_sel=1, theme=theme, bootstrap_data=bootstrap_data |
||||
) |
||||
|
||||
assert isinstance(fig, go.Figure) |
||||
assert len(fig.data) == 2 |
||||
|
||||
marker_trace = fig.data[1] |
||||
|
||||
# error_x should be present and visible |
||||
assert marker_trace.error_x is not None |
||||
assert marker_trace.error_x.visible is True |
||||
assert marker_trace.error_x.type == "data" |
||||
assert len(marker_trace.error_x.array) == 3 # 3 parties |
||||
|
||||
# All error bar values should be non-negative |
||||
for err in marker_trace.error_x.array: |
||||
assert err >= 0.0 |
||||
|
||||
# Marker symbols: first party (N=1) → diamond, others → circle |
||||
symbols = list(marker_trace.marker.symbol) |
||||
assert symbols[0] == "diamond" |
||||
assert all(s == "circle" for s in symbols[1:]) |
||||
|
||||
def test_with_bootstrap_hover_includes_n(self): |
||||
"""Hover text includes N=<count> for each party.""" |
||||
from explorer import _build_party_axis_figure |
||||
|
||||
party_scores = _make_party_scores() |
||||
theme = _make_theme() |
||||
bootstrap_data = _make_bootstrap_data(party_scores) |
||||
fig = _build_party_axis_figure( |
||||
party_scores, comp_sel=1, theme=theme, bootstrap_data=bootstrap_data |
||||
) |
||||
|
||||
marker_trace = fig.data[1] |
||||
for ht in marker_trace.hovertext: |
||||
assert "(N=" in ht |
||||
|
||||
def test_flip_negates_scores_but_error_bars_stay_positive(self): |
||||
"""When flip=True, scores are negated but error bar magnitudes stay positive.""" |
||||
from explorer import _build_party_axis_figure |
||||
|
||||
party_scores = _make_party_scores() |
||||
theme_no_flip = _make_theme(flip=False) |
||||
theme_flip = _make_theme(flip=True) |
||||
bootstrap_data = _make_bootstrap_data(party_scores) |
||||
|
||||
fig_normal = _build_party_axis_figure( |
||||
party_scores, comp_sel=1, theme=theme_no_flip, bootstrap_data=bootstrap_data |
||||
) |
||||
fig_flipped = _build_party_axis_figure( |
||||
party_scores, comp_sel=1, theme=theme_flip, bootstrap_data=bootstrap_data |
||||
) |
||||
|
||||
normal_scores = list(fig_normal.data[1].x) |
||||
flipped_scores = list(fig_flipped.data[1].x) |
||||
|
||||
# Scores should be negated |
||||
for ns, fs in zip(normal_scores, flipped_scores): |
||||
assert pytest.approx(ns) == -fs |
||||
|
||||
# Error bars should be the same (positive) in both cases |
||||
normal_errors = list(fig_normal.data[1].error_x.array) |
||||
flipped_errors = list(fig_flipped.data[1].error_x.array) |
||||
for ne, fe in zip(normal_errors, flipped_errors): |
||||
assert ne >= 0.0 |
||||
assert fe >= 0.0 |
||||
assert pytest.approx(ne) == fe |
||||
|
||||
|
||||
class TestLoadPartyMpVectorsImportable: |
||||
"""Smoke test: verify load_party_mp_vectors is importable.""" |
||||
|
||||
def test_importable(self): |
||||
from explorer import load_party_mp_vectors |
||||
|
||||
assert callable(load_party_mp_vectors) |
||||
Loading…
Reference in new issue