diff --git a/explorer.py b/explorer.py index ea10452..f312222 100644 --- a/explorer.py +++ b/explorer.py @@ -189,16 +189,21 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]: Returns: {party_name: [float * k]} — k = 50 for the canonical 2025 window """ + con = None try: + # Use a deterministic, ordered list for parameter binding + party_list = sorted(CURRENT_PARLIAMENT_PARTIES) + if not party_list: + return {} + con = duckdb.connect(database=db_path, read_only=True) - placeholders = ", ".join("?" for _ in CURRENT_PARLIAMENT_PARTIES) + placeholders = ", ".join("?" for _ in party_list) rows = con.execute( f"SELECT entity_id, vector FROM svd_vectors " f"WHERE entity_type='mp' AND window_id='2025' " f"AND entity_id IN ({placeholders})", - list(CURRENT_PARLIAMENT_PARTIES), + party_list, ).fetchall() - con.close() return { row[0]: json.loads(row[1]) if isinstance(row[1], str) else list(row[1]) for row in rows @@ -206,6 +211,9 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]: except Exception: logger.exception("Failed to load party axis scores") return {} + finally: + if con is not None: + con.close() @st.cache_data(show_spinner="Moties laden…") @@ -231,6 +239,97 @@ def load_motions_df(db_path: str) -> pd.DataFrame: con.close() +def _render_party_axis_chart( + party_scores: Dict[str, List[float]], comp_sel: int +) -> None: + """Render a 1D horizontal Plotly scatter of party positions on SVD axis `comp_sel`. + + party_scores: mapping party -> list-like vector (50-dim) + comp_sel: 1-based component index + """ + # Validate component selection + if not isinstance(comp_sel, int) or comp_sel < 1: + st.caption("_Ongeldige SVD-as geselecteerd._") + return + + if not party_scores: + st.caption("_Partijdata niet beschikbaar_") + return + + axis_idx = comp_sel - 1 + + parties: List[str] = [] + xs: List[float] = [] + + for party, vec in party_scores.items(): + # Ensure vec is indexable/sequence-like + if not isinstance(vec, (list, tuple, np.ndarray)): + # skip malformed entries + continue + try: + raw = vec[axis_idx] + # Convert to float safely + val = float(raw) + except Exception: + # skip entries that cannot be indexed or converted + continue + parties.append(party) + xs.append(val) + + if not xs: + st.caption("_Partijdata niet beschikbaar_") + return + + try: + x_min = min(xs) + x_max = max(xs) + except Exception: + st.caption("_Onvoldoende gegevens om asbereik te berekenen_") + return + + # If min == max, apply symmetric padding around the value. + if x_min == x_max: + padding = 0.5 if x_min == 0 else abs(x_min) * 0.1 + if padding <= 0: + padding = 0.5 + x_min = x_min - padding + x_max = x_max + padding + else: + # Expand range slightly for visual padding + x_min = x_min * 1.15 + x_max = x_max * 1.15 + + # Build horizontal scatter: y is constant (0) but offset for label placement + ys = [0 for _ in xs] + + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=xs, + y=ys, + mode="markers+text", + text=parties, + textposition="top center", + marker=dict( + size=10, color=[PARTY_COLOURS.get(p, "#9E9E9E") for p in parties] + ), + hovertemplate="%{text}
x: %{x:.3f}", + ) + ) + + fig.update_layout( + title=f"Partijposities op SVD-as {comp_sel}", + xaxis_title="Negatief ← — → Positief", + yaxis=dict(visible=False), + xaxis=dict(range=[x_min, x_max]), + height=300, + margin=dict(t=40, b=40, l=40, r=40), + showlegend=False, + ) + + st.plotly_chart(fig, use_container_width=True) + + def query_similar( db_path: str, source_motion_id: int,