diff --git a/analysis/political_axis.py b/analysis/political_axis.py index 85a10d4..3adfd14 100644 --- a/analysis/political_axis.py +++ b/analysis/political_axis.py @@ -360,6 +360,7 @@ def compute_2d_axes( # project per-window vectors (centre by global mean) global_mean = M.mean(axis=0) + axes["global_mean"] = global_mean positions_by_window: Dict[str, Dict[str, Tuple[float, float]]] = { wid: {} for wid in window_ids } diff --git a/tests/test_political_compass.py b/tests/test_political_compass.py index 9bc432b..21b18d2 100644 --- a/tests/test_political_compass.py +++ b/tests/test_political_compass.py @@ -1,6 +1,115 @@ import numpy as np import types import sys +import types as _types + +# Provide a minimal duckdb stub when the real package is not available in the test env +try: + import duckdb as _duckdb +except Exception: + import pandas as _pd + + class FakeDuckDBConnection: + def __init__(self): + # storage for mp_votes rows: list of tuples matching _make_mp_votes_db + self._mp_votes = [] + + def execute(self, sql, params=None): + s = sql.strip().lower() + # simple create/select handling: return empty results for schema queries + if s.startswith("create table") or s.startswith( + "select distinct window_id" + ): + return _types.SimpleNamespace(fetchall=lambda: []) + + # compute_party_discipline query detection + if ( + "from rice_per_motion" in s + or "select\n party,\n count(distinct motion_id) as n_motions" + in sql + ): + # params: [start_date, end_date] + start_date, end_date = params or [None, None] + # filter rows by mp_name like '%,%' and date range and vote in ('voor','tegen') + rows = [r for r in self._mp_votes if ("," in (r[2] or ""))] + if start_date: + rows = [r for r in rows if r[5] >= start_date and r[5] <= end_date] + rows = [r for r in rows if (r[4] in ("voor", "tegen"))] + + # build counts per motion_id, party, vote + from collections import defaultdict + + counts = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) + motions = set() + for _id, motion_id, mp_name, party, vote, date, created_at in rows: + counts[motion_id][party][vote] += 1 + motions.add((motion_id, party)) + + # compute rice per (motion, party) + rice_vals = defaultdict(list) # party -> list of rice per motion + motion_part_set = set() + for motion_id, party_counts in counts.items(): + for party, vc in party_counts.items(): + total = sum(vc.values()) + if total == 0: + continue + # majority vote: vote with max count, tie-breaker by vote asc + maj_vote = sorted(vc.items(), key=lambda kv: (-kv[1], kv[0]))[ + 0 + ][0] + same = vc.get(maj_vote, 0) + rice = same / float(total) + rice_vals[party].append((motion_id, rice)) + motion_part_set.add((motion_id, party)) + + # aggregate per party + import pandas as pd + + rows_out = [] + for party, lst in rice_vals.items(): + n_motions = len({m for m, _ in lst}) + avg_rice = sum(r for _, r in lst) / n_motions if n_motions else 0.0 + rows_out.append( + {"party": party, "n_motions": n_motions, "discipline": avg_rice} + ) + + df = pd.DataFrame(rows_out) + return _types.SimpleNamespace(fetchdf=lambda: df) + + # default fallback + return _types.SimpleNamespace(fetchall=lambda: []) + + def executemany(self, sql, rows): + s = sql.strip().lower() + if s.startswith("insert into mp_votes"): + for r in rows: + self._mp_votes.append(r) + + def close(self): + return None + + _fake_duckdb = _types.ModuleType("duckdb") + _fake_duckdb.connect = lambda *a, **kw: FakeDuckDBConnection() + sys.modules["duckdb"] = _fake_duckdb + _duckdb = _fake_duckdb + +# Provide a minimal plotly.express stub so explorer imports in tests without requiring plotly +try: + import plotly.express as px # type: ignore +except Exception: + _px = types.ModuleType("plotly.express") + _px.scatter = lambda *a, **kw: None + _px.line = lambda *a, **kw: None + # Ensure top-level 'plotly' package exists and exposes express + _plotly_pkg = types.ModuleType("plotly") + _plotly_pkg.express = _px + sys.modules["plotly"] = _plotly_pkg + sys.modules["plotly.express"] = _px + px = _px + # stub plotly.graph_objects too + _go = types.ModuleType("plotly.graph_objects") + _go.Figure = lambda *a, **kw: None + sys.modules["plotly.graph_objects"] = _go import pytest @@ -469,3 +578,33 @@ def test_axis_classifier_missing_csv(tmp_path, monkeypatch): # Must not crash and must return the original axes dict unchanged assert result is axes assert "x_label" not in result + + +def test_compute_2d_axes_exposes_global_mean(monkeypatch): + """axes dict returned by compute_2d_axes must contain 'global_mean'.""" + fake_traj = types.SimpleNamespace() + fake_traj._load_window_ids = lambda db: ["w1"] + aligned = { + "w1": { + "Alice": np.array([1.0, 0.0, 0.0]), + "Bob": np.array([-1.0, 0.5, 0.0]), + } + } + fake_traj._load_mp_vectors_for_window = lambda db, w: aligned.get(w, {}) + fake_traj._procrustes_align_windows = lambda x: aligned + monkeypatch.setitem(sys.modules, "analysis.trajectory", fake_traj) + # Provide a minimal duckdb stub so importing analysis.political_axis succeeds + import types as _types + + fake_conn = _types.SimpleNamespace( + execute=lambda q: _types.SimpleNamespace(fetchall=lambda: []), + close=lambda: None, + ) + fake_duckdb = _types.SimpleNamespace(connect=lambda db_path, **kw: fake_conn) + monkeypatch.setitem(sys.modules, "duckdb", fake_duckdb) + + from analysis.political_axis import compute_2d_axes + + _, axis_def = compute_2d_axes(db_path="dummy", window_ids=["w1"], method="pca") + assert "global_mean" in axis_def + assert isinstance(axis_def["global_mean"], np.ndarray)