From 6e36fa2604722e8c4be12502ab6de33f6c285187 Mon Sep 17 00:00:00 2001 From: Sven Geboers Date: Fri, 1 May 2026 10:34:31 +0200 Subject: [PATCH] feat: persist and load explained variance for scree plots MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - compute_svd_for_window now computes explained variance ratio (s²/sum(s²)) and appends it as a metadata row (entity_type='metadata', entity_id='explained_variance') to motion_rows - load_scree_data reads this metadata row from svd_vectors instead of querying the non-existent sv_metadata column - run_svd_for_window counts only entity_type='motion' rows in stored_motion so metadata rows don't inflate the count - Added 5 TDD tests covering load, compute, store, and round-trip All 227 tests pass. --- analysis/explorer_data.py | 28 +++-- pipeline/svd_pipeline.py | 9 +- tests/test_scree_data.py | 238 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 267 insertions(+), 8 deletions(-) create mode 100644 tests/test_scree_data.py diff --git a/analysis/explorer_data.py b/analysis/explorer_data.py index 035b7ad..55da83f 100644 --- a/analysis/explorer_data.py +++ b/analysis/explorer_data.py @@ -346,14 +346,28 @@ def load_party_mp_vectors(db_path: str) -> Dict[str, List[np.ndarray]]: def load_scree_data(db_path: str) -> List[float]: - """Load scree plot data (explained variance) for current_parliament. + """Load scree plot data (explained variance) for current_parliament.""" + try: + con = duckdb.connect(database=db_path, read_only=True) + row = con.execute( + """ + SELECT vector FROM svd_vectors + WHERE window_id = 'current_parliament' + AND entity_type = 'metadata' + AND entity_id = 'explained_variance' + LIMIT 1 + """ + ).fetchone() + con.close() - TODO: Scree data requires SVD metadata (singular values / explained - variance ratios) to be stored in the database. Currently only - transformed vectors are stored in svd_vectors.vector, not the - decomposition metadata needed for a scree plot. - """ - return [] + if row and row[0]: + import json + + return json.loads(row[0]) + return [] + except Exception: + logger.exception("Failed to load scree data") + return [] def load_motions_df(db_path: str) -> pd.DataFrame: diff --git a/pipeline/svd_pipeline.py b/pipeline/svd_pipeline.py index 6392fbe..93f96ce 100644 --- a/pipeline/svd_pipeline.py +++ b/pipeline/svd_pipeline.py @@ -409,11 +409,16 @@ def compute_svd_for_window( for j, mid in enumerate(motion_ids) ] + # Persist explained variance ratio as a metadata row for scree plots + evr = (s ** 2 / np.sum(s ** 2)).tolist() + motion_rows.append(("metadata", "explained_variance", evr, None)) + return { "window_id": window_id, "k_used": k_used, "mp_rows": mp_rows, "motion_rows": motion_rows, + "explained_variance": evr, } except Exception: @@ -438,8 +443,10 @@ def run_svd_for_window( rows = result["mp_rows"] + result["motion_rows"] stored = db.batch_store_svd_vectors(window_id, rows) + # motion_rows may include metadata rows (e.g. explained_variance) + motion_entity_rows = [r for r in result["motion_rows"] if r[0] == "motion"] return { "k_used": result["k_used"], "stored_mp": len(result["mp_rows"]), - "stored_motion": len(result["motion_rows"]), + "stored_motion": len(motion_entity_rows), } diff --git a/tests/test_scree_data.py b/tests/test_scree_data.py new file mode 100644 index 0000000..e693799 --- /dev/null +++ b/tests/test_scree_data.py @@ -0,0 +1,238 @@ +"""Tests for storing and loading scree plot (explained variance) data.""" + +import json + +import pytest + +duckdb = pytest.importorskip("duckdb") +np = pytest.importorskip("numpy") + + +def _setup_svd_vectors(db_path: str, rows: list): + """Insert synthetic svd_vectors rows. + + Args: + db_path: Path to DuckDB database. + rows: List of (window_id, entity_type, entity_id, vector_json_list, model). + """ + conn = duckdb.connect(db_path) + conn.execute( + """ + CREATE SEQUENCE IF NOT EXISTS svd_vectors_id_seq START 1; + CREATE TABLE IF NOT EXISTS svd_vectors ( + id INTEGER DEFAULT nextval('svd_vectors_id_seq'), + window_id TEXT NOT NULL, + entity_type TEXT NOT NULL, + entity_id TEXT NOT NULL, + vector JSON NOT NULL, + model TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id) + ) + """ + ) + for window_id, entity_type, entity_id, vector, model in rows: + conn.execute( + "INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector, model) VALUES (?, ?, ?, ?, ?)", + (window_id, entity_type, entity_id, json.dumps(vector), model), + ) + conn.close() + + +class TestLoadScreeData: + def test_load_scree_data_returns_empty_when_no_metadata(self, tmp_path): + db_path = str(tmp_path / "test.db") + conn = duckdb.connect(db_path) + conn.execute( + """ + CREATE SEQUENCE IF NOT EXISTS svd_vectors_id_seq START 1; + CREATE TABLE IF NOT EXISTS svd_vectors ( + id INTEGER DEFAULT nextval('svd_vectors_id_seq'), + window_id TEXT NOT NULL, + entity_type TEXT NOT NULL, + entity_id TEXT NOT NULL, + vector JSON NOT NULL, + model TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id) + ) + """ + ) + conn.close() + + from analysis.explorer_data import load_scree_data + + result = load_scree_data(db_path) + assert result == [] + + def test_load_scree_data_reads_metadata_row(self, tmp_path): + db_path = str(tmp_path / "test.db") + _setup_svd_vectors( + db_path, + [ + ( + "current_parliament", + "metadata", + "explained_variance", + [0.45, 0.25, 0.15, 0.08, 0.04, 0.02, 0.01], + None, + ), + ], + ) + + from analysis.explorer_data import load_scree_data + + result = load_scree_data(db_path) + assert result == pytest.approx([0.45, 0.25, 0.15, 0.08, 0.04, 0.02, 0.01]) + + def test_load_scree_data_ignores_other_windows(self, tmp_path): + db_path = str(tmp_path / "test.db") + _setup_svd_vectors( + db_path, + [ + ( + "current_parliament", + "metadata", + "explained_variance", + [0.45, 0.25], + None, + ), + ( + "2024", + "metadata", + "explained_variance", + [0.40, 0.30], + None, + ), + ], + ) + + from analysis.explorer_data import load_scree_data + + result = load_scree_data(db_path) + assert result == pytest.approx([0.45, 0.25]) + + +class TestComputeSvdForWindow: + def test_returns_explained_variance(self, tmp_path): + db_path = str(tmp_path / "test.db") + from database import MotionDatabase + + db = MotionDatabase(db_path) + + # Insert minimal motion data so SVD can run + conn = duckdb.connect(db_path) + for mid in range(5): + conn.execute( + "INSERT INTO motions (id, title, policy_area, voting_results) VALUES (?, ?, ?, ?)", + (mid, f"Motion {mid}", "Test", "[]"), + ) + for name in ["MP A", "MP B", "MP C"]: + conn.execute( + "INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)", + (name, "Party"), + ) + votes = [ + (0, "MP A", "Voor"), + (0, "MP B", "Tegen"), + (0, "MP C", "Voor"), + (1, "MP A", "Tegen"), + (1, "MP B", "Voor"), + (1, "MP C", "Tegen"), + (2, "MP A", "Voor"), + (2, "MP B", "Voor"), + (2, "MP C", "Voor"), + (3, "MP A", "Tegen"), + (3, "MP B", "Tegen"), + (3, "MP C", "Tegen"), + (4, "MP A", "Voor"), + (4, "MP B", "Geen stem"), + (4, "MP C", "Tegen"), + ] + for mid, mp, vote in votes: + conn.execute( + "INSERT INTO mp_votes (motion_id, mp_name, vote, date) VALUES (?, ?, ?, ?)", + (mid, mp, vote, "2024-06-01"), + ) + conn.close() + + from pipeline.svd_pipeline import compute_svd_for_window + + result = compute_svd_for_window( + db_path, "test_window", "2024-01-01", "2024-12-31", k=3 + ) + assert result["k_used"] > 0 + assert "explained_variance" in result + ev = result["explained_variance"] + assert isinstance(ev, list) + assert len(ev) == result["k_used"] + assert all(isinstance(v, float) for v in ev) + assert sum(ev) > 0.99 # Should sum to ~1.0 (or >0.99 due to rounding) + + +class TestPipelineStoresScreeData: + def test_run_pipeline_includes_explained_variance_row(self, tmp_path): + db_path = str(tmp_path / "test.db") + from database import MotionDatabase + + db = MotionDatabase(db_path) + + # Insert minimal data using the actual schema + conn = duckdb.connect(db_path) + for mid in range(5): + conn.execute( + "INSERT INTO motions (id, title, policy_area, voting_results) VALUES (?, ?, ?, ?)", + (mid, f"Motion {mid}", "Test", "[]"), + ) + for name in ["MP A", "MP B", "MP C"]: + conn.execute( + "INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)", + (name, "Party"), + ) + votes = [ + (0, "MP A", "Voor"), + (0, "MP B", "Tegen"), + (0, "MP C", "Voor"), + (1, "MP A", "Tegen"), + (1, "MP B", "Voor"), + (1, "MP C", "Tegen"), + (2, "MP A", "Voor"), + (2, "MP B", "Voor"), + (2, "MP C", "Voor"), + (3, "MP A", "Tegen"), + (3, "MP B", "Tegen"), + (3, "MP C", "Tegen"), + (4, "MP A", "Voor"), + (4, "MP B", "Geen stem"), + (4, "MP C", "Tegen"), + ] + for mid, mp, vote in votes: + conn.execute( + "INSERT INTO mp_votes (motion_id, mp_name, vote, date) VALUES (?, ?, ?, ?)", + (mid, mp, vote, "2024-06-01"), + ) + conn.close() + + from pipeline.svd_pipeline import compute_svd_for_window + + result = compute_svd_for_window( + db_path, "test_window", "2024-01-01", "2024-12-31", k=3 + ) + assert "explained_variance" in result + ev = result["explained_variance"] + assert isinstance(ev, list) and len(ev) > 0 + + # Verify the rows can include a metadata row + rows = result["mp_rows"] + result["motion_rows"] + metadata_rows = [r for r in rows if r[0] == "metadata" and r[1] == "explained_variance"] + assert len(metadata_rows) == 1 + assert metadata_rows[0][2] == ev + + # Verify storing works (use current_parliament so load_scree_data finds it) + db.batch_store_svd_vectors("current_parliament", rows) + + # Verify loading works + from analysis.explorer_data import load_scree_data + + loaded = load_scree_data(db_path) + assert loaded == pytest.approx(ev)