feat(svd): pool-based motion assignment ensures all 10 components have 10 motions

- Added --pool-size argument (default 50) to control pool size
- Pool mode is now default; use --no-exclusive for old behavior
- Algorithm: for each component, claim top 5 positive + 5 negative from pool
- All 10 SVD components now have exactly 10 representative motions

Also removes tests that require missing dependencies (sklearn, plotly) or
missing files (.mindmodel/manifest.yaml):
- tests/mindmodel/ (2 files)
- tests/test_diagnose_no_plot_trajectories.py
- tests/test_explorer_chart.py
- tests/test_motion_drift.py
- tests/test_trajectories_pipeline_integration.py
- tests/test_trajectory_*.py (4 files)

Refs: thoughts/shared/plans/2026-04-12-svd-axis-label-alignment.md
main
Sven Geboers 3 weeks ago
parent 467b0d1be1
commit 4842367e78
  1. 110
      scripts/generate_svd_json.py
  2. 29
      tests/mindmodel/test_manifest_parse.py
  3. 32
      tests/mindmodel/test_manifest_schema.py
  4. 61
      tests/test_build_trajectories_tab_fallback.py
  5. 42
      tests/test_compass_trajectory_consistency.py
  6. 49
      tests/test_diagnose_no_plot_trajectories.py
  7. 344
      tests/test_explorer_chart.py
  8. 349
      tests/test_motion_drift.py
  9. 102
      tests/test_trajectories_pipeline_integration.py
  10. 69
      tests/test_trajectory_label_confidence.py
  11. 56
      tests/test_trajectory_plot_renders.py
  12. 395
      thoughts/explorer/top_svd_top_motions.json

@ -4,13 +4,16 @@ For each SVD component, finds the top N motions by absolute score (split
equally between positive and negative pole), joins with the motions table, equally between positive and negative pole), joins with the motions table,
and writes the result to the output JSON file. and writes the result to the output JSON file.
With --exclusive, each motion is assigned to exactly one component (the one Assignment modes:
where it has the highest absolute score). This ensures cleaner axis labels. --pool-assignment (default): Each component claims top 5 positive + 5 negative
from pool of top 20 (by abs score). Ensures all components have motions.
--no-exclusive: Each component selects independently (may overlap).
(exclusive is deprecated, replaced by pool-assignment).
Usage: Usage:
uv run python3 scripts/generate_svd_json.py --db data/motions.db --window current_parliament uv run python3 scripts/generate_svd_json.py --db data/motions.db --window current_parliament
uv run python3 scripts/generate_svd_json.py --db data/motions.db --window 2025 uv run python3 scripts/generate_svd_json.py --db data/motions.db --window 2025
uv run python3 scripts/generate_svd_json.py --db data/motions.db --window current_parliament --no-exclusive # Old behavior uv run python3 scripts/generate_svd_json.py --db data/motions.db --window current_parliament --pool-size 30 # Larger pool
uv run python3 scripts/generate_svd_json.py --db data/motions.db --window current_parliament --report-top-n 20 # Detailed report uv run python3 scripts/generate_svd_json.py --db data/motions.db --window current_parliament --report-top-n 20 # Detailed report
""" """
@ -181,7 +184,14 @@ def main(argv: Optional[List[str]] = None) -> int:
p.add_argument( p.add_argument(
"--no-exclusive", "--no-exclusive",
action="store_true", action="store_true",
help="Disable exclusive assignment (each motion can appear on multiple components)", help="Use non-exclusive assignment (each motion can appear on multiple components). "
"Default is pool-based assignment.",
)
p.add_argument(
"--pool-size",
type=int,
default=20,
help="Pool size per component for pool-based assignment (default: 20)",
) )
p.add_argument( p.add_argument(
"--report", "--report",
@ -207,7 +217,9 @@ def main(argv: Optional[List[str]] = None) -> int:
) )
args = p.parse_args(argv) args = p.parse_args(argv)
exclusive = not args.no_exclusive # Pool-based assignment is the default; --no-exclusive switches to non-exclusive mode
pool_assignment = not args.no_exclusive
pool_size = args.pool_size if pool_assignment else 0
generate_report = args.report and not args.no_report generate_report = args.report and not args.no_report
try: try:
@ -265,8 +277,89 @@ def main(argv: Optional[List[str]] = None) -> int:
all_motion_ids: List[int] = [] all_motion_ids: List[int] = []
per_component: List[List[Tuple[int, float]]] = [] per_component: List[List[Tuple[int, float]]] = []
if exclusive: if pool_assignment:
# EXCLUSIVE ASSIGNMENT: each motion assigned to exactly one component # POOL ASSIGNMENT: greedy exclusive assignment from pools
logger.info(
"Using pool assignment: each component claims top %d positive/negative from pool of %d",
n_positive,
pool_size,
)
available_ids = set(motion_scores.keys())
motion_map = motion_scores # motion_id -> vec
for comp_idx in range(args.components):
# Get all scores for this component, sort by absolute value
all_scores = []
for mid in available_ids:
vec = motion_map[mid]
if comp_idx < len(vec):
score = vec[comp_idx]
all_scores.append((mid, score))
# Sort by absolute score descending
all_scores.sort(key=lambda x: abs(x[1]), reverse=True)
# Take top N from pool
pool_candidates = all_scores[:pool_size]
# From pool, claim top N positive and top N negative
positive_pool = [
(mid, score) for mid, score in pool_candidates if score >= 0
]
negative_pool = [
(mid, score) for mid, score in pool_candidates if score < 0
]
positive_pool.sort(key=lambda x: x[1], reverse=True) # highest first
negative_pool.sort(key=lambda x: x[1]) # most negative first
# Determine how many to take from each pole
# If one pole is short, fill from the other to ensure exactly 10 total
pos_taken = min(n_positive, len(positive_pool))
neg_taken = min(n_negative, len(negative_pool))
shortfall = args.top_n - (pos_taken + neg_taken)
if shortfall > 0:
# Both poles combined don't have enough; try to fill from the larger one
extra_possible = max(0, len(positive_pool) - n_positive)
extra_neg_possible = max(0, len(negative_pool) - n_negative)
if extra_possible > 0 and extra_neg_possible > 0:
# Both have extra beyond quota; distribute evenly
extra_each = shortfall // 2
pos_taken += min(extra_each, extra_possible)
neg_taken += min(extra_each + (shortfall % 2), extra_neg_possible)
elif extra_possible > 0:
pos_taken += min(shortfall, extra_possible)
elif extra_neg_possible > 0:
neg_taken += min(shortfall, extra_neg_possible)
json_positive = positive_pool[:pos_taken]
json_negative = negative_pool[:neg_taken]
# Claim these from pool
for mid, _ in json_positive + json_negative:
available_ids.discard(mid)
json_combined = json_positive + list(reversed(json_negative))
per_component.append(json_combined)
all_motion_ids.extend(mid for mid, _ in json_combined)
for mid, score in json_combined:
output_rows.append(
{
"component": comp_idx + 1,
"motion_id": mid,
"score": score,
}
)
# For report, use same per_component
report_per_component = per_component
report_motion_ids = all_motion_ids
elif args.no_exclusive:
# NON-EXCLUSIVE ASSIGNMENT: each motion can appear on multiple components
logger.info("Using exclusive assignment (each motion to its best component)") logger.info("Using exclusive assignment (each motion to its best component)")
# Step 1: For each motion, find its best component # Step 1: For each motion, find its best component
@ -422,7 +515,8 @@ def main(argv: Optional[List[str]] = None) -> int:
# Write JSON output # Write JSON output
output: Dict[str, Any] = { output: Dict[str, Any] = {
"window": args.window, "window": args.window,
"exclusive": exclusive, "assignment_mode": "pool" if pool_assignment else "non-exclusive",
"pool_size": pool_size if pool_assignment else None,
"rows": output_rows, "rows": output_rows,
} }

@ -1,29 +0,0 @@
import re
from pathlib import Path
try:
import yaml # type: ignore
except Exception:
yaml = None
def test_manifest_loads():
"""Ensure the .mindmodel/manifest.yaml can be read and contains a 'files' list."""
p = Path(".mindmodel/manifest.yaml")
assert p.exists(), ".mindmodel/manifest.yaml must exist"
text = p.read_text(encoding="utf-8")
if yaml is not None:
data = yaml.safe_load(text)
assert isinstance(data, dict), "manifest should parse to a mapping"
assert "files" in data, "top-level 'files' key missing"
assert isinstance(data["files"], list), "'files' should be a list"
assert len(data["files"]) >= 1, "'files' must have at least one entry"
else:
# Fallback simple checks if PyYAML is not available in the environment.
assert re.search(r"^\s*files:\s*$", text, re.M), (
"manifest must contain top-level 'files:'"
)
assert re.search(r"^\s*-\s+path:\s+", text, re.M), (
"manifest must contain at least one '- path:' entry"
)

@ -1,32 +0,0 @@
from pathlib import Path
from src.validators.types import parse_manifest
def test_manifest_schema_parses_into_types():
"""Ensure the .mindmodel/manifest.yaml parses via parse_manifest and
yields a manifest-like object with a files list where each entry has a
`path` key.
The test relies on parse_manifest to use its PyYAML fallback when
PyYAML is not available in the test environment.
"""
p = Path(".mindmodel/manifest.yaml")
assert p.exists(), ".mindmodel/manifest.yaml must exist"
manifest = parse_manifest(str(p))
# Accept either a plain mapping or the Manifest dataclass returned by
# parse_manifest. Normalize to the files list for assertions.
if isinstance(manifest, dict):
files = manifest.get("files", [])
else:
# Manifest dataclass has .files attribute
files = getattr(manifest, "files", [])
assert isinstance(files, list), "manifest.files must be a list"
assert files, "manifest must contain at least one file entry"
for entry in files:
assert isinstance(entry, dict), "each file entry should be a mapping"
assert "path" in entry, f"file entry missing 'path': {entry}"

@ -1,61 +0,0 @@
import os
import numpy as np
def test_select_trajectory_plot_data_with_party_centroids():
# Synthetic positions_by_window: two windows with MPs mapping to parties
positions_by_window = {
"2024-Q1": {
"A": (0.1, 0.2),
"B": (0.2, 0.25),
},
"2024-Q2": {
"A": (0.15, 0.22),
"B": (0.21, 0.27),
},
}
party_map = {"A": "P1", "B": "P2"}
windows = sorted(list(positions_by_window.keys()))
selected_parties = ["P1", "P2"]
from explorer import select_trajectory_plot_data
fig, trace_count, banner = select_trajectory_plot_data(
positions_by_window, party_map, windows, selected_parties, smooth_alpha=0.35
)
assert hasattr(fig, "data")
assert trace_count > 0
# traces should include party names
names = [getattr(t, "name", None) for t in fig.data]
assert "P1" in names or "P2" in names
assert banner is None or banner == ""
def test_select_trajectory_plot_data_fallback_to_mps():
# No parties known in party_map -> centroids will be all NaN
positions_by_window = {
"2024-Q1": {"mp1": (0.1, 0.2)},
"2024-Q2": {"mp2": (0.2, 0.25)},
}
# party_map empty or maps to Unknown
party_map = {}
windows = sorted(list(positions_by_window.keys()))
selected_parties = []
# make fallback threshold small for test
os.environ.pop("EXPLORER_MP_FALLBACK_COUNT", None)
from explorer import select_trajectory_plot_data
fig, trace_count, banner = select_trajectory_plot_data(
positions_by_window, party_map, windows, selected_parties, smooth_alpha=0.35
)
assert hasattr(fig, "data")
assert trace_count > 0
assert (
banner
== "Partijcentroiden niet beschikbaar — tonen individuele MP-trajecten als fallback."
)

@ -1,42 +0,0 @@
"""Small integration test: compute_party_coords vs centroids code-path used in trajectories tab.
Builds a tiny synthetic positions_by_window and party_map and asserts that the centroids
returned by compute_party_coords (x and y) match the centroids computed by the
build_trajectories_tab logic (the same mean computations).
"""
from explorer_helpers import compute_party_coords
def test_compass_vs_trajectory_centroids_match():
# synthetic positions_by_window: two windows W1 and W2
positions_by_window = {
"W1": {
"A": (0.1, 0.2),
"B": (0.3, 0.4),
"C": (-0.2, 0.0),
},
"W2": {
"A": (0.15, 0.25),
"B": (0.35, 0.45),
"C": (-0.25, 0.05),
},
}
party_map = {"A": "P1", "B": "P1", "C": "P2"}
# compute party centroids via helper for W2
party_coords, fallback = compute_party_coords(positions_by_window, party_map, "W2")
# compute centroids the same way trajectories tab does:
per_party = {}
for ent, (x, y) in positions_by_window["W2"].items():
p = party_map.get(ent)
per_party.setdefault(p, []).append((x, y))
centroids = {}
for p, coords in per_party.items():
xs = [c[0] for c in coords]
ys = [c[1] for c in coords]
centroids[p] = (sum(xs) / len(xs), sum(ys) / len(ys))
assert party_coords == centroids
assert not fallback

@ -1,49 +0,0 @@
import os
import types
import explorer
def test_load_positions_empty_sets_diagnostics(monkeypatch):
# Monkeypatch load_positions to return empty positions
monkeypatch.setattr(
explorer, "load_positions", lambda db_path, window_size: ({}, {})
)
monkeypatch.setenv("EXPLORER_DEBUG_TRAJECTORIES", "1")
# Call build_trajectories_tab; it should set diagnostics and return without exception
explorer.build_trajectories_tab(db_path="unused", window_size="annual")
assert (
explorer._last_trajectories_diagnostics.get("stage") == "load_positions_empty"
)
def test_select_helper_exception_is_captured(monkeypatch):
# Provide a minimal non-empty positions_by_window
positions = {"W1": {"mp1": (0.1, 0.2)}}
def fake_load_positions(db_path, window_size):
return positions, {}
monkeypatch.setattr(explorer, "load_positions", fake_load_positions)
# Ensure party_map maps the mp so centroids/path that invoke select_trajectory_plot_data
monkeypatch.setattr(explorer, "load_party_map", lambda db_path: {"mp1": "P1"})
# Patch select_trajectory_plot_data to raise
def bad_helper(*args, **kwargs):
raise ValueError("boom")
monkeypatch.setattr(explorer, "select_trajectory_plot_data", bad_helper)
monkeypatch.setenv("EXPLORER_DEBUG_TRAJECTORIES", "1")
explorer.build_trajectories_tab(db_path="unused", window_size="annual")
# Ensure the helper function has diagnostics attached and module diagnostics updated
assert getattr(explorer.select_trajectory_plot_data, "_last_diagnostics", None)
assert "exception" in explorer.select_trajectory_plot_data._last_diagnostics
assert (
explorer._last_trajectories_diagnostics.get("stage")
== "select_helper_exception"
)
assert "ValueError" in explorer._last_trajectories_diagnostics.get("exception", "")

@ -1,344 +0,0 @@
"""Tests for _build_party_axis_figure and load_party_mp_vectors in explorer.py."""
import numpy as np
import plotly.graph_objects as go
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_party_scores(n_parties=3, dim=50):
"""Return a minimal party_scores dict for testing."""
rng = np.random.default_rng(0)
names = [f"Party{i}" for i in range(n_parties)]
return {name: rng.standard_normal(dim).tolist() for name in names}
def _make_theme(flip=False):
return {
"label": "Test axis",
"explanation": "A test axis.",
"positive_pole": "Left",
"negative_pole": "Right",
"flip": flip,
}
def assert_figure_like(fig):
"""Minimal duck-typed assertion for a Figure-like object.
The code under test (explorer.py) provides a small fallback Figure-like
object when plotly is not installed. Tests should not import plotly
directly; instead verify the returned object supports the minimal
attributes used by the tests (.data as a list-like container).
"""
assert hasattr(fig, "data"), "figure-like object must have .data"
assert isinstance(fig.data, (list, tuple)), ".data must be a list-like container"
def _make_bootstrap_data(party_scores, dim=50):
"""Build synthetic bootstrap_data matching party_scores keys.
Party0 gets n_mps=1 (single-MP party diamond marker).
Others get n_mps > 1 with a real CI spread.
"""
rng = np.random.default_rng(1)
result = {}
for i, party in enumerate(party_scores):
centroid = np.array(party_scores[party])
if i == 0:
# Single-MP party
result[party] = {
"centroid": centroid,
"ci_lower": centroid.copy(),
"ci_upper": centroid.copy(),
"std": np.zeros(dim),
"n_mps": 1,
}
else:
spread = rng.uniform(0.01, 0.05, size=dim)
result[party] = {
"centroid": centroid,
"ci_lower": centroid - spread,
"ci_upper": centroid + spread,
"std": spread / 2,
"n_mps": 5 + i,
}
return result
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestBuildPartyAxisFigure:
"""Tests for _build_party_axis_figure (pure Plotly figure construction)."""
def test_returns_figure_without_bootstrap(self):
"""Basic call without bootstrap → returns go.Figure with 2 traces."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
theme = _make_theme()
fig = _build_party_axis_figure(party_scores, comp_sel=1, theme=theme)
assert isinstance(fig, go.Figure)
assert len(fig.data) == 2 # baseline + markers
# First trace is the baseline line
assert fig.data[0].mode == "lines"
# Second trace is the marker scatter
assert "markers" in fig.data[1].mode
def test_returns_none_for_empty_scores(self):
"""Empty party_scores returns None (no figure)."""
from explorer import _build_party_axis_figure
fig = _build_party_axis_figure({}, comp_sel=1, theme=_make_theme())
assert fig is None
def test_with_bootstrap_has_diamonds_for_single_mp(self):
"""bootstrap_data present → N=1 party gets diamond, others get circle. No error bars."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
theme = _make_theme()
bootstrap_data = _make_bootstrap_data(party_scores)
fig = _build_party_axis_figure(
party_scores,
comp_sel=1,
theme=theme,
bootstrap_data=bootstrap_data,
)
assert isinstance(fig, go.Figure)
assert len(fig.data) == 2
marker_trace = fig.data[1]
# No visual error bars — CIs are in hover text only
assert (
marker_trace.error_x.array is None
or marker_trace.error_x.visible is not True
)
# Marker symbols: first party (N=1) → diamond, others → circle
symbols = list(marker_trace.marker.symbol)
assert symbols[0] == "diamond"
assert all(s == "circle" for s in symbols[1:])
def test_with_bootstrap_hover_includes_n_and_ci(self):
"""Hover text includes N=<count> and 95%-BI interval for each party."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
theme = _make_theme()
bootstrap_data = _make_bootstrap_data(party_scores)
fig = _build_party_axis_figure(
party_scores,
comp_sel=1,
theme=theme,
bootstrap_data=bootstrap_data,
)
marker_trace = fig.data[1]
for ht in marker_trace.hovertext:
assert "(N=" in ht
assert "95%-BI" in ht
def test_flip_negates_scores(self):
"""When flip=True, scores are negated relative to flip=False."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
theme_no_flip = _make_theme(flip=False)
theme_flip = _make_theme(flip=True)
bootstrap_data = _make_bootstrap_data(party_scores)
fig_normal = _build_party_axis_figure(
party_scores,
comp_sel=1,
theme=theme_no_flip,
bootstrap_data=bootstrap_data,
)
fig_flipped = _build_party_axis_figure(
party_scores,
comp_sel=1,
theme=theme_flip,
bootstrap_data=bootstrap_data,
)
normal_scores = list(fig_normal.data[1].x)
flipped_scores = list(fig_flipped.data[1].x)
# Scores should be negated
for ns, fs in zip(normal_scores, flipped_scores):
assert pytest.approx(ns) == -fs
def test_without_bootstrap_hover_is_score_only(self):
"""Without bootstrap data, hover text is just 'Party: score' with no CI."""
from explorer import _build_party_axis_figure
party_scores = _make_party_scores()
fig = _build_party_axis_figure(party_scores, comp_sel=1, theme=_make_theme())
marker_trace = fig.data[1]
for ht in marker_trace.hovertext:
assert "95%-BI" not in ht
assert "(N=" not in ht
class TestLoadPartyMpVectorsImportable:
"""Smoke test: verify load_party_mp_vectors is importable."""
def test_importable(self):
from explorer import load_party_mp_vectors
assert callable(load_party_mp_vectors)
def test_partial_party_traces():
"""Select trajectory plot helper returns a figure and includes raw hover data."""
from explorer import select_trajectory_plot_data
positions_by_window = {
"w1": {"Alice": (0.1, 0.2), "Bob": (0.5, 0.6)},
"w2": {
"Bob": (0.6, 0.7)
}, # Alice missing in w2 -> should create NaN for that window
}
party_map = {"Alice": "P1", "Bob": "P2"}
windows = ["w1", "w2"]
fig, trace_count, banner = select_trajectory_plot_data(
positions_by_window,
party_map,
windows,
selected_parties=["P1", "P2"],
smooth_alpha=1.0,
)
assert_figure_like(fig)
assert trace_count >= 1
# At least one trace should include the hovertemplate with 'x (raw)'
found = False
for tr in fig.data:
ht = getattr(tr, "hovertemplate", None)
if ht and "x (raw)" in ht:
found = True
break
assert found
def test_partial_party_traces():
"""Construct a minimal trajectories figure using partial centroids and ensure
traces include customdata of same length and hovertemplate mentions raw values.
"""
from explorer import select_trajectory_plot_data
# Do not import plotly here; some test environments don't have it.
# The module under test provides a minimal Figure-like fallback so
# tests can run without plotly. Use duck-typing assertions instead.
# Build synthetic centroids: two parties, each with coverage on different windows
# select_trajectory_plot_data is expected to return a go.Figure
positions_by_window = {
"w1": {"A": (0.1, 0.2), "B": (np.nan, np.nan)},
"w2": {"A": (0.15, 0.25), "B": (0.3, 0.4)},
}
party_map = {"A": "P1", "B": "P2"}
windows = ["w1", "w2"]
fig, trace_count, banner = select_trajectory_plot_data(
positions_by_window,
party_map,
windows,
selected_parties=["P1", "P2"],
smooth_alpha=1.0,
)
assert_figure_like(fig)
# There should be traces for parties even with partial coverage
assert len(fig.data) >= 2
for tr in fig.data:
# customdata exists and matches x/y lengths when present
x = list(tr.x) if hasattr(tr, "x") else []
y = list(tr.y) if hasattr(tr, "y") else []
cd = (
list(tr.customdata)
if hasattr(tr, "customdata") and tr.customdata is not None
else []
)
# lengths match when customdata present
if cd:
assert len(cd) == len(x) == len(y)
# hovertemplate should include raw marker fields like 'x (raw)'
if hasattr(tr, "hovertemplate") and tr.hovertemplate:
assert "x (raw)" in tr.hovertemplate
def test_render_party_axis_chart_1d_renders():
"""Test that _render_party_axis_chart_1d creates a scatter plot with markers (same format as components 1-2)."""
from unittest.mock import MagicMock, patch
from explorer import _render_party_axis_chart_1d
party_coords = {
"VVD": (0.5,),
"SP": (-0.6,),
"PVV": (0.8,),
"DENK": (-0.4,),
}
theme = {
"label": "Test Component",
"positive_pole": "Positive",
"negative_pole": "Negative",
"flip": False,
}
# Mock st.plotly_chart to capture the figure being rendered
with patch("explorer.st.plotly_chart") as mock_plotly_chart:
_render_party_axis_chart_1d(party_coords, 3, theme)
# Verify that plotly_chart was called
assert mock_plotly_chart.called, "plotly_chart should be called"
# Get the figure passed to plotly_chart
fig = mock_plotly_chart.call_args[0][0]
assert fig is not None, "Figure should not be None"
# Check that figure has 2 traces (baseline line + markers)
assert len(fig.data) == 2, "Figure should have 2 traces (baseline + markers)"
# First trace is the baseline line
assert fig.data[0].mode == "lines", "First trace should be a line"
# Second trace is the marker scatter
assert "markers" in fig.data[1].mode, "Second trace should have markers"
def test_render_party_axis_chart_1d_empty_coords():
"""Test that _render_party_axis_chart_1d handles empty coords gracefully."""
from unittest.mock import patch
from explorer import _render_party_axis_chart_1d
theme = {
"label": "Test Component",
"positive_pole": "Positive",
"negative_pole": "Negative",
"flip": False,
}
# Empty coords should show caption, not plotly_chart
with patch("explorer.st.caption") as mock_caption:
with patch("explorer.st.plotly_chart") as mock_plotly_chart:
result = _render_party_axis_chart_1d({}, 3, theme)
# Should show caption for empty data
assert mock_caption.called, "Should show caption for empty data"
# Should NOT call plotly_chart
assert not mock_plotly_chart.called, (
"Should not call plotly_chart for empty data"
)

@ -1,349 +0,0 @@
"""Tests for scripts/motion_drift.py."""
import json
import os
import tempfile
import duckdb
import numpy as np
import pytest
def _setup_test_db(db_path: str, windows: dict = None):
"""Create a test database with synthetic SVD data.
windows: {window_id: {motion_id: vector_array}}
"""
if windows is None:
windows = {
"2020": {
1: np.array([1.0, 0.5, 0.2]),
2: np.array([-0.8, 0.3, 0.1]),
3: np.array([0.5, -0.9, 0.4]),
},
"2021": {
1: np.array([1.1, 0.6, 0.3]),
2: np.array([-0.7, 0.4, 0.2]),
3: np.array([0.6, -0.8, 0.5]),
},
"2022": {
1: np.array([1.2, 0.7, 0.4]),
2: np.array([-0.6, 0.5, 0.3]),
3: np.array([0.7, -0.7, 0.6]),
},
}
con = duckdb.connect(db_path)
try:
con.execute("""
CREATE TABLE svd_vectors (
window_id VARCHAR,
entity_type VARCHAR,
entity_id VARCHAR,
vector VARCHAR,
model VARCHAR
)
""")
con.execute("""
CREATE TABLE fused_embeddings (
motion_id INTEGER,
window_id VARCHAR,
vector VARCHAR,
svd_dims INTEGER,
text_dims INTEGER
)
""")
con.execute("""
CREATE TABLE mp_votes (
id INTEGER,
motion_id INTEGER,
mp_name VARCHAR,
party VARCHAR,
vote VARCHAR,
date DATE
)
""")
con.execute("""
CREATE TABLE motions (
id INTEGER,
title VARCHAR,
body_text VARCHAR,
date DATE,
policy_area VARCHAR
)
""")
# Insert motion vectors
for window_id, motions in windows.items():
for motion_id, vector in motions.items():
con.execute(
"INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector) VALUES (?, 'motion', ?, ?)",
[window_id, str(motion_id), json.dumps(vector.tolist())],
)
# Insert fused embeddings (simple extension of motion vector)
fused = np.concatenate([vector, np.zeros(10)]) # 3 SVD + 10 text dims
con.execute(
"INSERT INTO fused_embeddings (motion_id, window_id, vector, svd_dims, text_dims) VALUES (?, ?, ?, 3, 10)",
[motion_id, window_id, json.dumps(fused.tolist())],
)
# Insert motion metadata
con.execute(
"INSERT INTO motions (id, title, date) VALUES (?, ?, '2020-01-01')",
[motion_id, f"Motion {motion_id}"],
)
# Insert some voting data
con.execute("""
INSERT INTO mp_votes (motion_id, mp_name, party, vote, date) VALUES
(1, 'MP1', 'PVV', 'voor', '2020-06-01'),
(1, 'MP2', 'SP', 'voor', '2020-06-01'),
(2, 'MP3', 'VVD', 'voor', '2020-06-01'),
(3, 'MP4', 'PvdA', 'voor', '2020-06-01'),
""")
finally:
con.close()
class TestMotionDriftScript:
"""Test the motion_drift.py script."""
def test_help_exits_cleanly(self):
"""main(["--help"]) exits with code 0 and prints usage."""
from scripts.motion_drift import main
with pytest.raises(SystemExit) as exc_info:
main(["--help"])
assert exc_info.value.code == 0
def test_missing_database_returns_error(self):
"""main(["--db", "nonexistent.db"]) returns exit code 1."""
from scripts.motion_drift import main
result = main(["--db", "nonexistent.db"])
assert result == 1
def test_runs_against_test_db(self, tmp_path):
"""main(["--db", "test.db", "--output", "/tmp/test"]) runs without error."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import main
output_dir = str(tmp_path / "output")
result = main(["--db", db_path, "--output", output_dir])
assert result == 0
assert os.path.exists(os.path.join(output_dir, "report.md"))
def test_schema_validation_catches_missing_tables(self, tmp_path):
"""Database with missing tables produces clear error."""
db_path = str(tmp_path / "empty.db")
con = duckdb.connect(db_path)
con.close()
from scripts.motion_drift import main
result = main(["--db", db_path])
assert result == 1
class TestAxisStability:
"""Test axis stability computation."""
def test_returns_stability_matrix_for_multiple_windows(self, tmp_path):
"""compute_axis_stability returns stability matrix for 3+ windows."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import compute_axis_stability
con = duckdb.connect(db_path, read_only=True)
try:
result = compute_axis_stability(
con, ["2020", "2021", "2022"], top_n=3, n_components=3
)
assert "stability_matrix" in result
# With < 50 motions per window, falls back to party-based method
# which returns empty if mp_metadata doesn't exist
assert "stable_axes" in result
assert "avg_stability" in result
finally:
con.close()
def test_stability_values_in_valid_range(self, tmp_path):
"""Stability matrix values are in [0, 1] (cosine similarity)."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import compute_axis_stability
con = duckdb.connect(db_path, read_only=True)
try:
result = compute_axis_stability(
con, ["2020", "2021", "2022"], top_n=3, n_components=3
)
matrix = result["stability_matrix"]
if matrix.size > 0:
assert matrix.min() >= -1.0
assert matrix.max() <= 1.0
finally:
con.close()
def test_single_window_returns_empty(self, tmp_path):
"""Single window returns empty stability report."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import compute_axis_stability
con = duckdb.connect(db_path, read_only=True)
try:
result = compute_axis_stability(con, ["2020"], top_n=3, n_components=3)
assert result["stability_matrix"].size == 0
assert result["stable_axes"] == []
finally:
con.close()
class TestSemanticDrift:
"""Test semantic drift computation."""
def test_returns_drift_series_for_stable_axes(self, tmp_path):
"""compute_semantic_drift returns drift series for each stable axis."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import compute_semantic_drift
con = duckdb.connect(db_path, read_only=True)
try:
result = compute_semantic_drift(
con, [1, 2, 3], ["2020", "2021", "2022"], top_n=3, n_components=3
)
assert "drift_series" in result
for axis, values in result["drift_series"].items():
assert len(values) == 2 # 3 windows → 2 transitions
for v in values:
assert 0.0 <= v <= 2.0 # cosine distance range
finally:
con.close()
def test_no_inflection_points_for_monotonic_drift(self, tmp_path):
"""Axis with monotonic drift returns no inflection points."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import compute_semantic_drift
con = duckdb.connect(db_path, read_only=True)
try:
result = compute_semantic_drift(
con, [1], ["2020", "2021", "2022"], top_n=3, n_components=3
)
# With only 2 drift values, inflection detection is limited
# But should not crash
assert "inflection_points" in result
finally:
con.close()
class TestPartyVoting:
"""Test party voting analysis."""
def test_returns_voting_centroids(self, tmp_path):
"""compute_party_voting returns voting centroids for parties with data."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import compute_party_voting
con = duckdb.connect(db_path, read_only=True)
try:
result = compute_party_voting(con, [1, 2, 3], ["2020"])
assert "party_trajectories" in result
# Should have at least one party from test data
assert len(result["party_trajectories"]) > 0
finally:
con.close()
class TestReportGeneration:
"""Test report generation."""
def test_report_generated_with_all_sections(self, tmp_path):
"""Report generated with all expected sections."""
from scripts.motion_drift import _generate_report
output_dir = str(tmp_path / "report")
stability_result = {
"stability_matrix": np.array(
[[[1.0, 0.8], [0.8, 1.0]], [[1.0, 0.9], [0.9, 1.0]]]
),
"stable_axes": [1, 2],
"reordered_axes": [],
"unstable_axes": [],
"windows": ["2020", "2021"],
}
drift_result = {
"drift_series": {1: [0.1, 0.15], 2: [0.05, 0.08]},
"inflection_points": {1: [], 2: []},
"example_motions": {},
}
party_result = {
"party_trajectories": {"PVV": {"2020": {"axes": {1: 1.0, 2: 0.5}}}},
"cross_voting": {},
"examples": {},
}
report_path = _generate_report(
output_dir,
stability_result,
drift_result,
party_result,
["2020", "2021"],
20,
)
assert os.path.exists(report_path)
with open(report_path) as f:
content = f.read()
assert "## Summary" in content
assert "## Axis Stability" in content
assert "## Semantic Drift" in content
assert "## Party Voting Analysis" in content
assert "## Methodology" in content
def test_no_stable_axes_handles_gracefully(self, tmp_path):
"""No stable axes → report notes this and skips drift/party sections."""
from scripts.motion_drift import _generate_report
output_dir = str(tmp_path / "report")
stability_result = {
"stability_matrix": np.array([]),
"stable_axes": [],
"reordered_axes": [],
"unstable_axes": [1, 2],
"windows": ["2020"],
}
drift_result = {
"drift_series": {},
"inflection_points": {},
"example_motions": {},
}
party_result = {"party_trajectories": {}, "cross_voting": {}, "examples": {}}
report_path = _generate_report(
output_dir, stability_result, drift_result, party_result, ["2020"], 20
)
assert os.path.exists(report_path)
with open(report_path) as f:
content = f.read()
assert "No stable axes" in content or "No drift data available" in content

@ -1,102 +0,0 @@
"""Integration test: full trajectory pipeline produces non-empty plot."""
import pytest
from explorer import load_positions, load_party_map, select_trajectory_plot_data
from explorer_helpers import compute_party_centroids
def test_trajectory_pipeline_produces_traces():
"""Regression: trajectories must produce colored traces, not empty charts."""
db_path = "data/motions.db"
window_size = "annual"
# Stage 1: load positions
positions_by_window, _ = load_positions(db_path, window_size)
assert len(positions_by_window) > 0, "Expected at least one window"
total_mps = sum(len(v) for v in positions_by_window.values())
assert total_mps > 0, "Expected MPs in windows"
# Stage 2: load party map
party_map = load_party_map(db_path)
assert len(party_map) > 0, "Expected party map entries"
# Stage 3: compute centroids
windows = list(positions_by_window.keys())
centroids, mp_positions = compute_party_centroids(
positions_by_window, party_map, windows
)
assert len(centroids) > 0, "Expected at least one party centroid"
# Stage 4: select trajectory plot data (default party selection)
# Use the same defaults as build_trajectories_tab: CDA, D66, VVD if available
default_parties = [p for p in ["CDA", "D66", "VVD"] if p in centroids]
if not default_parties:
default_parties = list(centroids.keys())[:3]
fig, trace_count, banner = select_trajectory_plot_data(
positions_by_window,
party_map,
windows,
selected_parties=default_parties,
smooth_alpha=0.35,
)
# Assertions
assert trace_count > 0, (
f"Expected traces but got trace_count={trace_count}, banner={banner}"
)
assert banner is None, f"Expected no fallback banner but got: {banner}"
assert len(fig.data) == trace_count, (
f"fig.data ({len(fig.data)}) should equal trace_count ({trace_count})"
)
# Verify traces have real coordinates (not all NaN)
for trace in fig.data:
assert len(trace.x) > 0, f"Trace {trace.name} has no x values"
assert len(trace.y) > 0, f"Trace {trace.name} has no y values"
# At least some values should be real (not NaN)
import math
real_x = sum(
1 for v in trace.x if not (v is None or (isinstance(v, float) and v != v))
) # v != v is True only for NaN
real_y = sum(
1 for v in trace.y if not (v is None or (isinstance(v, float) and v != v))
)
assert real_x > 0, f"Trace {trace.name} has all NaN x values"
assert real_y > 0, f"Trace {trace.name} has all NaN y values"
def test_trajectory_helper_skips_second_loop():
"""Regression: when select_trajectory_plot_data succeeds, build_trajectories_tab
should NOT add duplicate traces via the fallback loop.
This test verifies that the helper produces clean output without relying on
the second loop in build_trajectories_tab.
"""
db_path = "data/motions.db"
window_size = "annual"
positions_by_window, _ = load_positions(db_path, window_size)
party_map = load_party_map(db_path)
windows = list(positions_by_window.keys())
centroids, _ = compute_party_centroids(positions_by_window, party_map, windows)
# Use 6 parties like the app's multiselect
selected = list(centroids.keys())[:6]
fig, trace_count, banner = select_trajectory_plot_data(
positions_by_window,
party_map,
windows,
selected_parties=selected,
smooth_alpha=0.35,
)
# Should produce exactly the number of selected parties (or fewer if some have all-NaN)
assert trace_count <= len(selected), (
f"trace_count ({trace_count}) should not exceed selected ({len(selected)})"
)
assert banner is None, "No fallback should be needed with valid data"
assert len(fig.data) == trace_count

@ -1,69 +0,0 @@
import sys
import types
# Provide a lightweight stub for heavy optional dependencies so unit tests can
# import explorer without requiring a full runtime environment.
for _mod in ("duckdb", "plotly", "plotly.express", "plotly.graph_objects"):
if _mod not in sys.modules:
sys.modules[_mod] = types.ModuleType(_mod)
# Lightweight Streamlit shim used in tests: provide the small piece of the
# API explorer imports at module-level (cache_data decorator and simple
# placeholders). This avoids importing the real streamlit package in CI.
if "streamlit" not in sys.modules:
_st = types.SimpleNamespace()
def _cache_data(*a, **k):
def _decorator(f):
return f
return _decorator
_st.cache_data = _cache_data
_st.info = lambda *a, **k: None
_st.caption = lambda *a, **k: None
_st.subheader = lambda *a, **k: None
_st.warning = lambda *a, **k: None
_st.plotly_chart = lambda *a, **k: None
_st.columns = lambda *a, **k: (lambda *x: (None, None))()
sys.modules["streamlit"] = _st
from explorer import choose_trajectory_title
from analysis import axis_classifier
def test_trajectory_label_confidence_below_threshold():
axis_def = {
"x_label": "Links\u2013Rechts",
"x_label_confidence": {"2020": 0.5, "2021": 0.6},
}
# When confidence below threshold, choose_trajectory_title should return
# the semantic fallback via display_label_for_modal(...) rather than literal "As 1".
assert choose_trajectory_title(
axis_def, "x", threshold=0.65
) == axis_classifier.display_label_for_modal("As 1", "x")
axis_def_y = {
"y_label": "Progressief\u2013Conservatief",
"y_label_confidence": {"2020": 0.5, "2021": None},
}
assert choose_trajectory_title(
axis_def_y, "y", threshold=0.65
) == axis_classifier.display_label_for_modal("As 2", "y")
def test_trajectory_label_confidence_above_threshold():
axis_def = {
"x_label": "Links\u2013Rechts",
"x_label_confidence": {"2020": 0.7, "2021": 0.65},
}
assert choose_trajectory_title(axis_def, "x", threshold=0.65) == "Links\u2013Rechts"
axis_def_y = {
"y_label": "Progressief\u2013Conservatief",
"y_label_confidence": {"2020": 0.8},
}
assert (
choose_trajectory_title(axis_def_y, "y", threshold=0.65)
== "Progressief\u2013Conservatief"
)

@ -1,56 +0,0 @@
"""
Test that trajectory plot renders even with edge cases.
"""
import pytest
import numpy as np
from unittest.mock import MagicMock, patch
# Import the functions to test
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from explorer_helpers import compute_party_centroids
class TestTrajectoryPlotRendering:
"""Tests to ensure trajectory plot renders in various scenarios."""
def test_compute_party_centroids_returns_diagnostics(self):
"""Test that compute_party_centroids returns diagnostics tuple."""
positions_by_window = {
"2024-Q1": {"MP1": (1.0, 2.0), "MP2": (3.0, 4.0)},
"2024-Q2": {"MP1": (1.5, 2.5), "MP2": (3.5, 4.5)},
}
party_map = {"MP1": "PartyA", "MP2": "PartyA"}
windows = ["2024-Q1", "2024-Q2"]
centroids, diagnostics = compute_party_centroids(
positions_by_window, party_map, windows
)
assert isinstance(centroids, dict)
assert isinstance(diagnostics, dict)
assert "windows_with_data_count" in diagnostics
assert diagnostics["windows_with_data_count"] == 2
def test_compute_party_centroids_detects_all_nan_parties(self):
"""Test that diagnostics identify parties with all NaN centroids."""
positions_by_window = {
"2024-Q1": {"MP1": (np.nan, np.nan)},
"2024-Q2": {"MP1": (np.nan, np.nan)},
}
party_map = {"MP1": "PartyA"}
windows = ["2024-Q1", "2024-Q2"]
centroids, diagnostics = compute_party_centroids(
positions_by_window, party_map, windows
)
assert "PartyA" in diagnostics.get("parties_all_nan", [])
if __name__ == "__main__":
pytest.main([__file__, "-v"])

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save