- compute_svd_for_window now computes explained variance ratio (s²/sum(s²)) and appends it as a metadata row (entity_type='metadata', entity_id='explained_variance') to motion_rows - load_scree_data reads this metadata row from svd_vectors instead of querying the non-existent sv_metadata column - run_svd_for_window counts only entity_type='motion' rows in stored_motion so metadata rows don't inflate the count - Added 5 TDD tests covering load, compute, store, and round-trip All 227 tests pass.main
parent
121c32ae8a
commit
6e36fa2604
@ -0,0 +1,238 @@ |
|||||||
|
"""Tests for storing and loading scree plot (explained variance) data.""" |
||||||
|
|
||||||
|
import json |
||||||
|
|
||||||
|
import pytest |
||||||
|
|
||||||
|
duckdb = pytest.importorskip("duckdb") |
||||||
|
np = pytest.importorskip("numpy") |
||||||
|
|
||||||
|
|
||||||
|
def _setup_svd_vectors(db_path: str, rows: list): |
||||||
|
"""Insert synthetic svd_vectors rows. |
||||||
|
|
||||||
|
Args: |
||||||
|
db_path: Path to DuckDB database. |
||||||
|
rows: List of (window_id, entity_type, entity_id, vector_json_list, model). |
||||||
|
""" |
||||||
|
conn = duckdb.connect(db_path) |
||||||
|
conn.execute( |
||||||
|
""" |
||||||
|
CREATE SEQUENCE IF NOT EXISTS svd_vectors_id_seq START 1; |
||||||
|
CREATE TABLE IF NOT EXISTS svd_vectors ( |
||||||
|
id INTEGER DEFAULT nextval('svd_vectors_id_seq'), |
||||||
|
window_id TEXT NOT NULL, |
||||||
|
entity_type TEXT NOT NULL, |
||||||
|
entity_id TEXT NOT NULL, |
||||||
|
vector JSON NOT NULL, |
||||||
|
model TEXT, |
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
||||||
|
PRIMARY KEY (id) |
||||||
|
) |
||||||
|
""" |
||||||
|
) |
||||||
|
for window_id, entity_type, entity_id, vector, model in rows: |
||||||
|
conn.execute( |
||||||
|
"INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector, model) VALUES (?, ?, ?, ?, ?)", |
||||||
|
(window_id, entity_type, entity_id, json.dumps(vector), model), |
||||||
|
) |
||||||
|
conn.close() |
||||||
|
|
||||||
|
|
||||||
|
class TestLoadScreeData: |
||||||
|
def test_load_scree_data_returns_empty_when_no_metadata(self, tmp_path): |
||||||
|
db_path = str(tmp_path / "test.db") |
||||||
|
conn = duckdb.connect(db_path) |
||||||
|
conn.execute( |
||||||
|
""" |
||||||
|
CREATE SEQUENCE IF NOT EXISTS svd_vectors_id_seq START 1; |
||||||
|
CREATE TABLE IF NOT EXISTS svd_vectors ( |
||||||
|
id INTEGER DEFAULT nextval('svd_vectors_id_seq'), |
||||||
|
window_id TEXT NOT NULL, |
||||||
|
entity_type TEXT NOT NULL, |
||||||
|
entity_id TEXT NOT NULL, |
||||||
|
vector JSON NOT NULL, |
||||||
|
model TEXT, |
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
||||||
|
PRIMARY KEY (id) |
||||||
|
) |
||||||
|
""" |
||||||
|
) |
||||||
|
conn.close() |
||||||
|
|
||||||
|
from analysis.explorer_data import load_scree_data |
||||||
|
|
||||||
|
result = load_scree_data(db_path) |
||||||
|
assert result == [] |
||||||
|
|
||||||
|
def test_load_scree_data_reads_metadata_row(self, tmp_path): |
||||||
|
db_path = str(tmp_path / "test.db") |
||||||
|
_setup_svd_vectors( |
||||||
|
db_path, |
||||||
|
[ |
||||||
|
( |
||||||
|
"current_parliament", |
||||||
|
"metadata", |
||||||
|
"explained_variance", |
||||||
|
[0.45, 0.25, 0.15, 0.08, 0.04, 0.02, 0.01], |
||||||
|
None, |
||||||
|
), |
||||||
|
], |
||||||
|
) |
||||||
|
|
||||||
|
from analysis.explorer_data import load_scree_data |
||||||
|
|
||||||
|
result = load_scree_data(db_path) |
||||||
|
assert result == pytest.approx([0.45, 0.25, 0.15, 0.08, 0.04, 0.02, 0.01]) |
||||||
|
|
||||||
|
def test_load_scree_data_ignores_other_windows(self, tmp_path): |
||||||
|
db_path = str(tmp_path / "test.db") |
||||||
|
_setup_svd_vectors( |
||||||
|
db_path, |
||||||
|
[ |
||||||
|
( |
||||||
|
"current_parliament", |
||||||
|
"metadata", |
||||||
|
"explained_variance", |
||||||
|
[0.45, 0.25], |
||||||
|
None, |
||||||
|
), |
||||||
|
( |
||||||
|
"2024", |
||||||
|
"metadata", |
||||||
|
"explained_variance", |
||||||
|
[0.40, 0.30], |
||||||
|
None, |
||||||
|
), |
||||||
|
], |
||||||
|
) |
||||||
|
|
||||||
|
from analysis.explorer_data import load_scree_data |
||||||
|
|
||||||
|
result = load_scree_data(db_path) |
||||||
|
assert result == pytest.approx([0.45, 0.25]) |
||||||
|
|
||||||
|
|
||||||
|
class TestComputeSvdForWindow: |
||||||
|
def test_returns_explained_variance(self, tmp_path): |
||||||
|
db_path = str(tmp_path / "test.db") |
||||||
|
from database import MotionDatabase |
||||||
|
|
||||||
|
db = MotionDatabase(db_path) |
||||||
|
|
||||||
|
# Insert minimal motion data so SVD can run |
||||||
|
conn = duckdb.connect(db_path) |
||||||
|
for mid in range(5): |
||||||
|
conn.execute( |
||||||
|
"INSERT INTO motions (id, title, policy_area, voting_results) VALUES (?, ?, ?, ?)", |
||||||
|
(mid, f"Motion {mid}", "Test", "[]"), |
||||||
|
) |
||||||
|
for name in ["MP A", "MP B", "MP C"]: |
||||||
|
conn.execute( |
||||||
|
"INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)", |
||||||
|
(name, "Party"), |
||||||
|
) |
||||||
|
votes = [ |
||||||
|
(0, "MP A", "Voor"), |
||||||
|
(0, "MP B", "Tegen"), |
||||||
|
(0, "MP C", "Voor"), |
||||||
|
(1, "MP A", "Tegen"), |
||||||
|
(1, "MP B", "Voor"), |
||||||
|
(1, "MP C", "Tegen"), |
||||||
|
(2, "MP A", "Voor"), |
||||||
|
(2, "MP B", "Voor"), |
||||||
|
(2, "MP C", "Voor"), |
||||||
|
(3, "MP A", "Tegen"), |
||||||
|
(3, "MP B", "Tegen"), |
||||||
|
(3, "MP C", "Tegen"), |
||||||
|
(4, "MP A", "Voor"), |
||||||
|
(4, "MP B", "Geen stem"), |
||||||
|
(4, "MP C", "Tegen"), |
||||||
|
] |
||||||
|
for mid, mp, vote in votes: |
||||||
|
conn.execute( |
||||||
|
"INSERT INTO mp_votes (motion_id, mp_name, vote, date) VALUES (?, ?, ?, ?)", |
||||||
|
(mid, mp, vote, "2024-06-01"), |
||||||
|
) |
||||||
|
conn.close() |
||||||
|
|
||||||
|
from pipeline.svd_pipeline import compute_svd_for_window |
||||||
|
|
||||||
|
result = compute_svd_for_window( |
||||||
|
db_path, "test_window", "2024-01-01", "2024-12-31", k=3 |
||||||
|
) |
||||||
|
assert result["k_used"] > 0 |
||||||
|
assert "explained_variance" in result |
||||||
|
ev = result["explained_variance"] |
||||||
|
assert isinstance(ev, list) |
||||||
|
assert len(ev) == result["k_used"] |
||||||
|
assert all(isinstance(v, float) for v in ev) |
||||||
|
assert sum(ev) > 0.99 # Should sum to ~1.0 (or >0.99 due to rounding) |
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineStoresScreeData: |
||||||
|
def test_run_pipeline_includes_explained_variance_row(self, tmp_path): |
||||||
|
db_path = str(tmp_path / "test.db") |
||||||
|
from database import MotionDatabase |
||||||
|
|
||||||
|
db = MotionDatabase(db_path) |
||||||
|
|
||||||
|
# Insert minimal data using the actual schema |
||||||
|
conn = duckdb.connect(db_path) |
||||||
|
for mid in range(5): |
||||||
|
conn.execute( |
||||||
|
"INSERT INTO motions (id, title, policy_area, voting_results) VALUES (?, ?, ?, ?)", |
||||||
|
(mid, f"Motion {mid}", "Test", "[]"), |
||||||
|
) |
||||||
|
for name in ["MP A", "MP B", "MP C"]: |
||||||
|
conn.execute( |
||||||
|
"INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)", |
||||||
|
(name, "Party"), |
||||||
|
) |
||||||
|
votes = [ |
||||||
|
(0, "MP A", "Voor"), |
||||||
|
(0, "MP B", "Tegen"), |
||||||
|
(0, "MP C", "Voor"), |
||||||
|
(1, "MP A", "Tegen"), |
||||||
|
(1, "MP B", "Voor"), |
||||||
|
(1, "MP C", "Tegen"), |
||||||
|
(2, "MP A", "Voor"), |
||||||
|
(2, "MP B", "Voor"), |
||||||
|
(2, "MP C", "Voor"), |
||||||
|
(3, "MP A", "Tegen"), |
||||||
|
(3, "MP B", "Tegen"), |
||||||
|
(3, "MP C", "Tegen"), |
||||||
|
(4, "MP A", "Voor"), |
||||||
|
(4, "MP B", "Geen stem"), |
||||||
|
(4, "MP C", "Tegen"), |
||||||
|
] |
||||||
|
for mid, mp, vote in votes: |
||||||
|
conn.execute( |
||||||
|
"INSERT INTO mp_votes (motion_id, mp_name, vote, date) VALUES (?, ?, ?, ?)", |
||||||
|
(mid, mp, vote, "2024-06-01"), |
||||||
|
) |
||||||
|
conn.close() |
||||||
|
|
||||||
|
from pipeline.svd_pipeline import compute_svd_for_window |
||||||
|
|
||||||
|
result = compute_svd_for_window( |
||||||
|
db_path, "test_window", "2024-01-01", "2024-12-31", k=3 |
||||||
|
) |
||||||
|
assert "explained_variance" in result |
||||||
|
ev = result["explained_variance"] |
||||||
|
assert isinstance(ev, list) and len(ev) > 0 |
||||||
|
|
||||||
|
# Verify the rows can include a metadata row |
||||||
|
rows = result["mp_rows"] + result["motion_rows"] |
||||||
|
metadata_rows = [r for r in rows if r[0] == "metadata" and r[1] == "explained_variance"] |
||||||
|
assert len(metadata_rows) == 1 |
||||||
|
assert metadata_rows[0][2] == ev |
||||||
|
|
||||||
|
# Verify storing works (use current_parliament so load_scree_data finds it) |
||||||
|
db.batch_store_svd_vectors("current_parliament", rows) |
||||||
|
|
||||||
|
# Verify loading works |
||||||
|
from analysis.explorer_data import load_scree_data |
||||||
|
|
||||||
|
loaded = load_scree_data(db_path) |
||||||
|
assert loaded == pytest.approx(ev) |
||||||
Loading…
Reference in new issue