diff --git a/analysis/axis_classifier.py b/analysis/axis_classifier.py new file mode 100644 index 0000000..db26314 --- /dev/null +++ b/analysis/axis_classifier.py @@ -0,0 +1,269 @@ +"""Axis classifier: correlate per-party PCA positions against ideology reference data +to assign honest, dynamic labels to political compass axes. + +Public API: classify_axes(positions_by_window, axes, db_path) -> dict +""" + +import logging +from collections import Counter +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np + +_logger = logging.getLogger(__name__) + +# Module-level caches — loaded once per process lifetime. +_ideology_cache: Optional[Dict[str, Dict[str, float]]] = None +_coalition_cache: Optional[Dict[str, set]] = None + +# Correlation threshold above which we consider an axis "explained" by a dimension. +_THRESHOLD = 0.65 + +_LABELS = { + "lr": "Links\u2013Rechts", + "co": "Coalitie\u2013Oppositie", + "pc": "Progressief\u2013Conservatief", + "fallback_x": "Stempatroon As 1", + "fallback_y": "Stempatroon As 2", +} + +_INTERPRETATION_TEMPLATES = { + "lr": "De {orientation} as weerspiegelt de klassieke links-rechts tegenstelling.", + "co": ( + "De {orientation} as weerspiegelt stemgedrag van coalitie- versus " + "oppositiepartijen (r={r:.2f}). Links-rechts is minder dominant dit jaar." + ), + "pc": "De {orientation} as weerspiegelt de progressief-conservatieve tegenstelling.", + "fallback": ( + "De {orientation} as weerspiegelt een empirisch stempatroon " + "zonder duidelijke ideologische richting." + ), +} + + +def _load_ideology(csv_path: Path) -> Dict[str, Dict[str, float]]: + """Load party ideology scores from CSV. + + Returns {party_name: {"left_right": float, "progressive": float}}. + Returns {} on any error (caller should treat empty as 'skip classification'). + """ + global _ideology_cache + if _ideology_cache is not None: + return _ideology_cache + result: Dict[str, Dict[str, float]] = {} + try: + with open(csv_path, encoding="utf-8") as fh: + lines = fh.read().splitlines() + header = [h.strip() for h in lines[0].split(",")] + lr_idx = header.index("left_right") + pc_idx = header.index("progressive") + for line in lines[1:]: + if not line.strip(): + continue + parts = [p.strip() for p in line.split(",")] + if len(parts) <= max(lr_idx, pc_idx): + continue + result[parts[0]] = { + "left_right": float(parts[lr_idx]), + "progressive": float(parts[pc_idx]), + } + except FileNotFoundError: + _logger.warning( + "party_ideologies.csv not found at %s — axis labels will be generic", + csv_path, + ) + return {} + except Exception as exc: + _logger.warning("Failed to load party_ideologies.csv: %s", exc) + return {} + _ideology_cache = result + return result + + +def _load_coalition(csv_path: Path) -> Dict[str, set]: + """Load coalition membership from CSV. + + Returns {window_id: set_of_party_names}. + Returns {} on any error (coalition dimension will be skipped). + """ + global _coalition_cache + if _coalition_cache is not None: + return _coalition_cache + result: Dict[str, set] = {} + try: + with open(csv_path, encoding="utf-8") as fh: + lines = fh.read().splitlines() + for line in lines[1:]: + if not line.strip(): + continue + parts = [p.strip() for p in line.split(",")] + if len(parts) < 2: + continue + wid, party = parts[0], parts[1] + result.setdefault(wid, set()).add(party) + except FileNotFoundError: + _logger.warning( + "coalition_membership.csv not found at %s — coalition axis detection disabled", + csv_path, + ) + return {} + except Exception as exc: + _logger.warning("Failed to load coalition_membership.csv: %s", exc) + return {} + _coalition_cache = result + return result + + +def _window_year(window_id: str) -> Optional[str]: + """Extract year string from window_id. + + Returns None for 'current_parliament'. + '2016' → '2016', '2016-Q3' → '2016'. + """ + if window_id == "current_parliament": + return None + return window_id.split("-")[0] + + +def _pearsonr(x: List[float], y: List[float]) -> float: + """Pearson r; returns 0.0 for degenerate input (< 3 points or zero variance).""" + if len(x) < 3: + return 0.0 + xa = np.array(x, dtype=float) + ya = np.array(y, dtype=float) + if xa.std() < 1e-12 or ya.std() < 1e-12: + return 0.0 + return float(np.corrcoef(xa, ya)[0, 1]) + + +def _assign_label( + r_lr: float, + r_co: float, + r_pc: float, + axis: str, +) -> Tuple[str, str, float]: + """Assign label, interpretation and quality score for one axis. + + Priority: left-right > coalition > progressive > fallback. + Returns (label, interpretation_string, quality_score). + """ + orientation = "horizontale" if axis == "x" else "verticale" + fallback_label = _LABELS["fallback_x"] if axis == "x" else _LABELS["fallback_y"] + quality = max(abs(r_lr), abs(r_co), abs(r_pc)) + + if abs(r_lr) >= _THRESHOLD: + return ( + _LABELS["lr"], + _INTERPRETATION_TEMPLATES["lr"].format(orientation=orientation), + quality, + ) + if abs(r_co) >= _THRESHOLD: + return ( + _LABELS["co"], + _INTERPRETATION_TEMPLATES["co"].format(orientation=orientation, r=r_co), + quality, + ) + if abs(r_pc) >= _THRESHOLD: + return ( + _LABELS["pc"], + _INTERPRETATION_TEMPLATES["pc"].format(orientation=orientation), + quality, + ) + return ( + fallback_label, + _INTERPRETATION_TEMPLATES["fallback"].format(orientation=orientation), + quality, + ) + + +def classify_axes( + positions_by_window: Dict[str, Dict[str, Tuple[float, float]]], + axes: dict, + db_path: str, +) -> dict: + """Classify compass axes by correlating per-party positions against ideology reference data. + + Enriches ``axes`` with: + x_label, y_label — global label (modal across annual windows) + x_quality, y_quality — {window_id: float} max |r| for each window + x_interpretation — {window_id: str} Dutch explanation per window + y_interpretation — {window_id: str} Dutch explanation per window + + Returns the original ``axes`` dict unchanged if reference data is unavailable. + """ + data_dir = Path(db_path).parent + ideology = _load_ideology(data_dir / "party_ideologies.csv") + if not ideology: + return axes # no reference data — preserve existing behaviour + + coalition = _load_coalition(data_dir / "coalition_membership.csv") + + x_quality: Dict[str, float] = {} + y_quality: Dict[str, float] = {} + x_interpretation: Dict[str, str] = {} + y_interpretation: Dict[str, str] = {} + annual_x_labels: List[str] = [] + annual_y_labels: List[str] = [] + + for wid, pos_dict in positions_by_window.items(): + year = _window_year(wid) + is_current = wid == "current_parliament" + is_annual = not is_current and "-" not in wid # e.g. "2016" not "2016-Q3" + + # Only use parties present in both the positions and the ideology reference. + parties = [p for p in pos_dict if p in ideology] + if len(parties) < 5: + _logger.debug( + "Skipping axis classification for %s: only %d reference parties (need 5)", + wid, + len(parties), + ) + continue + + party_x = [pos_dict[p][0] for p in parties] + party_y = [pos_dict[p][1] for p in parties] + ref_lr = [ideology[p]["left_right"] for p in parties] + ref_pc = [ideology[p]["progressive"] for p in parties] + + # Coalition dummy: +1 if in government that year, -1 otherwise. + # current_parliament and windows with no coalition data use a neutral vector. + if year and coalition and year in coalition: + gov_set = coalition[year] + ref_co = [1.0 if p in gov_set else -1.0 for p in parties] + else: + ref_co = [0.0] * len(parties) # neutral — will never exceed threshold + + r_lr_x = _pearsonr(party_x, ref_lr) + r_co_x = _pearsonr(party_x, ref_co) + r_pc_x = _pearsonr(party_x, ref_pc) + x_lbl, x_int, x_q = _assign_label(r_lr_x, r_co_x, r_pc_x, "x") + + r_lr_y = _pearsonr(party_y, ref_lr) + r_co_y = _pearsonr(party_y, ref_co) + r_pc_y = _pearsonr(party_y, ref_pc) + y_lbl, y_int, y_q = _assign_label(r_lr_y, r_co_y, r_pc_y, "y") + + x_quality[wid] = x_q + y_quality[wid] = y_q + x_interpretation[wid] = x_int + y_interpretation[wid] = y_int + + # Only annual windows vote on the global label (not quarterly, not current_parliament). + if is_annual: + annual_x_labels.append(x_lbl) + annual_y_labels.append(y_lbl) + + def _modal(labels: List[str], fallback: str) -> str: + if not labels: + return fallback + return Counter(labels).most_common(1)[0][0] + + enriched = dict(axes) + enriched["x_label"] = _modal(annual_x_labels, "Links\u2013Rechts") + enriched["y_label"] = _modal(annual_y_labels, "Progressief\u2013Conservatief") + enriched["x_quality"] = x_quality + enriched["y_quality"] = y_quality + enriched["x_interpretation"] = x_interpretation + enriched["y_interpretation"] = y_interpretation + return enriched diff --git a/data/coalition_membership.csv b/data/coalition_membership.csv new file mode 100644 index 0000000..eb16b6d --- /dev/null +++ b/data/coalition_membership.csv @@ -0,0 +1,51 @@ +window_id,party +2012,VVD +2012,PvdA +2013,VVD +2013,PvdA +2014,VVD +2014,PvdA +2015,VVD +2015,PvdA +2016,VVD +2016,PvdA +2017,VVD +2017,CDA +2017,D66 +2017,ChristenUnie +2018,VVD +2018,CDA +2018,D66 +2018,ChristenUnie +2019,VVD +2019,CDA +2019,D66 +2019,ChristenUnie +2020,VVD +2020,CDA +2020,D66 +2020,ChristenUnie +2021,VVD +2021,CDA +2021,D66 +2021,ChristenUnie +2022,VVD +2022,D66 +2022,CDA +2022,ChristenUnie +2023,VVD +2023,D66 +2023,CDA +2023,ChristenUnie +2024,PVV +2024,VVD +2024,NSC +2024,BBB +2025,PVV +2025,VVD +2025,NSC +2025,BBB +2026,PVV +2026,VVD +2026,NSC +2026,BBB diff --git a/data/party_ideologies.csv b/data/party_ideologies.csv new file mode 100644 index 0000000..af10884 --- /dev/null +++ b/data/party_ideologies.csv @@ -0,0 +1,23 @@ +party,left_right,progressive +VVD,0.65,0.10 +PvdA,-0.70,0.75 +SP,-0.90,0.50 +CDA,0.25,-0.45 +D66,-0.10,0.85 +GroenLinks,-0.70,0.90 +GL,-0.70,0.90 +GroenLinks-PvdA,-0.70,0.82 +ChristenUnie,0.10,-0.55 +SGP,0.35,-0.95 +PVV,0.90,-0.50 +DENK,-0.40,0.55 +50Plus,-0.05,-0.10 +FVD,0.90,-0.75 +PvdD,-0.60,0.85 +Volt,-0.20,0.80 +JA21,0.70,-0.30 +BBB,0.50,-0.35 +NSC,0.20,-0.20 +Nieuw Sociaal Contract,0.20,-0.20 +BVNL,0.85,-0.55 +Bij1,-0.90,0.90 diff --git a/tests/test_political_compass.py b/tests/test_political_compass.py index c97d1d0..9bc432b 100644 --- a/tests/test_political_compass.py +++ b/tests/test_political_compass.py @@ -365,3 +365,107 @@ def test_compute_party_discipline_empty_range(monkeypatch): ) assert df.empty + + +# --------------------------------------------------------------------------- +# Tests for analysis.axis_classifier +# --------------------------------------------------------------------------- + +import importlib + + +def _fresh_classifier(monkeypatch): + """Import axis_classifier with cleared module-level caches.""" + import analysis.axis_classifier as _cls + + monkeypatch.setattr(_cls, "_ideology_cache", None) + monkeypatch.setattr(_cls, "_coalition_cache", None) + return _cls + + +def test_axis_label_left_right(tmp_path, monkeypatch): + """Positions that closely correlate with left_right scores → label 'Links–Rechts'.""" + _cls = _fresh_classifier(monkeypatch) + + (tmp_path / "party_ideologies.csv").write_text( + "party,left_right,progressive\n" + "VVD,0.65,0.10\n" + "PvdA,-0.70,0.75\n" + "SP,-0.90,0.50\n" + "PVV,0.90,-0.50\n" + "D66,-0.10,0.85\n" + "CDA,0.25,-0.45\n" + ) + (tmp_path / "coalition_membership.csv").write_text("window_id,party\n") + + # X values are the party's left_right scores — perfect correlation + positions_by_window = { + "2022": { + "VVD": (0.65, 0.10), + "PvdA": (-0.70, 0.20), + "SP": (-0.90, 0.30), + "PVV": (0.90, -0.10), + "D66": (-0.10, 0.40), + "CDA": (0.25, -0.20), + } + } + axes = {"x_axis": None, "y_axis": None, "method": "pca"} + + result = _cls.classify_axes(positions_by_window, axes, str(tmp_path / "motions.db")) + + assert result["x_label"] == "Links\u2013Rechts" + assert result["x_quality"]["2022"] >= 0.65 + + +def test_axis_label_coalition_dominant(tmp_path, monkeypatch): + """Positions that match coalition pattern but NOT left-right → 'Coalitie–Oppositie'.""" + _cls = _fresh_classifier(monkeypatch) + + (tmp_path / "party_ideologies.csv").write_text( + "party,left_right,progressive\n" + "VVD,0.65,0.10\n" + "PvdA,-0.70,0.75\n" + "SP,-0.90,0.50\n" + "PVV,0.90,-0.50\n" + "D66,-0.10,0.85\n" + "CDA,0.25,-0.45\n" + ) + # 2016: Rutte II coalition = VVD + PvdA + (tmp_path / "coalition_membership.csv").write_text( + "window_id,party\n2016,VVD\n2016,PvdA\n" + ) + + # Coalition parties (VVD + PvdA) at x ≈ +1, opposition at x ≈ -1. + # VVD (right) and PvdA (left) are both near +1 → low left_right correlation + # but high coalition correlation. + positions_by_window = { + "2016": { + "VVD": (0.95, 0.10), + "PvdA": (0.90, 0.20), + "SP": (-0.85, 0.30), + "PVV": (-0.95, -0.10), + "D66": (-0.80, 0.40), + "CDA": (-0.75, -0.20), + } + } + axes = {"x_axis": None, "y_axis": None, "method": "pca"} + + result = _cls.classify_axes(positions_by_window, axes, str(tmp_path / "motions.db")) + + assert result["x_label"] == "Coalitie\u2013Oppositie" + assert "coalitie" in result["x_interpretation"]["2016"].lower() + + +def test_axis_classifier_missing_csv(tmp_path, monkeypatch): + """Missing party_ideologies.csv → returns axes dict unchanged, no exception.""" + _cls = _fresh_classifier(monkeypatch) + + # No CSVs written — directory exists but files do not + positions_by_window = {"2022": {"VVD": (1.0, 0.5), "PvdA": (-1.0, 0.3)}} + axes = {"x_axis": None, "y_axis": None, "method": "pca"} + + result = _cls.classify_axes(positions_by_window, axes, str(tmp_path / "motions.db")) + + # Must not crash and must return the original axes dict unchanged + assert result is axes + assert "x_label" not in result