You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
3.0 KiB
87 lines
3.0 KiB
"""
|
|
Test that trajectory plot renders even with edge cases.
|
|
"""
|
|
|
|
import pytest
|
|
import numpy as np
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
# Import the functions to test
|
|
import sys
|
|
|
|
sys.path.insert(0, "/home/sgeboers/Projects/stemwijzer")
|
|
|
|
from explorer_helpers import compute_party_centroids
|
|
|
|
|
|
class TestTrajectoryPlotRendering:
|
|
"""Tests to ensure trajectory plot renders in various scenarios."""
|
|
|
|
def test_compute_party_centroids_returns_diagnostics(self):
|
|
"""Test that compute_party_centroids returns diagnostics tuple."""
|
|
positions_by_window = {
|
|
"2024-Q1": {"MP1": (1.0, 2.0), "MP2": (3.0, 4.0)},
|
|
"2024-Q2": {"MP1": (1.5, 2.5), "MP2": (3.5, 4.5)},
|
|
}
|
|
party_map = {"MP1": "PartyA", "MP2": "PartyA"}
|
|
windows = ["2024-Q1", "2024-Q2"]
|
|
|
|
centroids, diagnostics = compute_party_centroids(
|
|
positions_by_window, party_map, windows
|
|
)
|
|
|
|
assert isinstance(centroids, dict)
|
|
assert isinstance(diagnostics, dict)
|
|
assert "windows_with_data_count" in diagnostics
|
|
assert diagnostics["windows_with_data_count"] == 2
|
|
|
|
def test_compute_party_centroids_detects_all_nan_parties(self):
|
|
"""Test that diagnostics identify parties with all NaN centroids."""
|
|
positions_by_window = {
|
|
"2024-Q1": {"MP1": (np.nan, np.nan)},
|
|
"2024-Q2": {"MP1": (np.nan, np.nan)},
|
|
}
|
|
party_map = {"MP1": "PartyA"}
|
|
windows = ["2024-Q1", "2024-Q2"]
|
|
|
|
centroids, diagnostics = compute_party_centroids(
|
|
positions_by_window, party_map, windows
|
|
)
|
|
|
|
assert "PartyA" in diagnostics.get("parties_all_nan", [])
|
|
|
|
def test_name_normalization_improves_matching(self):
|
|
"""Test that normalized names improve party matching."""
|
|
# Positions with slightly different name format
|
|
positions_by_window = {
|
|
"2024-Q1": {"Agema, M.": (1.0, 2.0)},
|
|
}
|
|
# Party map with different spacing
|
|
party_map = {"Agema, M.": "PVV"} # Without normalization, this might not match
|
|
|
|
# After normalization, they should match
|
|
def normalize_mp_name(name):
|
|
if not name:
|
|
return name
|
|
name = name.strip()
|
|
if "," in name and ", " not in name:
|
|
name = name.replace(",", ", ")
|
|
return name
|
|
|
|
normalized_party_map = {normalize_mp_name(k): v for k, v in party_map.items()}
|
|
normalized_positions = {
|
|
window: {normalize_mp_name(k): v for k, v in positions.items()}
|
|
for window, positions in positions_by_window.items()
|
|
}
|
|
|
|
# Check matching
|
|
all_mp_names = set()
|
|
for positions in normalized_positions.values():
|
|
all_mp_names.update(positions.keys())
|
|
|
|
matched = sum(1 for mp in all_mp_names if mp in normalized_party_map)
|
|
assert matched > 0, "Name normalization should improve matching"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|
|
|