diff --git a/explorer.py b/explorer.py index 3533b9d..4864465 100644 --- a/explorer.py +++ b/explorer.py @@ -245,6 +245,81 @@ def load_active_mps(db_path: str) -> set: return set() +def compute_party_discipline( + db_path: str, + start_date: str, + end_date: str, +) -> pd.DataFrame: + """Compute per-party voting discipline (Rice index) for roll-call votes in a date range. + + Only individual MP vote rows are used (mp_name LIKE '%,%'). + Returns a DataFrame with columns [party, n_motions, discipline] sorted by discipline ascending. + Returns an empty DataFrame if fewer than 1 qualifying motion exists or on any DB error. + + Rice index per motion per party = fraction of party MPs voting with the party majority. + The per-party score is the average Rice index across all motions in the date range. + """ + try: + conn = duckdb.connect(db_path, read_only=True) + result = conn.execute( + """ + WITH individual_votes AS ( + SELECT + motion_id, + party, + LOWER(vote) AS vote + FROM mp_votes + WHERE mp_name LIKE '%,%' + AND date >= CAST(? AS DATE) + AND date <= CAST(? AS DATE) + AND vote IN ('voor', 'tegen', 'afwezig', 'onthouden') + ), + vote_counts AS ( + SELECT + motion_id, + party, + vote, + COUNT(*) AS cnt + FROM individual_votes + GROUP BY motion_id, party, vote + ), + majority_vote AS ( + SELECT + motion_id, + party, + FIRST(vote ORDER BY cnt DESC, vote ASC) AS maj_vote, + SUM(cnt) AS total_mp_votes + FROM vote_counts + GROUP BY motion_id, party + ), + rice_per_motion AS ( + SELECT + mv.motion_id, + mv.party, + SUM(CASE WHEN vc.vote = mv.maj_vote THEN vc.cnt ELSE 0 END) + * 1.0 / mv.total_mp_votes AS rice + FROM majority_vote mv + JOIN vote_counts vc + ON mv.motion_id = vc.motion_id AND mv.party = vc.party + GROUP BY mv.motion_id, mv.party, mv.total_mp_votes + ) + SELECT + party, + COUNT(DISTINCT motion_id) AS n_motions, + AVG(rice) AS discipline + FROM rice_per_motion + GROUP BY party + ORDER BY discipline ASC + """, + [start_date, end_date], + ).fetchdf() + conn.close() + return result + except Exception as exc: + logger.warning("compute_party_discipline failed: %s", exc) + return pd.DataFrame(columns=["party", "n_motions", "discipline"]) + + @st.cache_data(show_spinner="Partijposities op SVD-assen laden…") def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]: """Return per-party SVD vectors, computed as mean of individual MP vectors. @@ -704,6 +779,26 @@ def _add_y_direction_annotations(fig: go.Figure) -> None: fig.add_annotation(**common, y=-0.06, text="▼ Conservatief", xanchor="center") +def _window_to_dates(window_id: str) -> tuple[str, str]: + """Return (start_date, end_date) ISO strings for a given window_id. + + Annual windows like '2024' → ('2024-01-01', '2024-12-31'). + 'current_parliament' → ('2023-11-22', '2099-12-31') (2023 formation date, open end). + Unknown formats → ('2000-01-01', '2099-12-31') (effectively all time). + """ + if window_id == "current_parliament": + return ("2023-11-22", "2099-12-31") + if re.fullmatch(r"\d{4}", window_id): + return (f"{window_id}-01-01", f"{window_id}-12-31") + m = re.fullmatch(r"(\d{4})-Q([1-4])", window_id) + if m: + year, q = int(m.group(1)), int(m.group(2)) + starts = {1: "01-01", 2: "04-01", 3: "07-01", 4: "10-01"} + ends = {1: "03-31", 2: "06-30", 3: "09-30", 4: "12-31"} + return (f"{year}-{starts[q]}", f"{year}-{ends[q]}") + return ("2000-01-01", "2099-12-31") + + def build_compass_tab(db_path: str, window_size: str) -> None: st.subheader("Politiek Kompas") st.markdown( @@ -854,6 +949,81 @@ def build_compass_tab(db_path: str, window_size: str) -> None: with col1: st.plotly_chart(fig, use_container_width=True) + # --- Voting discipline section --- + _MIN_MOTIONS_FOR_DISCIPLINE = 5 + start_date, end_date = _window_to_dates(window_idx) + disc_df = compute_party_discipline(db_path, start_date, end_date) + + st.subheader("Stemgedrag cohesie") + if disc_df.empty or disc_df["n_motions"].max() < _MIN_MOTIONS_FOR_DISCIPLINE: + st.caption( + "Te weinig hoofdelijke stemmingen in dit venster voor een cohesieanalyse." + ) + else: + compass_parties = set(df_pos["party"].unique()) + disc_df = disc_df[disc_df["party"].isin(compass_parties)].copy() + + if disc_df.empty: + st.caption("Geen overlappende partijen tussen kompas en stemmingsdata.") + else: + disc_df["discipline_pct"] = (disc_df["discipline"] * 100).round(1) + disc_df["party_label"] = disc_df.apply( + lambda r: f"{r['party']} ({int(r['n_motions'])} moties)", axis=1 + ) + + bar_fig = px.bar( + disc_df.sort_values("discipline"), + x="discipline_pct", + y="party_label", + orientation="h", + color="discipline_pct", + color_continuous_scale="RdYlGn", + range_color=[80, 100], + labels={"discipline_pct": "Cohesie (%)", "party_label": "Partij"}, + title="Cohesie bij hoofdelijke stemmingen", + ) + bar_fig.update_layout( + height=max(300, len(disc_df) * 35 + 80), + showlegend=False, + coloraxis_showscale=False, + yaxis_title="", + ) + st.plotly_chart(bar_fig, use_container_width=True) + + top3 = disc_df.nlargest(3, "discipline")[ + ["party", "discipline_pct", "n_motions"] + ] + bot3 = disc_df.nsmallest(3, "discipline")[ + ["party", "discipline_pct", "n_motions"] + ] + col_a, col_b = st.columns(2) + with col_a: + st.markdown("**Meest eensgezind**") + st.dataframe( + top3.rename( + columns={ + "party": "Partij", + "discipline_pct": "Cohesie (%)", + "n_motions": "Moties", + } + ), + hide_index=True, + use_container_width=True, + ) + with col_b: + st.markdown("**Meest verdeeld**") + st.dataframe( + bot3.rename( + columns={ + "party": "Partij", + "discipline_pct": "Cohesie (%)", + "n_motions": "Moties", + } + ), + hide_index=True, + use_container_width=True, + ) + # --------------------------------------------------------------------------- # Tab 2: Partij Trajectories diff --git a/tests/test_political_compass.py b/tests/test_political_compass.py index d5c3f28..c97d1d0 100644 --- a/tests/test_political_compass.py +++ b/tests/test_political_compass.py @@ -237,3 +237,131 @@ def test_pca_axis_orientation(monkeypatch): assert prog_y > cons_y, ( f"Expected progressive parties (y={prog_y:.3f}) > conservative parties (y={cons_y:.3f}) on Y-axis" ) + + +# --------------------------------------------------------------------------- +# Tests for compute_party_discipline +# --------------------------------------------------------------------------- + + +def _make_mp_votes_db(): + """Create an in-memory DuckDB with mp_votes fixture data. + + 6 motions, 2 parties (SP, VVD), each with 4 MPs. + SP is perfectly disciplined (all 4 vote the same each time). + VVD has 1 dissident on 2 of 6 motions → Rice index = (4+4+4+4+3+3)/6/4 ≈ 0.917. + Dates span 2023-01-01 to 2023-12-31. + """ + import duckdb + + conn = duckdb.connect(":memory:") + conn.execute(""" + CREATE TABLE mp_votes ( + id INTEGER, + motion_id VARCHAR, + mp_name VARCHAR, + party VARCHAR, + vote VARCHAR, + date DATE, + created_at TIMESTAMP + ) + """) + rows = [] + dates = [ + "2023-01-10", + "2023-03-15", + "2023-05-20", + "2023-07-25", + "2023-09-30", + "2023-11-05", + ] + sp_mps = ["Janssen, A.", "Pietersen, B.", "Willemsen, C.", "Hendriksen, D."] + vvd_mps = ["Adams, E.", "Bakker, F.", "Claassen, G.", "Dekker, H."] + for i, date in enumerate(dates, start=1): + m_id = f"M{i:03d}" + for mp in sp_mps: + rows.append((i * 10 + 1, m_id, mp, "SP", "voor", date, "2023-01-01")) + if i <= 4: + for mp in vvd_mps: + rows.append((i * 10 + 2, m_id, mp, "VVD", "voor", date, "2023-01-01")) + else: + for mp in vvd_mps[:3]: + rows.append((i * 10 + 2, m_id, mp, "VVD", "voor", date, "2023-01-01")) + rows.append( + (i * 10 + 3, m_id, vvd_mps[3], "VVD", "tegen", date, "2023-01-01") + ) + conn.executemany("INSERT INTO mp_votes VALUES (?, ?, ?, ?, ?, ?, ?)", rows) + return conn + + +def test_compute_party_discipline_basic(monkeypatch): + """compute_party_discipline returns correct Rice index for fixture data.""" + import duckdb as _duckdb + + fixture_conn = _make_mp_votes_db() + + monkeypatch.setattr(_duckdb, "connect", lambda path, **kw: fixture_conn) + + import importlib + import sys + + if "streamlit" not in sys.modules: + import types + + st_stub = types.ModuleType("streamlit") + st_stub.cache_data = lambda **kw: lambda f: f + sys.modules["streamlit"] = st_stub + + import explorer as _explorer + + importlib.reload(_explorer) + + df = _explorer.compute_party_discipline( + db_path="dummy", + start_date="2023-01-01", + end_date="2023-12-31", + ) + + assert not df.empty + assert set(df.columns) >= {"party", "n_motions", "discipline"} + + sp_row = df[df["party"] == "SP"].iloc[0] + vvd_row = df[df["party"] == "VVD"].iloc[0] + + assert sp_row["n_motions"] == 6 + assert sp_row["discipline"] == pytest.approx(1.0, abs=1e-6) + + assert vvd_row["n_motions"] == 6 + expected_vvd = (4 * 1.0 + 2 * 0.75) / 6 + assert vvd_row["discipline"] == pytest.approx(expected_vvd, abs=1e-4) + + assert (df["discipline"] >= 0).all() and (df["discipline"] <= 1).all() + + +def test_compute_party_discipline_empty_range(monkeypatch): + """Returns empty DataFrame when no motions fall in the date range.""" + import duckdb as _duckdb + + fixture_conn = _make_mp_votes_db() + monkeypatch.setattr(_duckdb, "connect", lambda path, **kw: fixture_conn) + + import importlib, sys + + if "streamlit" not in sys.modules: + import types + + st_stub = types.ModuleType("streamlit") + st_stub.cache_data = lambda **kw: lambda f: f + sys.modules["streamlit"] = st_stub + + import explorer as _explorer + + importlib.reload(_explorer) + + df = _explorer.compute_party_discipline( + db_path="dummy", + start_date="2000-01-01", + end_date="2000-12-31", + ) + + assert df.empty