From a5e95c33d70eb6fc3679fd5e9ccf48f32685cc40 Mon Sep 17 00:00:00 2001 From: Sven Geboers Date: Thu, 2 Apr 2026 21:39:09 +0200 Subject: [PATCH] refactor: use scatter plot format for SVD components 3-10 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Changed _render_party_axis_chart_1d from horizontal bar chart to scatter plot - Same format as components 1-2: markers on horizontal line with axis arrows- Axis labels now show correct direction with arrows (← left | right →) - Ensures consistent visualization across all SVD components --- explorer.py | 101 +++++++++++++++++++++++------------ tests/test_explorer_chart.py | 10 ++-- 2 files changed, 75 insertions(+), 36 deletions(-) diff --git a/explorer.py b/explorer.py index 399e35d..6e6397e 100644 --- a/explorer.py +++ b/explorer.py @@ -1357,7 +1357,10 @@ def _render_party_axis_chart_1d( comp_sel: int, theme: dict, ) -> None: - """Render a 1D horizontal bar chart of party positions on SVD component `comp_sel`. + """Render a 1D horizontal scatter of party positions on SVD component `comp_sel`. + + Uses the same format as components 1-2: parties as markers on a horizontal line + with axis title showing poles with arrows. Args: party_coords: Dict mapping party name to tuple of scores (score_for_comp,) @@ -1371,54 +1374,86 @@ def _render_party_axis_chart_1d( return # Extract scores and parties - parties = list(party_coords.keys()) - scores = [coords[0] for coords in party_coords.values()] + parties = [] + scores = [] + colours = [] + + for party, coords in party_coords.items(): + try: + score = float(coords[0]) + parties.append(party) + scores.append(score) + colours.append(PARTY_COLOURS.get(party, "#9E9E9E")) + except Exception: + continue + + if not scores: + st.caption("_Partijdata niet beschikbaar voor deze as._") + return - # Apply flip if needed + # Apply flip if needed (ensures right parties appear on right side) flip = theme.get("flip", False) if flip: scores = [-s for s in scores] - # Get party colors - party_colors = [PARTY_COLOURS.get(p, "#9E9E9E") for p in parties] + # Build hover text + hover = [f"{p}: {s:.3f}" for p, s in zip(parties, scores)] - # Sort by score for better visualization - sorted_indices = np.argsort(scores) - sorted_parties = [parties[i] for i in sorted_indices] - sorted_scores = [scores[i] for i in sorted_indices] - sorted_colors = [party_colors[i] for i in sorted_indices] - - # Create horizontal bar chart + # Create figure with same format as components 1-2 fig = go.Figure() + x_min, x_max = min(scores) * 1.15, max(scores) * 1.15 + if x_min == x_max: + x_min, x_max = x_min - 1, x_max + 1 + # Add horizontal axis line fig.add_trace( - go.Bar( - y=sorted_parties, - x=sorted_scores, - orientation="h", - marker_color=sorted_colors, - text=[f"{s:.2f}" for s in sorted_scores], - textposition="outside", + go.Scatter( + x=[x_min, x_max], + y=[0, 0], + mode="lines", + line={"color": "#cccccc", "width": 1}, + hoverinfo="skip", + showlegend=False, ) ) - # Update layout - label = theme.get("label", f"As {comp_sel}") - positive_pole = theme.get("positive_pole", "Positief") - negative_pole = theme.get("negative_pole", "Negatief") + # Add party markers + fig.add_trace( + go.Scatter( + x=scores, + y=[0] * len(scores), + mode="markers+text", + text=parties, + textposition="top center", + marker={"size": 14, "color": colours}, + hovertext=hover, + hoverinfo="text", + showlegend=False, + ) + ) + + # Determine pole labels based on flip + pos_pole = theme.get("positive_pole", "") + neg_pole = theme.get("negative_pole", "") + left_label = pos_pole if flip else neg_pole + right_label = neg_pole if flip else pos_pole + # Update layout with same format as components 1-2 fig.update_layout( - title=f"Partijposities — {label}", - xaxis_title=f"{negative_pole} ← → {positive_pole}", - yaxis_title="", - height=max(400, len(parties) * 25), - margin=dict(l=150), - showlegend=False, + height=160, + margin={"l": 10, "r": 10, "t": 10, "b": 30}, + xaxis={ + "title": f"← {left_label} | {right_label} →", + "showticklabels": False, + "showline": False, + "showgrid": False, + "zeroline": False, + }, + yaxis={"visible": False, "range": [-1, 2]}, + plot_bgcolor="rgba(0,0,0,0)", + paper_bgcolor="rgba(0,0,0,0)", ) - # Add vertical line at x=0 - fig.add_vline(x=0, line_dash="dash", line_color="gray", opacity=0.5) - st.plotly_chart(fig, use_container_width=True) diff --git a/tests/test_explorer_chart.py b/tests/test_explorer_chart.py index b016db8..9dbe803 100644 --- a/tests/test_explorer_chart.py +++ b/tests/test_explorer_chart.py @@ -281,7 +281,7 @@ def test_partial_party_traces(): def test_render_party_axis_chart_1d_renders(): - """Test that _render_party_axis_chart_1d creates a figure with proper structure.""" + """Test that _render_party_axis_chart_1d creates a scatter plot with markers (same format as components 1-2).""" from unittest.mock import MagicMock, patch from explorer import _render_party_axis_chart_1d @@ -310,8 +310,12 @@ def test_render_party_axis_chart_1d_renders(): # Get the figure passed to plotly_chart fig = mock_plotly_chart.call_args[0][0] assert fig is not None, "Figure should not be None" - # Check that figure has traces (the bar chart) - assert len(fig.data) > 0, "Figure should have traces" + # Check that figure has 2 traces (baseline line + markers) + assert len(fig.data) == 2, "Figure should have 2 traces (baseline + markers)" + # First trace is the baseline line + assert fig.data[0].mode == "lines", "First trace should be a line" + # Second trace is the marker scatter + assert "markers" in fig.data[1].mode, "Second trace should have markers" def test_render_party_axis_chart_1d_empty_coords():