""" Test that trajectory plot renders even with edge cases. """ import pytest import numpy as np from unittest.mock import MagicMock, patch # Import the functions to test import sys sys.path.insert(0, "/home/sgeboers/Projects/stemwijzer") from explorer_helpers import compute_party_centroids class TestTrajectoryPlotRendering: """Tests to ensure trajectory plot renders in various scenarios.""" def test_compute_party_centroids_returns_diagnostics(self): """Test that compute_party_centroids returns diagnostics tuple.""" positions_by_window = { "2024-Q1": {"MP1": (1.0, 2.0), "MP2": (3.0, 4.0)}, "2024-Q2": {"MP1": (1.5, 2.5), "MP2": (3.5, 4.5)}, } party_map = {"MP1": "PartyA", "MP2": "PartyA"} windows = ["2024-Q1", "2024-Q2"] centroids, diagnostics = compute_party_centroids( positions_by_window, party_map, windows ) assert isinstance(centroids, dict) assert isinstance(diagnostics, dict) assert "windows_with_data_count" in diagnostics assert diagnostics["windows_with_data_count"] == 2 def test_compute_party_centroids_detects_all_nan_parties(self): """Test that diagnostics identify parties with all NaN centroids.""" positions_by_window = { "2024-Q1": {"MP1": (np.nan, np.nan)}, "2024-Q2": {"MP1": (np.nan, np.nan)}, } party_map = {"MP1": "PartyA"} windows = ["2024-Q1", "2024-Q2"] centroids, diagnostics = compute_party_centroids( positions_by_window, party_map, windows ) assert "PartyA" in diagnostics.get("parties_all_nan", []) def test_name_normalization_improves_matching(self): """Test that normalized names improve party matching.""" # Positions with slightly different name format positions_by_window = { "2024-Q1": {"Agema, M.": (1.0, 2.0)}, } # Party map with different spacing party_map = {"Agema, M.": "PVV"} # Without normalization, this might not match # After normalization, they should match def normalize_mp_name(name): if not name: return name name = name.strip() if "," in name and ", " not in name: name = name.replace(",", ", ") return name normalized_party_map = {normalize_mp_name(k): v for k, v in party_map.items()} normalized_positions = { window: {normalize_mp_name(k): v for k, v in positions.items()} for window, positions in positions_by_window.items() } # Check matching all_mp_names = set() for positions in normalized_positions.values(): all_mp_names.update(positions.keys()) matched = sum(1 for mp in all_mp_names if mp in normalized_party_map) assert matched > 0, "Name normalization should improve matching" if __name__ == "__main__": pytest.main([__file__, "-v"])