You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/explorer_helpers.py

297 lines
10 KiB

"""Helper utilities used by explorer.py.
Primary export:
- compute_party_coords: compute per-party (x_mean, y_mean) from positions_by_window.
This module is intentionally free of Streamlit side-effects to be easy to unit test.
"""
from __future__ import annotations
import logging
import math
import re
from typing import Any, Dict, List, Optional, Set, Tuple
import numpy as np
logger = logging.getLogger(__name__)
def normalize_positions(
positions_by_window: Dict[str, Dict[str, Tuple[Any, Any]]],
clamp_abs_value: float = 1e3,
null_tokens: tuple = ("nan", "NaN", "None", "none", "null", ""),
) -> Dict[str, Dict[str, Tuple[float, float]]]:
"""Normalize a positions_by_window structure.
- Coerce numeric strings to floats.
- Treat common null tokens and None as np.nan.
- Decode bytes/bytearray if necessary (best-effort).
- Clamp very large absolute values to [-clamp_abs_value, clamp_abs_value].
- Preserve entity keys; any uncoercible coords become (np.nan, np.nan).
Returns a new positions_by_window mapping with floats or np.nan values.
Pure and import-safe (no IO).
"""
def _coerce(val: Any) -> float:
if val is None:
return float(np.nan)
if isinstance(val, (float, int, np.floating, np.integer)):
v = float(val)
if math.isnan(v) or math.isinf(v):
return float(np.nan)
if abs(v) > clamp_abs_value:
return float(np.nan)
return v
if isinstance(val, (bytes, bytearray)):
try:
s = val.decode()
except Exception:
return float(np.nan)
val = s
if isinstance(val, str):
s = val.strip()
if s in null_tokens:
return float(np.nan)
try:
v = float(s)
except Exception:
return float(np.nan)
if math.isnan(v) or math.isinf(v):
return float(np.nan)
if abs(v) > clamp_abs_value:
return float(np.nan)
return v
return float(np.nan)
out: Dict[str, Dict[str, Tuple[float, float]]] = {}
for wid, mapping in (positions_by_window or {}).items():
win_map: Dict[str, Tuple[float, float]] = {}
if not mapping:
out[wid] = win_map
continue
for ent, xy in mapping.items():
try:
if xy is None:
x_raw = y_raw = None
else:
x_raw = xy[0] if len(xy) > 0 else None
y_raw = xy[1] if len(xy) > 1 else None
except Exception:
x_raw = y_raw = None
x = _coerce(x_raw)
y = _coerce(y_raw)
win_map[ent] = (x, y)
out[wid] = win_map
return out
def _strip_paren(s: str) -> str:
# helper used in plan to try to strip parenthetical variants
return s.split("(")[0].strip()
def inspect_positions_for_issues(
positions_by_window: Dict[str, Dict[str, Tuple[float, float]]],
party_map: Dict[str, str],
) -> Dict[str, Any]:
"""Inspect positions_by_window for simple issues/summary.
Returns a dictionary with keys including the previous ones (windows_count,
window_labels, mp_id_set, party_map_count, parties_with_centroid_counts,
mismatched_mp_ids_sample) plus:
- mp_positions_count: int (num unique MP ids seen)
- mp_positions_sample: list[str] (sorted sample up to 10)
- windows_with_no_positions: list[str]
This helper remains pure and import-safe so unit tests can exercise it.
"""
windows = list(positions_by_window.keys())
windows_count = len(windows)
window_labels = sorted(windows)[:10]
mp_id_set: Set[str] = set()
parties_with_centroid_counts: Dict[str, int] = {}
mismatched: Set[str] = set()
windows_with_no_positions: List[str] = []
for win, pos in positions_by_window.items():
if not pos:
windows_with_no_positions.append(win)
continue
present_parties: Set[str] = set()
for ent in pos.keys():
if not ent:
continue
mp_id_set.add(ent)
party = party_map.get(ent)
if party is None:
# try stripping paren variant
party = party_map.get(_strip_paren(ent))
if party:
present_parties.add(party)
else:
mismatched.add(ent)
for p in present_parties:
parties_with_centroid_counts[p] = parties_with_centroid_counts.get(p, 0) + 1
mismatched_mp_ids_sample = sorted(list(mismatched))[:10]
mp_positions_sample = sorted(list(mp_id_set))[:10]
mp_positions_count = len(mp_id_set)
return {
"windows_count": windows_count,
"window_labels": window_labels,
"mp_id_set": mp_id_set,
"party_map_count": len(party_map),
"parties_with_centroid_counts": parties_with_centroid_counts,
"mismatched_mp_ids_sample": mismatched_mp_ids_sample,
"mp_positions_sample": mp_positions_sample,
"mp_positions_count": mp_positions_count,
"windows_with_no_positions": windows_with_no_positions,
}
def compute_party_coords(
positions_by_window: Dict[str, Dict[str, Tuple[float, float]]],
party_map: Dict[str, str],
window_id: str,
fallback_party_scores: Optional[Dict[str, List[float]]] = None,
) -> Tuple[Dict[str, Tuple[float, float]], Set[str]]:
"""
Compute per-party centroids (x_mean, y_mean) for a specific window.
Args:
positions_by_window: mapping window_id -> {entity_name: (x, y)}
party_map: mapping mp_name -> party abbreviation (Normalized)
window_id: which window to compute centroids for (key into positions_by_window)
fallback_party_scores: optional mapping party -> numeric vector (len>=2). When a
party has no MPs in the window and fallback_party_scores contains an entry,
the first two elements of that vector will be used as a fallback (x,y).
Returns:
(party_coords, fallback_used) where:
- party_coords: {party: (x_mean, y_mean)} for parties with a computed coord or fallback.
- fallback_used: set of party names where fallback_party_scores was used.
"""
pos = positions_by_window.get(window_id, {}) or {}
per_party: Dict[str, List[Tuple[float, float]]] = {}
for ent, xy in pos.items():
if not ent or xy is None:
continue
try:
x, y = float(xy[0]), float(xy[1])
except Exception:
# skip malformed coords
continue
party = party_map.get(ent)
if party is None:
# try stripped name fallback
party = party_map.get(_strip_paren(ent))
if not party or party == "Unknown":
continue
per_party.setdefault(party, []).append((x, y))
party_coords: Dict[str, Tuple[float, float]] = {}
fallback_used: Set[str] = set()
# compute means for parties that have MPs
for party, coords in per_party.items():
xs = [c[0] for c in coords]
ys = [c[1] for c in coords]
# defensive: drop nan/inf
xs = [float(x) for x in xs if not (math.isnan(x) or math.isinf(x))]
ys = [float(y) for y in ys if not (math.isnan(y) or math.isinf(y))]
if not xs or not ys:
continue
party_coords[party] = (float(np.mean(xs)), float(np.mean(ys)))
# fallback: use supplied party vectors if a party has no MPs in this window
if fallback_party_scores:
for party, vec in fallback_party_scores.items():
if party in party_coords:
continue
if not vec:
continue
try:
# vec may be list, np.array, etc.
if len(vec) >= 2:
x_f, y_f = float(vec[0]), float(vec[1])
if (
math.isnan(x_f)
or math.isnan(y_f)
or math.isinf(x_f)
or math.isinf(y_f)
):
continue
party_coords[party] = (x_f, y_f)
fallback_used.add(party)
except Exception:
continue
if fallback_used:
logger.warning(
"compute_party_coords used fallback for parties: %s",
sorted(list(fallback_used)),
)
return party_coords, fallback_used
def compute_party_centroids(
positions_by_window: Dict[str, Dict[str, Tuple[float, float]]],
party_map: Dict[str, str],
windows: List[str],
) -> Tuple[Dict[str, List[Tuple[float, float]]], Dict[str, Any]]:
"""Compute per-party centroids across multiple windows.
Returns (party_centroids, metadata)
- party_centroids: mapping party -> list of (x,y) tuples of length len(windows).
Entries without MPs are (np.nan, np.nan).
- metadata: dict with keys 'per_party_counts', 'total_windows', 'parties'
"""
party_centroids: Dict[str, List[Tuple[float, float]]] = {}
# collect all parties from party_map values
parties = sorted(set(party_map.values()))
# if no parties known, return empty dict but still metadata
if not parties:
return {}, {
"per_party_counts": {},
"total_windows": len(windows),
"parties": [],
}
# initialize lists
for p in parties:
party_centroids[p] = []
# for each window, compute party coords using compute_party_coords for that window
for w in windows:
coords, _ = compute_party_coords(positions_by_window or {}, party_map, w)
for p in parties:
if p in coords:
# ensure numeric floats
party_centroids[p].append((float(coords[p][0]), float(coords[p][1])))
else:
party_centroids[p].append((float(np.nan), float(np.nan)))
# metadata
per_party_counts: Dict[str, int] = {}
for p, vals in party_centroids.items():
count = 0
for x, y in vals:
if not (np.isnan(x) or np.isnan(y)):
count += 1
per_party_counts[p] = count
metadata = {
"per_party_counts": per_party_counts,
"total_windows": len(windows),
"parties": parties,
}
return party_centroids, metadata