You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/tests/test_analysis.py

195 lines
6.7 KiB

"""Tests for analysis modules: political_axis, trajectory, clustering."""
import json
import numpy as np
import pytest
duckdb = pytest.importorskip("duckdb")
# ── Helpers ──────────────────────────────────────────────────────────────────
def _setup_svd_vectors(db_path: str, window_ids_mp_vecs: dict):
"""Insert synthetic MP SVD vectors into svd_vectors table.
window_ids_mp_vecs: {window_id: {mp_name: np.ndarray}}
"""
conn = duckdb.connect(db_path)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS svd_vectors (
id INTEGER,
window_id TEXT,
entity_type TEXT,
entity_id TEXT,
vector JSON,
model TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
for wid, mp_vecs in window_ids_mp_vecs.items():
for mp_name, vec in mp_vecs.items():
conn.execute(
"INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector, model) VALUES (?, 'mp', ?, ?, 'test')",
(wid, mp_name, json.dumps(vec.tolist())),
)
conn.close()
def _setup_mp_metadata(db_path: str, mp_party: dict):
"""Insert synthetic MP metadata rows."""
conn = duckdb.connect(db_path)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS mp_metadata (
mp_name TEXT,
party TEXT,
van DATE,
tot_en_met DATE,
persoon_id TEXT
)
"""
)
for mp_name, party in mp_party.items():
conn.execute(
"INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)",
(mp_name, party),
)
conn.close()
# ── political_axis ────────────────────────────────────────────────────────────
class TestPoliticalAxis:
def test_pca_axis_basic(self, tmp_path):
np.random.seed(42)
db_path = str(tmp_path / "test.db")
n_mps, k = 20, 5
# Create a low-rank set of MP vectors (they should have a clear first PC)
vecs = np.random.randn(n_mps, k)
mp_names = [f"MP_{i}" for i in range(n_mps)]
_setup_svd_vectors(
db_path, {"2024-Q1": {mp_names[i]: vecs[i] for i in range(n_mps)}}
)
from analysis.political_axis import compute_pca_axis
scores = compute_pca_axis(db_path, "2024-Q1")
assert len(scores) == n_mps
assert all(isinstance(v, float) for v in scores.values())
# Scores should have non-trivial variance
vals = list(scores.values())
assert np.std(vals) > 0.0
def test_pca_axis_too_few_mps(self, tmp_path):
db_path = str(tmp_path / "test.db")
_setup_svd_vectors(db_path, {"w1": {"MP_A": np.array([1.0, 0.0])}})
from analysis.political_axis import compute_pca_axis
scores = compute_pca_axis(db_path, "w1")
assert scores == {}
def test_anchor_axis_basic(self, tmp_path):
db_path = str(tmp_path / "test.db")
# Two clusters clearly separated on dim 0
left_vec = np.array([-2.0, 0.0, 0.0])
right_vec = np.array([2.0, 0.0, 0.0])
mp_vecs = {
"Left_A": left_vec + np.array([0.1, 0.0, 0.0]),
"Left_B": left_vec - np.array([0.1, 0.0, 0.0]),
"Right_A": right_vec + np.array([0.1, 0.0, 0.0]),
"Right_B": right_vec - np.array([0.1, 0.0, 0.0]),
"Centre": np.array([0.0, 0.0, 0.0]),
}
_setup_svd_vectors(db_path, {"w1": mp_vecs})
_setup_mp_metadata(
db_path,
{
"Left_A": "SP",
"Left_B": "SP",
"Right_A": "VVD",
"Right_B": "VVD",
"Centre": "D66",
},
)
from analysis.political_axis import compute_anchor_axis
scores = compute_anchor_axis(
db_path, "w1", left_parties=["SP"], right_parties=["VVD"]
)
assert len(scores) == 5
# Left MPs should have negative scores, Right MPs positive
assert scores["Left_A"] < scores["Right_A"]
assert scores["Left_B"] < scores["Right_B"]
# ── trajectory ───────────────────────────────────────────────────────────────
class TestTrajectory:
def test_basic_trajectory(self, tmp_path):
np.random.seed(0)
db_path = str(tmp_path / "test.db")
vec_w1 = {"MP_A": np.array([1.0, 0.0]), "MP_B": np.array([0.0, 1.0])}
vec_w2 = {
"MP_A": np.array([1.5, 0.5]),
"MP_B": np.array([0.0, 1.0]),
"MP_C": np.array([2.0, 2.0]),
}
_setup_svd_vectors(db_path, {"2024-Q1": vec_w1, "2024-Q2": vec_w2})
from analysis.trajectory import compute_trajectories, top_drifters
traj = compute_trajectories(db_path)
# Only MPs appearing in >= 2 windows
assert "MP_A" in traj
assert "MP_B" in traj
assert "MP_C" not in traj # only in one window
assert len(traj["MP_A"]["drift"]) == 1
assert traj["MP_A"]["total_drift"] > 0.0
# MP_B didn't move — drift should be 0
assert traj["MP_B"]["total_drift"] == pytest.approx(0.0)
drifters = top_drifters(traj, n=5)
assert drifters[0]["mp_name"] == "MP_A"
def test_fewer_than_2_windows(self, tmp_path):
db_path = str(tmp_path / "test.db")
_setup_svd_vectors(db_path, {"2024-Q1": {"MP_A": np.array([1.0, 2.0])}})
from analysis.trajectory import compute_trajectories
traj = compute_trajectories(db_path)
assert traj == {}
# ── clustering ────────────────────────────────────────────────────────────────
class TestClustering:
def test_cluster_kmeans_basic(self):
from analysis.clustering import cluster_kmeans
import numpy as np
coords = np.random.randn(20, 2)
labels = cluster_kmeans(coords, n_clusters=3)
assert len(labels) == 20
assert set(labels).issubset({0, 1, 2})
def test_cluster_kmeans_fewer_points_than_clusters(self):
from analysis.clustering import cluster_kmeans
coords = np.array([[0.0, 0.0], [1.0, 1.0]])
labels = cluster_kmeans(coords, n_clusters=5)
# Should not crash; n_clusters clamped to len(coords)
assert len(labels) == 2