You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/tests/test_explorer_chart.py

188 lines
6.2 KiB

"""Tests for _build_party_axis_figure and load_party_mp_vectors in explorer.py."""
import numpy as np
import plotly.graph_objects as go
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_party_scores(n_parties=3, dim=50):
"""Return a minimal party_scores dict for testing."""
rng = np.random.default_rng(0)
names = [f"Party{i}" for i in range(n_parties)]
return {name: rng.standard_normal(dim).tolist() for name in names}
def _make_theme(flip=False):
return {
"label": "Test axis",
"explanation": "A test axis.",
"positive_pole": "Left",
"negative_pole": "Right",
"flip": flip,
}
def _make_bootstrap_data(party_scores, dim=50):
"""Build synthetic bootstrap_data matching party_scores keys.
Party0 gets n_mps=1 (single-MP party → diamond marker).
Others get n_mps > 1 with a real CI spread.
"""
rng = np.random.default_rng(1)
result = {}
for i, party in enumerate(party_scores):
centroid = np.array(party_scores[party])
if i == 0:
# Single-MP party
result[party] = {
"centroid": centroid,
"ci_lower": centroid.copy(),
"ci_upper": centroid.copy(),
"std": np.zeros(dim),
"n_mps": 1,
}
else:
spread = rng.uniform(0.01, 0.05, size=dim)
result[party] = {
"centroid": centroid,
"ci_lower": centroid - spread,
"ci_upper": centroid + spread,
"std": spread / 2,
"n_mps": 5 + i,
}
return result
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestBuildPartyAxisFigure:
"""Tests for _build_party_axis_figure (pure Plotly figure construction)."""
def test_returns_figure_without_bootstrap(self):
"""Basic call without bootstrap → returns go.Figure with 2 traces."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
theme = _make_theme()
fig = _build_party_axis_figure(party_scores, comp_sel=1, theme=theme)
assert isinstance(fig, go.Figure)
assert len(fig.data) == 2 # baseline + markers
# First trace is the baseline line
assert fig.data[0].mode == "lines"
# Second trace is the marker scatter
assert "markers" in fig.data[1].mode
def test_returns_none_for_empty_scores(self):
"""Empty party_scores returns None (no figure)."""
from explorer import _build_party_axis_figure
fig = _build_party_axis_figure({}, comp_sel=1, theme=_make_theme())
assert fig is None
def test_with_bootstrap_has_diamonds_for_single_mp(self):
"""bootstrap_data present → N=1 party gets diamond, others get circle. No error bars."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
theme = _make_theme()
bootstrap_data = _make_bootstrap_data(party_scores)
fig = _build_party_axis_figure(
party_scores,
comp_sel=1,
theme=theme,
bootstrap_data=bootstrap_data,
)
assert isinstance(fig, go.Figure)
assert len(fig.data) == 2
marker_trace = fig.data[1]
# No visual error bars — CIs are in hover text only
assert (
marker_trace.error_x.array is None
or marker_trace.error_x.visible is not True
)
# Marker symbols: first party (N=1) → diamond, others → circle
symbols = list(marker_trace.marker.symbol)
assert symbols[0] == "diamond"
assert all(s == "circle" for s in symbols[1:])
def test_with_bootstrap_hover_includes_n_and_ci(self):
"""Hover text includes N=<count> and 95%-BI interval for each party."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
theme = _make_theme()
bootstrap_data = _make_bootstrap_data(party_scores)
fig = _build_party_axis_figure(
party_scores,
comp_sel=1,
theme=theme,
bootstrap_data=bootstrap_data,
)
marker_trace = fig.data[1]
for ht in marker_trace.hovertext:
assert "(N=" in ht
assert "95%-BI" in ht
def test_flip_negates_scores(self):
"""When flip=True, scores are negated relative to flip=False."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
theme_no_flip = _make_theme(flip=False)
theme_flip = _make_theme(flip=True)
bootstrap_data = _make_bootstrap_data(party_scores)
fig_normal = _build_party_axis_figure(
party_scores,
comp_sel=1,
theme=theme_no_flip,
bootstrap_data=bootstrap_data,
)
fig_flipped = _build_party_axis_figure(
party_scores,
comp_sel=1,
theme=theme_flip,
bootstrap_data=bootstrap_data,
)
normal_scores = list(fig_normal.data[1].x)
flipped_scores = list(fig_flipped.data[1].x)
# Scores should be negated
for ns, fs in zip(normal_scores, flipped_scores):
assert pytest.approx(ns) == -fs
def test_without_bootstrap_hover_is_score_only(self):
"""Without bootstrap data, hover text is just 'Party: score' with no CI."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
fig = _build_party_axis_figure(party_scores, comp_sel=1, theme=_make_theme())
marker_trace = fig.data[1]
for ht in marker_trace.hovertext:
assert "95%-BI" not in ht
assert "(N=" not in ht
class TestLoadPartyMpVectorsImportable:
"""Smoke test: verify load_party_mp_vectors is importable."""
def test_importable(self):
from explorer import load_party_mp_vectors
assert callable(load_party_mp_vectors)