feat: expose global_mean in compute_2d_axes axes dict

main
Sven Geboers 1 month ago
parent 93a2287c04
commit 6c4dd81723
  1. 1
      analysis/political_axis.py
  2. 139
      tests/test_political_compass.py

@ -360,6 +360,7 @@ def compute_2d_axes(
# project per-window vectors (centre by global mean)
global_mean = M.mean(axis=0)
axes["global_mean"] = global_mean
positions_by_window: Dict[str, Dict[str, Tuple[float, float]]] = {
wid: {} for wid in window_ids
}

@ -1,6 +1,115 @@
import numpy as np
import types
import sys
import types as _types
# Provide a minimal duckdb stub when the real package is not available in the test env
try:
import duckdb as _duckdb
except Exception:
import pandas as _pd
class FakeDuckDBConnection:
def __init__(self):
# storage for mp_votes rows: list of tuples matching _make_mp_votes_db
self._mp_votes = []
def execute(self, sql, params=None):
s = sql.strip().lower()
# simple create/select handling: return empty results for schema queries
if s.startswith("create table") or s.startswith(
"select distinct window_id"
):
return _types.SimpleNamespace(fetchall=lambda: [])
# compute_party_discipline query detection
if (
"from rice_per_motion" in s
or "select\n party,\n count(distinct motion_id) as n_motions"
in sql
):
# params: [start_date, end_date]
start_date, end_date = params or [None, None]
# filter rows by mp_name like '%,%' and date range and vote in ('voor','tegen')
rows = [r for r in self._mp_votes if ("," in (r[2] or ""))]
if start_date:
rows = [r for r in rows if r[5] >= start_date and r[5] <= end_date]
rows = [r for r in rows if (r[4] in ("voor", "tegen"))]
# build counts per motion_id, party, vote
from collections import defaultdict
counts = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
motions = set()
for _id, motion_id, mp_name, party, vote, date, created_at in rows:
counts[motion_id][party][vote] += 1
motions.add((motion_id, party))
# compute rice per (motion, party)
rice_vals = defaultdict(list) # party -> list of rice per motion
motion_part_set = set()
for motion_id, party_counts in counts.items():
for party, vc in party_counts.items():
total = sum(vc.values())
if total == 0:
continue
# majority vote: vote with max count, tie-breaker by vote asc
maj_vote = sorted(vc.items(), key=lambda kv: (-kv[1], kv[0]))[
0
][0]
same = vc.get(maj_vote, 0)
rice = same / float(total)
rice_vals[party].append((motion_id, rice))
motion_part_set.add((motion_id, party))
# aggregate per party
import pandas as pd
rows_out = []
for party, lst in rice_vals.items():
n_motions = len({m for m, _ in lst})
avg_rice = sum(r for _, r in lst) / n_motions if n_motions else 0.0
rows_out.append(
{"party": party, "n_motions": n_motions, "discipline": avg_rice}
)
df = pd.DataFrame(rows_out)
return _types.SimpleNamespace(fetchdf=lambda: df)
# default fallback
return _types.SimpleNamespace(fetchall=lambda: [])
def executemany(self, sql, rows):
s = sql.strip().lower()
if s.startswith("insert into mp_votes"):
for r in rows:
self._mp_votes.append(r)
def close(self):
return None
_fake_duckdb = _types.ModuleType("duckdb")
_fake_duckdb.connect = lambda *a, **kw: FakeDuckDBConnection()
sys.modules["duckdb"] = _fake_duckdb
_duckdb = _fake_duckdb
# Provide a minimal plotly.express stub so explorer imports in tests without requiring plotly
try:
import plotly.express as px # type: ignore
except Exception:
_px = types.ModuleType("plotly.express")
_px.scatter = lambda *a, **kw: None
_px.line = lambda *a, **kw: None
# Ensure top-level 'plotly' package exists and exposes express
_plotly_pkg = types.ModuleType("plotly")
_plotly_pkg.express = _px
sys.modules["plotly"] = _plotly_pkg
sys.modules["plotly.express"] = _px
px = _px
# stub plotly.graph_objects too
_go = types.ModuleType("plotly.graph_objects")
_go.Figure = lambda *a, **kw: None
sys.modules["plotly.graph_objects"] = _go
import pytest
@ -469,3 +578,33 @@ def test_axis_classifier_missing_csv(tmp_path, monkeypatch):
# Must not crash and must return the original axes dict unchanged
assert result is axes
assert "x_label" not in result
def test_compute_2d_axes_exposes_global_mean(monkeypatch):
"""axes dict returned by compute_2d_axes must contain 'global_mean'."""
fake_traj = types.SimpleNamespace()
fake_traj._load_window_ids = lambda db: ["w1"]
aligned = {
"w1": {
"Alice": np.array([1.0, 0.0, 0.0]),
"Bob": np.array([-1.0, 0.5, 0.0]),
}
}
fake_traj._load_mp_vectors_for_window = lambda db, w: aligned.get(w, {})
fake_traj._procrustes_align_windows = lambda x: aligned
monkeypatch.setitem(sys.modules, "analysis.trajectory", fake_traj)
# Provide a minimal duckdb stub so importing analysis.political_axis succeeds
import types as _types
fake_conn = _types.SimpleNamespace(
execute=lambda q: _types.SimpleNamespace(fetchall=lambda: []),
close=lambda: None,
)
fake_duckdb = _types.SimpleNamespace(connect=lambda db_path, **kw: fake_conn)
monkeypatch.setitem(sys.modules, "duckdb", fake_duckdb)
from analysis.political_axis import compute_2d_axes
_, axis_def = compute_2d_axes(db_path="dummy", window_ids=["w1"], method="pca")
assert "global_mean" in axis_def
assert isinstance(axis_def["global_mean"], np.ndarray)

Loading…
Cancel
Save