Extract _load_mp_vectors_by_party helper and fix cache key

- Extract shared helper that both load_party_axis_scores and
  load_party_mp_vectors delegate to, eliminating ~40 lines of
  duplicated DB query + vector parsing code
- Remove dead code in load_party_axis_scores that queried mp_metadata
  twice (first without ORDER BY, then again with ORDER BY, overwriting)
- Fix _cached_bootstrap_cis parameter: remove _ prefix so Streamlit
  actually hashes the input dict instead of caching with no key
main
Sven Geboers 1 month ago
parent 3938eecc53
commit b7129b3755
  1. 136
      explorer.py

@ -397,53 +397,29 @@ def compute_party_discipline(
pass
@st.cache_data(show_spinner="Partijposities op SVD-assen laden…")
def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
"""Return per-party SVD vectors, computed as mean of individual MP vectors.
def _load_mp_vectors_by_party(db_path: str) -> Dict[str, List[np.ndarray]]:
"""Load individual MP SVD vectors grouped by party.
Loads individual MP rows (entity_id LIKE '%,%') from window='current_parliament',
assigns each MP their party using the dominant party from mp_votes, then
averages SVD vectors per party.
This matches the political compass data source (also averages individual MPs),
so axis rankings are consistent between the SVD tab and the compass.
Queries mp_metadata for the mpparty mapping (latest assignment during the
current parliament), normalises party names, loads SVD vectors from the
``current_parliament`` window, and filters to CURRENT_PARLIAMENT_PARTIES.
Returns:
{party_name: [float * k]} k = 50, mean over all MPs in that party.
{party_name: [np.ndarray(50,), ...]} one array per MP.
"""
con = duckdb.connect(database=db_path, read_only=True)
try:
con = duckdb.connect(database=db_path, read_only=True)
# Build mp → party mapping from mp_metadata (most recent party during current parliament).
# mp_metadata format: mp_name like "Van Baarle, S.R.T.", party = "GroenLinks-PvdA"
# We take the party record with the latest `van` date (most recent assignment).
# Build mp → party mapping. ORDER BY van ASC so latest assignment wins
# via last-write-wins when an MP switched party.
meta_rows = con.execute(
"SELECT mp_name, party FROM mp_metadata "
"WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22'"
).fetchall()
# For MPs with multiple records (party switches), keep the one with latest van date.
# Simple approach: last-write-wins per mp_name after sorting by van ascending.
mp_party_raw: Dict[str, str] = {}
for mp_name, party in meta_rows:
if mp_name and party:
mp_party_raw[mp_name] = party # later rows (after ORDER BY van) win
# Re-query ordered so latest van wins reliably
meta_ordered = con.execute(
"SELECT mp_name, party FROM mp_metadata "
"WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' "
"ORDER BY van ASC"
).fetchall()
mp_party_raw = {}
for mp_name, party in meta_ordered:
if mp_name and party:
mp_party_raw[mp_name] = party
# Normalize party names to canonical abbreviations
mp_party: Dict[str, str] = {}
for mp_name, party in mp_party_raw.items():
canonical = _PARTY_NORMALIZE.get(party, party)
mp_party[mp_name] = canonical
for mp_name, party in meta_rows:
if mp_name and party:
mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party)
# Individual MP vectors from current_parliament
rows = con.execute(
@ -451,7 +427,7 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
"WHERE entity_type='mp' AND window_id='current_parliament'"
).fetchall()
party_vecs: Dict[str, list] = {}
party_vecs: Dict[str, List[np.ndarray]] = {}
for entity_id, raw_vec in rows:
party = mp_party.get(entity_id)
if party is None or party not in CURRENT_PARLIAMENT_PARTIES:
@ -467,17 +443,10 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
vec = list(raw_vec)
except Exception:
continue
fvec = [float(v) if v is not None else 0.0 for v in vec]
fvec = np.array([float(v) if v is not None else 0.0 for v in vec])
party_vecs.setdefault(party, []).append(fvec)
# Average vectors per party
result: Dict[str, List[float]] = {}
for party, vecs in party_vecs.items():
result[party] = np.array(vecs).mean(axis=0).tolist()
return result
except Exception:
logger.exception("Failed to load party axis scores")
return {}
return party_vecs
finally:
try:
con.close()
@ -485,75 +454,52 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
pass
@st.cache_data(show_spinner="Partij-MP vectoren laden…")
def load_party_mp_vectors(db_path: str) -> Dict[str, List[np.ndarray]]:
"""Return per-party lists of individual MP SVD vectors.
@st.cache_data(show_spinner="Partijposities op SVD-assen laden…")
def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
"""Return per-party SVD vectors, computed as mean of individual MP vectors.
Same MPparty mapping as load_party_axis_scores(), but returns the raw
per-MP vectors instead of averaging them. Suitable for bootstrap CI
computation.
Loads individual MP rows from window='current_parliament', assigns each MP
their party, then averages SVD vectors per party.
Returns:
{party_name: [np.ndarray(50,), ...]} one array per MP.
{party_name: [float * k]} k = 50, mean over all MPs in that party.
"""
try:
con = duckdb.connect(database=db_path, read_only=True)
party_vecs = _load_mp_vectors_by_party(db_path)
return {
party: np.array(vecs).mean(axis=0).tolist()
for party, vecs in party_vecs.items()
}
except Exception:
logger.exception("Failed to load party axis scores")
return {}
# Build mp → party mapping (same logic as load_party_axis_scores)
meta_ordered = con.execute(
"SELECT mp_name, party FROM mp_metadata "
"WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' "
"ORDER BY van ASC"
).fetchall()
mp_party: Dict[str, str] = {}
for mp_name, party in meta_ordered:
if mp_name and party:
mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party)
# Individual MP vectors from current_parliament
rows = con.execute(
"SELECT entity_id, vector FROM svd_vectors "
"WHERE entity_type='mp' AND window_id='current_parliament'"
).fetchall()
@st.cache_data(show_spinner="Partij-MP vectoren laden…")
def load_party_mp_vectors(db_path: str) -> Dict[str, List[np.ndarray]]:
"""Return per-party lists of individual MP SVD vectors.
party_vecs: Dict[str, List[np.ndarray]] = {}
for entity_id, raw_vec in rows:
party = mp_party.get(entity_id)
if party is None or party not in CURRENT_PARLIAMENT_PARTIES:
continue
if isinstance(raw_vec, str):
vec = json.loads(raw_vec)
elif isinstance(raw_vec, (bytes, bytearray)):
vec = json.loads(raw_vec.decode())
elif isinstance(raw_vec, list):
vec = raw_vec
else:
try:
vec = list(raw_vec)
except Exception:
continue
fvec = np.array([float(v) if v is not None else 0.0 for v in vec])
party_vecs.setdefault(party, []).append(fvec)
Same MPparty mapping as load_party_axis_scores(), suitable for bootstrap
CI computation.
return party_vecs
Returns:
{party_name: [np.ndarray(50,), ...]} one array per MP.
"""
try:
return _load_mp_vectors_by_party(db_path)
except Exception:
logger.exception("Failed to load party MP vectors")
return {}
finally:
try:
con.close()
except Exception:
pass
@st.cache_data(show_spinner="Bootstrap CI berekenen…")
def _cached_bootstrap_cis(
_party_mp_vectors: Dict[str, List[np.ndarray]],
party_mp_vectors: Dict[str, List[np.ndarray]],
) -> Dict[str, Dict]:
"""Thin caching wrapper around compute_party_bootstrap_cis."""
from analysis.political_axis import compute_party_bootstrap_cis
return compute_party_bootstrap_cis(_party_mp_vectors)
return compute_party_bootstrap_cis(party_mp_vectors)
@st.cache_data(show_spinner="Scree-plot laden…")

Loading…
Cancel
Save