You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
motief/tests/test_scree_data.py

238 lines
7.8 KiB

"""Tests for storing and loading scree plot (explained variance) data."""
import json
import pytest
duckdb = pytest.importorskip("duckdb")
np = pytest.importorskip("numpy")
def _setup_svd_vectors(db_path: str, rows: list):
"""Insert synthetic svd_vectors rows.
Args:
db_path: Path to DuckDB database.
rows: List of (window_id, entity_type, entity_id, vector_json_list, model).
"""
conn = duckdb.connect(db_path)
conn.execute(
"""
CREATE SEQUENCE IF NOT EXISTS svd_vectors_id_seq START 1;
CREATE TABLE IF NOT EXISTS svd_vectors (
id INTEGER DEFAULT nextval('svd_vectors_id_seq'),
window_id TEXT NOT NULL,
entity_type TEXT NOT NULL,
entity_id TEXT NOT NULL,
vector JSON NOT NULL,
model TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
)
"""
)
for window_id, entity_type, entity_id, vector, model in rows:
conn.execute(
"INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector, model) VALUES (?, ?, ?, ?, ?)",
(window_id, entity_type, entity_id, json.dumps(vector), model),
)
conn.close()
class TestLoadScreeData:
def test_load_scree_data_returns_empty_when_no_metadata(self, tmp_path):
db_path = str(tmp_path / "test.db")
conn = duckdb.connect(db_path)
conn.execute(
"""
CREATE SEQUENCE IF NOT EXISTS svd_vectors_id_seq START 1;
CREATE TABLE IF NOT EXISTS svd_vectors (
id INTEGER DEFAULT nextval('svd_vectors_id_seq'),
window_id TEXT NOT NULL,
entity_type TEXT NOT NULL,
entity_id TEXT NOT NULL,
vector JSON NOT NULL,
model TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
)
"""
)
conn.close()
from analysis.explorer_data import load_scree_data
result = load_scree_data(db_path)
assert result == []
def test_load_scree_data_reads_metadata_row(self, tmp_path):
db_path = str(tmp_path / "test.db")
_setup_svd_vectors(
db_path,
[
(
"current_parliament",
"metadata",
"explained_variance",
[0.45, 0.25, 0.15, 0.08, 0.04, 0.02, 0.01],
None,
),
],
)
from analysis.explorer_data import load_scree_data
result = load_scree_data(db_path)
assert result == pytest.approx([0.45, 0.25, 0.15, 0.08, 0.04, 0.02, 0.01])
def test_load_scree_data_ignores_other_windows(self, tmp_path):
db_path = str(tmp_path / "test.db")
_setup_svd_vectors(
db_path,
[
(
"current_parliament",
"metadata",
"explained_variance",
[0.45, 0.25],
None,
),
(
"2024",
"metadata",
"explained_variance",
[0.40, 0.30],
None,
),
],
)
from analysis.explorer_data import load_scree_data
result = load_scree_data(db_path)
assert result == pytest.approx([0.45, 0.25])
class TestComputeSvdForWindow:
def test_returns_explained_variance(self, tmp_path):
db_path = str(tmp_path / "test.db")
from database import MotionDatabase
db = MotionDatabase(db_path)
# Insert minimal motion data so SVD can run
conn = duckdb.connect(db_path)
for mid in range(5):
conn.execute(
"INSERT INTO motions (id, title, policy_area, voting_results) VALUES (?, ?, ?, ?)",
(mid, f"Motion {mid}", "Test", "[]"),
)
for name in ["MP A", "MP B", "MP C"]:
conn.execute(
"INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)",
(name, "Party"),
)
votes = [
(0, "MP A", "Voor"),
(0, "MP B", "Tegen"),
(0, "MP C", "Voor"),
(1, "MP A", "Tegen"),
(1, "MP B", "Voor"),
(1, "MP C", "Tegen"),
(2, "MP A", "Voor"),
(2, "MP B", "Voor"),
(2, "MP C", "Voor"),
(3, "MP A", "Tegen"),
(3, "MP B", "Tegen"),
(3, "MP C", "Tegen"),
(4, "MP A", "Voor"),
(4, "MP B", "Geen stem"),
(4, "MP C", "Tegen"),
]
for mid, mp, vote in votes:
conn.execute(
"INSERT INTO mp_votes (motion_id, mp_name, vote, date) VALUES (?, ?, ?, ?)",
(mid, mp, vote, "2024-06-01"),
)
conn.close()
from pipeline.svd_pipeline import compute_svd_for_window
result = compute_svd_for_window(
db_path, "test_window", "2024-01-01", "2024-12-31", k=3
)
assert result["k_used"] > 0
assert "explained_variance" in result
ev = result["explained_variance"]
assert isinstance(ev, list)
assert len(ev) == result["k_used"]
assert all(isinstance(v, float) for v in ev)
assert sum(ev) > 0.99 # Should sum to ~1.0 (or >0.99 due to rounding)
class TestPipelineStoresScreeData:
def test_run_pipeline_includes_explained_variance_row(self, tmp_path):
db_path = str(tmp_path / "test.db")
from database import MotionDatabase
db = MotionDatabase(db_path)
# Insert minimal data using the actual schema
conn = duckdb.connect(db_path)
for mid in range(5):
conn.execute(
"INSERT INTO motions (id, title, policy_area, voting_results) VALUES (?, ?, ?, ?)",
(mid, f"Motion {mid}", "Test", "[]"),
)
for name in ["MP A", "MP B", "MP C"]:
conn.execute(
"INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)",
(name, "Party"),
)
votes = [
(0, "MP A", "Voor"),
(0, "MP B", "Tegen"),
(0, "MP C", "Voor"),
(1, "MP A", "Tegen"),
(1, "MP B", "Voor"),
(1, "MP C", "Tegen"),
(2, "MP A", "Voor"),
(2, "MP B", "Voor"),
(2, "MP C", "Voor"),
(3, "MP A", "Tegen"),
(3, "MP B", "Tegen"),
(3, "MP C", "Tegen"),
(4, "MP A", "Voor"),
(4, "MP B", "Geen stem"),
(4, "MP C", "Tegen"),
]
for mid, mp, vote in votes:
conn.execute(
"INSERT INTO mp_votes (motion_id, mp_name, vote, date) VALUES (?, ?, ?, ?)",
(mid, mp, vote, "2024-06-01"),
)
conn.close()
from pipeline.svd_pipeline import compute_svd_for_window
result = compute_svd_for_window(
db_path, "test_window", "2024-01-01", "2024-12-31", k=3
)
assert "explained_variance" in result
ev = result["explained_variance"]
assert isinstance(ev, list) and len(ev) > 0
# Verify the rows can include a metadata row
rows = result["mp_rows"] + result["motion_rows"]
metadata_rows = [r for r in rows if r[0] == "metadata" and r[1] == "explained_variance"]
assert len(metadata_rows) == 1
assert metadata_rows[0][2] == ev
# Verify storing works (use current_parliament so load_scree_data finds it)
db.batch_store_svd_vectors("current_parliament", rows)
# Verify loading works
from analysis.explorer_data import load_scree_data
loaded = load_scree_data(db_path)
assert loaded == pytest.approx(ev)