You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
238 lines
7.8 KiB
238 lines
7.8 KiB
"""Tests for storing and loading scree plot (explained variance) data."""
|
|
|
|
import json
|
|
|
|
import pytest
|
|
|
|
duckdb = pytest.importorskip("duckdb")
|
|
np = pytest.importorskip("numpy")
|
|
|
|
|
|
def _setup_svd_vectors(db_path: str, rows: list):
|
|
"""Insert synthetic svd_vectors rows.
|
|
|
|
Args:
|
|
db_path: Path to DuckDB database.
|
|
rows: List of (window_id, entity_type, entity_id, vector_json_list, model).
|
|
"""
|
|
conn = duckdb.connect(db_path)
|
|
conn.execute(
|
|
"""
|
|
CREATE SEQUENCE IF NOT EXISTS svd_vectors_id_seq START 1;
|
|
CREATE TABLE IF NOT EXISTS svd_vectors (
|
|
id INTEGER DEFAULT nextval('svd_vectors_id_seq'),
|
|
window_id TEXT NOT NULL,
|
|
entity_type TEXT NOT NULL,
|
|
entity_id TEXT NOT NULL,
|
|
vector JSON NOT NULL,
|
|
model TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
PRIMARY KEY (id)
|
|
)
|
|
"""
|
|
)
|
|
for window_id, entity_type, entity_id, vector, model in rows:
|
|
conn.execute(
|
|
"INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector, model) VALUES (?, ?, ?, ?, ?)",
|
|
(window_id, entity_type, entity_id, json.dumps(vector), model),
|
|
)
|
|
conn.close()
|
|
|
|
|
|
class TestLoadScreeData:
|
|
def test_load_scree_data_returns_empty_when_no_metadata(self, tmp_path):
|
|
db_path = str(tmp_path / "test.db")
|
|
conn = duckdb.connect(db_path)
|
|
conn.execute(
|
|
"""
|
|
CREATE SEQUENCE IF NOT EXISTS svd_vectors_id_seq START 1;
|
|
CREATE TABLE IF NOT EXISTS svd_vectors (
|
|
id INTEGER DEFAULT nextval('svd_vectors_id_seq'),
|
|
window_id TEXT NOT NULL,
|
|
entity_type TEXT NOT NULL,
|
|
entity_id TEXT NOT NULL,
|
|
vector JSON NOT NULL,
|
|
model TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
PRIMARY KEY (id)
|
|
)
|
|
"""
|
|
)
|
|
conn.close()
|
|
|
|
from analysis.explorer_data import load_scree_data
|
|
|
|
result = load_scree_data(db_path)
|
|
assert result == []
|
|
|
|
def test_load_scree_data_reads_metadata_row(self, tmp_path):
|
|
db_path = str(tmp_path / "test.db")
|
|
_setup_svd_vectors(
|
|
db_path,
|
|
[
|
|
(
|
|
"current_parliament",
|
|
"metadata",
|
|
"explained_variance",
|
|
[0.45, 0.25, 0.15, 0.08, 0.04, 0.02, 0.01],
|
|
None,
|
|
),
|
|
],
|
|
)
|
|
|
|
from analysis.explorer_data import load_scree_data
|
|
|
|
result = load_scree_data(db_path)
|
|
assert result == pytest.approx([0.45, 0.25, 0.15, 0.08, 0.04, 0.02, 0.01])
|
|
|
|
def test_load_scree_data_ignores_other_windows(self, tmp_path):
|
|
db_path = str(tmp_path / "test.db")
|
|
_setup_svd_vectors(
|
|
db_path,
|
|
[
|
|
(
|
|
"current_parliament",
|
|
"metadata",
|
|
"explained_variance",
|
|
[0.45, 0.25],
|
|
None,
|
|
),
|
|
(
|
|
"2024",
|
|
"metadata",
|
|
"explained_variance",
|
|
[0.40, 0.30],
|
|
None,
|
|
),
|
|
],
|
|
)
|
|
|
|
from analysis.explorer_data import load_scree_data
|
|
|
|
result = load_scree_data(db_path)
|
|
assert result == pytest.approx([0.45, 0.25])
|
|
|
|
|
|
class TestComputeSvdForWindow:
|
|
def test_returns_explained_variance(self, tmp_path):
|
|
db_path = str(tmp_path / "test.db")
|
|
from database import MotionDatabase
|
|
|
|
db = MotionDatabase(db_path)
|
|
|
|
# Insert minimal motion data so SVD can run
|
|
conn = duckdb.connect(db_path)
|
|
for mid in range(5):
|
|
conn.execute(
|
|
"INSERT INTO motions (id, title, policy_area, voting_results) VALUES (?, ?, ?, ?)",
|
|
(mid, f"Motion {mid}", "Test", "[]"),
|
|
)
|
|
for name in ["MP A", "MP B", "MP C"]:
|
|
conn.execute(
|
|
"INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)",
|
|
(name, "Party"),
|
|
)
|
|
votes = [
|
|
(0, "MP A", "Voor"),
|
|
(0, "MP B", "Tegen"),
|
|
(0, "MP C", "Voor"),
|
|
(1, "MP A", "Tegen"),
|
|
(1, "MP B", "Voor"),
|
|
(1, "MP C", "Tegen"),
|
|
(2, "MP A", "Voor"),
|
|
(2, "MP B", "Voor"),
|
|
(2, "MP C", "Voor"),
|
|
(3, "MP A", "Tegen"),
|
|
(3, "MP B", "Tegen"),
|
|
(3, "MP C", "Tegen"),
|
|
(4, "MP A", "Voor"),
|
|
(4, "MP B", "Geen stem"),
|
|
(4, "MP C", "Tegen"),
|
|
]
|
|
for mid, mp, vote in votes:
|
|
conn.execute(
|
|
"INSERT INTO mp_votes (motion_id, mp_name, vote, date) VALUES (?, ?, ?, ?)",
|
|
(mid, mp, vote, "2024-06-01"),
|
|
)
|
|
conn.close()
|
|
|
|
from pipeline.svd_pipeline import compute_svd_for_window
|
|
|
|
result = compute_svd_for_window(
|
|
db_path, "test_window", "2024-01-01", "2024-12-31", k=3
|
|
)
|
|
assert result["k_used"] > 0
|
|
assert "explained_variance" in result
|
|
ev = result["explained_variance"]
|
|
assert isinstance(ev, list)
|
|
assert len(ev) == result["k_used"]
|
|
assert all(isinstance(v, float) for v in ev)
|
|
assert sum(ev) > 0.99 # Should sum to ~1.0 (or >0.99 due to rounding)
|
|
|
|
|
|
class TestPipelineStoresScreeData:
|
|
def test_run_pipeline_includes_explained_variance_row(self, tmp_path):
|
|
db_path = str(tmp_path / "test.db")
|
|
from database import MotionDatabase
|
|
|
|
db = MotionDatabase(db_path)
|
|
|
|
# Insert minimal data using the actual schema
|
|
conn = duckdb.connect(db_path)
|
|
for mid in range(5):
|
|
conn.execute(
|
|
"INSERT INTO motions (id, title, policy_area, voting_results) VALUES (?, ?, ?, ?)",
|
|
(mid, f"Motion {mid}", "Test", "[]"),
|
|
)
|
|
for name in ["MP A", "MP B", "MP C"]:
|
|
conn.execute(
|
|
"INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)",
|
|
(name, "Party"),
|
|
)
|
|
votes = [
|
|
(0, "MP A", "Voor"),
|
|
(0, "MP B", "Tegen"),
|
|
(0, "MP C", "Voor"),
|
|
(1, "MP A", "Tegen"),
|
|
(1, "MP B", "Voor"),
|
|
(1, "MP C", "Tegen"),
|
|
(2, "MP A", "Voor"),
|
|
(2, "MP B", "Voor"),
|
|
(2, "MP C", "Voor"),
|
|
(3, "MP A", "Tegen"),
|
|
(3, "MP B", "Tegen"),
|
|
(3, "MP C", "Tegen"),
|
|
(4, "MP A", "Voor"),
|
|
(4, "MP B", "Geen stem"),
|
|
(4, "MP C", "Tegen"),
|
|
]
|
|
for mid, mp, vote in votes:
|
|
conn.execute(
|
|
"INSERT INTO mp_votes (motion_id, mp_name, vote, date) VALUES (?, ?, ?, ?)",
|
|
(mid, mp, vote, "2024-06-01"),
|
|
)
|
|
conn.close()
|
|
|
|
from pipeline.svd_pipeline import compute_svd_for_window
|
|
|
|
result = compute_svd_for_window(
|
|
db_path, "test_window", "2024-01-01", "2024-12-31", k=3
|
|
)
|
|
assert "explained_variance" in result
|
|
ev = result["explained_variance"]
|
|
assert isinstance(ev, list) and len(ev) > 0
|
|
|
|
# Verify the rows can include a metadata row
|
|
rows = result["mp_rows"] + result["motion_rows"]
|
|
metadata_rows = [r for r in rows if r[0] == "metadata" and r[1] == "explained_variance"]
|
|
assert len(metadata_rows) == 1
|
|
assert metadata_rows[0][2] == ev
|
|
|
|
# Verify storing works (use current_parliament so load_scree_data finds it)
|
|
db.batch_store_svd_vectors("current_parliament", rows)
|
|
|
|
# Verify loading works
|
|
from analysis.explorer_data import load_scree_data
|
|
|
|
loaded = load_scree_data(db_path)
|
|
assert loaded == pytest.approx(ev)
|
|
|