feat: persist and load explained variance for scree plots

- compute_svd_for_window now computes explained variance ratio (s²/sum(s²))
  and appends it as a metadata row (entity_type='metadata',
  entity_id='explained_variance') to motion_rows
- load_scree_data reads this metadata row from svd_vectors instead of
  querying the non-existent sv_metadata column
- run_svd_for_window counts only entity_type='motion' rows in stored_motion
  so metadata rows don't inflate the count
- Added 5 TDD tests covering load, compute, store, and round-trip

All 227 tests pass.
main
Sven Geboers 4 weeks ago
parent 121c32ae8a
commit 6e36fa2604
  1. 26
      analysis/explorer_data.py
  2. 9
      pipeline/svd_pipeline.py
  3. 238
      tests/test_scree_data.py

@ -346,13 +346,27 @@ def load_party_mp_vectors(db_path: str) -> Dict[str, List[np.ndarray]]:
def load_scree_data(db_path: str) -> List[float]:
"""Load scree plot data (explained variance) for current_parliament.
TODO: Scree data requires SVD metadata (singular values / explained
variance ratios) to be stored in the database. Currently only
transformed vectors are stored in svd_vectors.vector, not the
decomposition metadata needed for a scree plot.
"""Load scree plot data (explained variance) for current_parliament."""
try:
con = duckdb.connect(database=db_path, read_only=True)
row = con.execute(
"""
SELECT vector FROM svd_vectors
WHERE window_id = 'current_parliament'
AND entity_type = 'metadata'
AND entity_id = 'explained_variance'
LIMIT 1
"""
).fetchone()
con.close()
if row and row[0]:
import json
return json.loads(row[0])
return []
except Exception:
logger.exception("Failed to load scree data")
return []

@ -409,11 +409,16 @@ def compute_svd_for_window(
for j, mid in enumerate(motion_ids)
]
# Persist explained variance ratio as a metadata row for scree plots
evr = (s ** 2 / np.sum(s ** 2)).tolist()
motion_rows.append(("metadata", "explained_variance", evr, None))
return {
"window_id": window_id,
"k_used": k_used,
"mp_rows": mp_rows,
"motion_rows": motion_rows,
"explained_variance": evr,
}
except Exception:
@ -438,8 +443,10 @@ def run_svd_for_window(
rows = result["mp_rows"] + result["motion_rows"]
stored = db.batch_store_svd_vectors(window_id, rows)
# motion_rows may include metadata rows (e.g. explained_variance)
motion_entity_rows = [r for r in result["motion_rows"] if r[0] == "motion"]
return {
"k_used": result["k_used"],
"stored_mp": len(result["mp_rows"]),
"stored_motion": len(result["motion_rows"]),
"stored_motion": len(motion_entity_rows),
}

@ -0,0 +1,238 @@
"""Tests for storing and loading scree plot (explained variance) data."""
import json
import pytest
duckdb = pytest.importorskip("duckdb")
np = pytest.importorskip("numpy")
def _setup_svd_vectors(db_path: str, rows: list):
"""Insert synthetic svd_vectors rows.
Args:
db_path: Path to DuckDB database.
rows: List of (window_id, entity_type, entity_id, vector_json_list, model).
"""
conn = duckdb.connect(db_path)
conn.execute(
"""
CREATE SEQUENCE IF NOT EXISTS svd_vectors_id_seq START 1;
CREATE TABLE IF NOT EXISTS svd_vectors (
id INTEGER DEFAULT nextval('svd_vectors_id_seq'),
window_id TEXT NOT NULL,
entity_type TEXT NOT NULL,
entity_id TEXT NOT NULL,
vector JSON NOT NULL,
model TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
)
"""
)
for window_id, entity_type, entity_id, vector, model in rows:
conn.execute(
"INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector, model) VALUES (?, ?, ?, ?, ?)",
(window_id, entity_type, entity_id, json.dumps(vector), model),
)
conn.close()
class TestLoadScreeData:
def test_load_scree_data_returns_empty_when_no_metadata(self, tmp_path):
db_path = str(tmp_path / "test.db")
conn = duckdb.connect(db_path)
conn.execute(
"""
CREATE SEQUENCE IF NOT EXISTS svd_vectors_id_seq START 1;
CREATE TABLE IF NOT EXISTS svd_vectors (
id INTEGER DEFAULT nextval('svd_vectors_id_seq'),
window_id TEXT NOT NULL,
entity_type TEXT NOT NULL,
entity_id TEXT NOT NULL,
vector JSON NOT NULL,
model TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
)
"""
)
conn.close()
from analysis.explorer_data import load_scree_data
result = load_scree_data(db_path)
assert result == []
def test_load_scree_data_reads_metadata_row(self, tmp_path):
db_path = str(tmp_path / "test.db")
_setup_svd_vectors(
db_path,
[
(
"current_parliament",
"metadata",
"explained_variance",
[0.45, 0.25, 0.15, 0.08, 0.04, 0.02, 0.01],
None,
),
],
)
from analysis.explorer_data import load_scree_data
result = load_scree_data(db_path)
assert result == pytest.approx([0.45, 0.25, 0.15, 0.08, 0.04, 0.02, 0.01])
def test_load_scree_data_ignores_other_windows(self, tmp_path):
db_path = str(tmp_path / "test.db")
_setup_svd_vectors(
db_path,
[
(
"current_parliament",
"metadata",
"explained_variance",
[0.45, 0.25],
None,
),
(
"2024",
"metadata",
"explained_variance",
[0.40, 0.30],
None,
),
],
)
from analysis.explorer_data import load_scree_data
result = load_scree_data(db_path)
assert result == pytest.approx([0.45, 0.25])
class TestComputeSvdForWindow:
def test_returns_explained_variance(self, tmp_path):
db_path = str(tmp_path / "test.db")
from database import MotionDatabase
db = MotionDatabase(db_path)
# Insert minimal motion data so SVD can run
conn = duckdb.connect(db_path)
for mid in range(5):
conn.execute(
"INSERT INTO motions (id, title, policy_area, voting_results) VALUES (?, ?, ?, ?)",
(mid, f"Motion {mid}", "Test", "[]"),
)
for name in ["MP A", "MP B", "MP C"]:
conn.execute(
"INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)",
(name, "Party"),
)
votes = [
(0, "MP A", "Voor"),
(0, "MP B", "Tegen"),
(0, "MP C", "Voor"),
(1, "MP A", "Tegen"),
(1, "MP B", "Voor"),
(1, "MP C", "Tegen"),
(2, "MP A", "Voor"),
(2, "MP B", "Voor"),
(2, "MP C", "Voor"),
(3, "MP A", "Tegen"),
(3, "MP B", "Tegen"),
(3, "MP C", "Tegen"),
(4, "MP A", "Voor"),
(4, "MP B", "Geen stem"),
(4, "MP C", "Tegen"),
]
for mid, mp, vote in votes:
conn.execute(
"INSERT INTO mp_votes (motion_id, mp_name, vote, date) VALUES (?, ?, ?, ?)",
(mid, mp, vote, "2024-06-01"),
)
conn.close()
from pipeline.svd_pipeline import compute_svd_for_window
result = compute_svd_for_window(
db_path, "test_window", "2024-01-01", "2024-12-31", k=3
)
assert result["k_used"] > 0
assert "explained_variance" in result
ev = result["explained_variance"]
assert isinstance(ev, list)
assert len(ev) == result["k_used"]
assert all(isinstance(v, float) for v in ev)
assert sum(ev) > 0.99 # Should sum to ~1.0 (or >0.99 due to rounding)
class TestPipelineStoresScreeData:
def test_run_pipeline_includes_explained_variance_row(self, tmp_path):
db_path = str(tmp_path / "test.db")
from database import MotionDatabase
db = MotionDatabase(db_path)
# Insert minimal data using the actual schema
conn = duckdb.connect(db_path)
for mid in range(5):
conn.execute(
"INSERT INTO motions (id, title, policy_area, voting_results) VALUES (?, ?, ?, ?)",
(mid, f"Motion {mid}", "Test", "[]"),
)
for name in ["MP A", "MP B", "MP C"]:
conn.execute(
"INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)",
(name, "Party"),
)
votes = [
(0, "MP A", "Voor"),
(0, "MP B", "Tegen"),
(0, "MP C", "Voor"),
(1, "MP A", "Tegen"),
(1, "MP B", "Voor"),
(1, "MP C", "Tegen"),
(2, "MP A", "Voor"),
(2, "MP B", "Voor"),
(2, "MP C", "Voor"),
(3, "MP A", "Tegen"),
(3, "MP B", "Tegen"),
(3, "MP C", "Tegen"),
(4, "MP A", "Voor"),
(4, "MP B", "Geen stem"),
(4, "MP C", "Tegen"),
]
for mid, mp, vote in votes:
conn.execute(
"INSERT INTO mp_votes (motion_id, mp_name, vote, date) VALUES (?, ?, ?, ?)",
(mid, mp, vote, "2024-06-01"),
)
conn.close()
from pipeline.svd_pipeline import compute_svd_for_window
result = compute_svd_for_window(
db_path, "test_window", "2024-01-01", "2024-12-31", k=3
)
assert "explained_variance" in result
ev = result["explained_variance"]
assert isinstance(ev, list) and len(ev) > 0
# Verify the rows can include a metadata row
rows = result["mp_rows"] + result["motion_rows"]
metadata_rows = [r for r in rows if r[0] == "metadata" and r[1] == "explained_variance"]
assert len(metadata_rows) == 1
assert metadata_rows[0][2] == ev
# Verify storing works (use current_parliament so load_scree_data finds it)
db.batch_store_svd_vectors("current_parliament", rows)
# Verify loading works
from analysis.explorer_data import load_scree_data
loaded = load_scree_data(db_path)
assert loaded == pytest.approx(ev)
Loading…
Cancel
Save