You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/analysis/axis_classifier.py

269 lines
9.3 KiB

"""Axis classifier: correlate per-party PCA positions against ideology reference data
to assign honest, dynamic labels to political compass axes.
Public API: classify_axes(positions_by_window, axes, db_path) -> dict
"""
import logging
from collections import Counter
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
_logger = logging.getLogger(__name__)
# Module-level caches — loaded once per process lifetime.
_ideology_cache: Optional[Dict[str, Dict[str, float]]] = None
_coalition_cache: Optional[Dict[str, set]] = None
# Correlation threshold above which we consider an axis "explained" by a dimension.
_THRESHOLD = 0.65
_LABELS = {
"lr": "Links\u2013Rechts",
"co": "Coalitie\u2013Oppositie",
"pc": "Progressief\u2013Conservatief",
"fallback_x": "Stempatroon As 1",
"fallback_y": "Stempatroon As 2",
}
_INTERPRETATION_TEMPLATES = {
"lr": "De {orientation} as weerspiegelt de klassieke links-rechts tegenstelling.",
"co": (
"De {orientation} as weerspiegelt stemgedrag van coalitie- versus "
"oppositiepartijen (r={r:.2f}). Links-rechts is minder dominant dit jaar."
),
"pc": "De {orientation} as weerspiegelt de progressief-conservatieve tegenstelling.",
"fallback": (
"De {orientation} as weerspiegelt een empirisch stempatroon "
"zonder duidelijke ideologische richting."
),
}
def _load_ideology(csv_path: Path) -> Dict[str, Dict[str, float]]:
"""Load party ideology scores from CSV.
Returns {party_name: {"left_right": float, "progressive": float}}.
Returns {} on any error (caller should treat empty as 'skip classification').
"""
global _ideology_cache
if _ideology_cache is not None:
return _ideology_cache
result: Dict[str, Dict[str, float]] = {}
try:
with open(csv_path, encoding="utf-8") as fh:
lines = fh.read().splitlines()
header = [h.strip() for h in lines[0].split(",")]
lr_idx = header.index("left_right")
pc_idx = header.index("progressive")
for line in lines[1:]:
if not line.strip():
continue
parts = [p.strip() for p in line.split(",")]
if len(parts) <= max(lr_idx, pc_idx):
continue
result[parts[0]] = {
"left_right": float(parts[lr_idx]),
"progressive": float(parts[pc_idx]),
}
except FileNotFoundError:
_logger.warning(
"party_ideologies.csv not found at %s — axis labels will be generic",
csv_path,
)
return {}
except Exception as exc:
_logger.warning("Failed to load party_ideologies.csv: %s", exc)
return {}
_ideology_cache = result
return result
def _load_coalition(csv_path: Path) -> Dict[str, set]:
"""Load coalition membership from CSV.
Returns {window_id: set_of_party_names}.
Returns {} on any error (coalition dimension will be skipped).
"""
global _coalition_cache
if _coalition_cache is not None:
return _coalition_cache
result: Dict[str, set] = {}
try:
with open(csv_path, encoding="utf-8") as fh:
lines = fh.read().splitlines()
for line in lines[1:]:
if not line.strip():
continue
parts = [p.strip() for p in line.split(",")]
if len(parts) < 2:
continue
wid, party = parts[0], parts[1]
result.setdefault(wid, set()).add(party)
except FileNotFoundError:
_logger.warning(
"coalition_membership.csv not found at %s — coalition axis detection disabled",
csv_path,
)
return {}
except Exception as exc:
_logger.warning("Failed to load coalition_membership.csv: %s", exc)
return {}
_coalition_cache = result
return result
def _window_year(window_id: str) -> Optional[str]:
"""Extract year string from window_id.
Returns None for 'current_parliament'.
'2016''2016', '2016-Q3''2016'.
"""
if window_id == "current_parliament":
return None
return window_id.split("-")[0]
def _pearsonr(x: List[float], y: List[float]) -> float:
"""Pearson r; returns 0.0 for degenerate input (< 3 points or zero variance)."""
if len(x) < 3:
return 0.0
xa = np.array(x, dtype=float)
ya = np.array(y, dtype=float)
if xa.std() < 1e-12 or ya.std() < 1e-12:
return 0.0
return float(np.corrcoef(xa, ya)[0, 1])
def _assign_label(
r_lr: float,
r_co: float,
r_pc: float,
axis: str,
) -> Tuple[str, str, float]:
"""Assign label, interpretation and quality score for one axis.
Priority: left-right > coalition > progressive > fallback.
Returns (label, interpretation_string, quality_score).
"""
orientation = "horizontale" if axis == "x" else "verticale"
fallback_label = _LABELS["fallback_x"] if axis == "x" else _LABELS["fallback_y"]
quality = max(abs(r_lr), abs(r_co), abs(r_pc))
if abs(r_lr) >= _THRESHOLD:
return (
_LABELS["lr"],
_INTERPRETATION_TEMPLATES["lr"].format(orientation=orientation),
quality,
)
if abs(r_co) >= _THRESHOLD:
return (
_LABELS["co"],
_INTERPRETATION_TEMPLATES["co"].format(orientation=orientation, r=r_co),
quality,
)
if abs(r_pc) >= _THRESHOLD:
return (
_LABELS["pc"],
_INTERPRETATION_TEMPLATES["pc"].format(orientation=orientation),
quality,
)
return (
fallback_label,
_INTERPRETATION_TEMPLATES["fallback"].format(orientation=orientation),
quality,
)
def classify_axes(
positions_by_window: Dict[str, Dict[str, Tuple[float, float]]],
axes: dict,
db_path: str,
) -> dict:
"""Classify compass axes by correlating per-party positions against ideology reference data.
Enriches ``axes`` with:
x_label, y_label — global label (modal across annual windows)
x_quality, y_quality — {window_id: float} max |r| for each window
x_interpretation — {window_id: str} Dutch explanation per window
y_interpretation — {window_id: str} Dutch explanation per window
Returns the original ``axes`` dict unchanged if reference data is unavailable.
"""
data_dir = Path(db_path).parent
ideology = _load_ideology(data_dir / "party_ideologies.csv")
if not ideology:
return axes # no reference data — preserve existing behaviour
coalition = _load_coalition(data_dir / "coalition_membership.csv")
x_quality: Dict[str, float] = {}
y_quality: Dict[str, float] = {}
x_interpretation: Dict[str, str] = {}
y_interpretation: Dict[str, str] = {}
annual_x_labels: List[str] = []
annual_y_labels: List[str] = []
for wid, pos_dict in positions_by_window.items():
year = _window_year(wid)
is_current = wid == "current_parliament"
is_annual = not is_current and "-" not in wid # e.g. "2016" not "2016-Q3"
# Only use parties present in both the positions and the ideology reference.
parties = [p for p in pos_dict if p in ideology]
if len(parties) < 5:
_logger.debug(
"Skipping axis classification for %s: only %d reference parties (need 5)",
wid,
len(parties),
)
continue
party_x = [pos_dict[p][0] for p in parties]
party_y = [pos_dict[p][1] for p in parties]
ref_lr = [ideology[p]["left_right"] for p in parties]
ref_pc = [ideology[p]["progressive"] for p in parties]
# Coalition dummy: +1 if in government that year, -1 otherwise.
# current_parliament and windows with no coalition data use a neutral vector.
if year and coalition and year in coalition:
gov_set = coalition[year]
ref_co = [1.0 if p in gov_set else -1.0 for p in parties]
else:
ref_co = [0.0] * len(parties) # neutral — will never exceed threshold
r_lr_x = _pearsonr(party_x, ref_lr)
r_co_x = _pearsonr(party_x, ref_co)
r_pc_x = _pearsonr(party_x, ref_pc)
x_lbl, x_int, x_q = _assign_label(r_lr_x, r_co_x, r_pc_x, "x")
r_lr_y = _pearsonr(party_y, ref_lr)
r_co_y = _pearsonr(party_y, ref_co)
r_pc_y = _pearsonr(party_y, ref_pc)
y_lbl, y_int, y_q = _assign_label(r_lr_y, r_co_y, r_pc_y, "y")
x_quality[wid] = x_q
y_quality[wid] = y_q
x_interpretation[wid] = x_int
y_interpretation[wid] = y_int
# Only annual windows vote on the global label (not quarterly, not current_parliament).
if is_annual:
annual_x_labels.append(x_lbl)
annual_y_labels.append(y_lbl)
def _modal(labels: List[str], fallback: str) -> str:
if not labels:
return fallback
return Counter(labels).most_common(1)[0][0]
enriched = dict(axes)
enriched["x_label"] = _modal(annual_x_labels, "Links\u2013Rechts")
enriched["y_label"] = _modal(annual_y_labels, "Progressief\u2013Conservatief")
enriched["x_quality"] = x_quality
enriched["y_quality"] = y_quality
enriched["x_interpretation"] = x_interpretation
enriched["y_interpretation"] = y_interpretation
return enriched