Extract _load_mp_vectors_by_party helper and fix cache key

- Extract shared helper that both load_party_axis_scores and
  load_party_mp_vectors delegate to, eliminating ~40 lines of
  duplicated DB query + vector parsing code
- Remove dead code in load_party_axis_scores that queried mp_metadata
  twice (first without ORDER BY, then again with ORDER BY, overwriting)
- Fix _cached_bootstrap_cis parameter: remove _ prefix so Streamlit
  actually hashes the input dict instead of caching with no key
main
Sven Geboers 1 month ago
parent 3938eecc53
commit b7129b3755
  1. 136
      explorer.py

@ -397,53 +397,29 @@ def compute_party_discipline(
pass pass
@st.cache_data(show_spinner="Partijposities op SVD-assen laden…") def _load_mp_vectors_by_party(db_path: str) -> Dict[str, List[np.ndarray]]:
def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]: """Load individual MP SVD vectors grouped by party.
"""Return per-party SVD vectors, computed as mean of individual MP vectors.
Loads individual MP rows (entity_id LIKE '%,%') from window='current_parliament', Queries mp_metadata for the mpparty mapping (latest assignment during the
assigns each MP their party using the dominant party from mp_votes, then current parliament), normalises party names, loads SVD vectors from the
averages SVD vectors per party. ``current_parliament`` window, and filters to CURRENT_PARLIAMENT_PARTIES.
This matches the political compass data source (also averages individual MPs),
so axis rankings are consistent between the SVD tab and the compass.
Returns: Returns:
{party_name: [float * k]} k = 50, mean over all MPs in that party. {party_name: [np.ndarray(50,), ...]} one array per MP.
""" """
con = duckdb.connect(database=db_path, read_only=True)
try: try:
con = duckdb.connect(database=db_path, read_only=True) # Build mp → party mapping. ORDER BY van ASC so latest assignment wins
# via last-write-wins when an MP switched party.
# Build mp → party mapping from mp_metadata (most recent party during current parliament).
# mp_metadata format: mp_name like "Van Baarle, S.R.T.", party = "GroenLinks-PvdA"
# We take the party record with the latest `van` date (most recent assignment).
meta_rows = con.execute( meta_rows = con.execute(
"SELECT mp_name, party FROM mp_metadata "
"WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22'"
).fetchall()
# For MPs with multiple records (party switches), keep the one with latest van date.
# Simple approach: last-write-wins per mp_name after sorting by van ascending.
mp_party_raw: Dict[str, str] = {}
for mp_name, party in meta_rows:
if mp_name and party:
mp_party_raw[mp_name] = party # later rows (after ORDER BY van) win
# Re-query ordered so latest van wins reliably
meta_ordered = con.execute(
"SELECT mp_name, party FROM mp_metadata " "SELECT mp_name, party FROM mp_metadata "
"WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' " "WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' "
"ORDER BY van ASC" "ORDER BY van ASC"
).fetchall() ).fetchall()
mp_party_raw = {}
for mp_name, party in meta_ordered:
if mp_name and party:
mp_party_raw[mp_name] = party
# Normalize party names to canonical abbreviations
mp_party: Dict[str, str] = {} mp_party: Dict[str, str] = {}
for mp_name, party in mp_party_raw.items(): for mp_name, party in meta_rows:
canonical = _PARTY_NORMALIZE.get(party, party) if mp_name and party:
mp_party[mp_name] = canonical mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party)
# Individual MP vectors from current_parliament # Individual MP vectors from current_parliament
rows = con.execute( rows = con.execute(
@ -451,7 +427,7 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
"WHERE entity_type='mp' AND window_id='current_parliament'" "WHERE entity_type='mp' AND window_id='current_parliament'"
).fetchall() ).fetchall()
party_vecs: Dict[str, list] = {} party_vecs: Dict[str, List[np.ndarray]] = {}
for entity_id, raw_vec in rows: for entity_id, raw_vec in rows:
party = mp_party.get(entity_id) party = mp_party.get(entity_id)
if party is None or party not in CURRENT_PARLIAMENT_PARTIES: if party is None or party not in CURRENT_PARLIAMENT_PARTIES:
@ -467,17 +443,10 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
vec = list(raw_vec) vec = list(raw_vec)
except Exception: except Exception:
continue continue
fvec = [float(v) if v is not None else 0.0 for v in vec] fvec = np.array([float(v) if v is not None else 0.0 for v in vec])
party_vecs.setdefault(party, []).append(fvec) party_vecs.setdefault(party, []).append(fvec)
# Average vectors per party return party_vecs
result: Dict[str, List[float]] = {}
for party, vecs in party_vecs.items():
result[party] = np.array(vecs).mean(axis=0).tolist()
return result
except Exception:
logger.exception("Failed to load party axis scores")
return {}
finally: finally:
try: try:
con.close() con.close()
@ -485,75 +454,52 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
pass pass
@st.cache_data(show_spinner="Partij-MP vectoren laden…") @st.cache_data(show_spinner="Partijposities op SVD-assen laden…")
def load_party_mp_vectors(db_path: str) -> Dict[str, List[np.ndarray]]: def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
"""Return per-party lists of individual MP SVD vectors. """Return per-party SVD vectors, computed as mean of individual MP vectors.
Same MPparty mapping as load_party_axis_scores(), but returns the raw Loads individual MP rows from window='current_parliament', assigns each MP
per-MP vectors instead of averaging them. Suitable for bootstrap CI their party, then averages SVD vectors per party.
computation.
Returns: Returns:
{party_name: [np.ndarray(50,), ...]} one array per MP. {party_name: [float * k]} k = 50, mean over all MPs in that party.
""" """
try: try:
con = duckdb.connect(database=db_path, read_only=True) party_vecs = _load_mp_vectors_by_party(db_path)
return {
party: np.array(vecs).mean(axis=0).tolist()
for party, vecs in party_vecs.items()
}
except Exception:
logger.exception("Failed to load party axis scores")
return {}
# Build mp → party mapping (same logic as load_party_axis_scores)
meta_ordered = con.execute(
"SELECT mp_name, party FROM mp_metadata "
"WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' "
"ORDER BY van ASC"
).fetchall()
mp_party: Dict[str, str] = {}
for mp_name, party in meta_ordered:
if mp_name and party:
mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party)
# Individual MP vectors from current_parliament @st.cache_data(show_spinner="Partij-MP vectoren laden…")
rows = con.execute( def load_party_mp_vectors(db_path: str) -> Dict[str, List[np.ndarray]]:
"SELECT entity_id, vector FROM svd_vectors " """Return per-party lists of individual MP SVD vectors.
"WHERE entity_type='mp' AND window_id='current_parliament'"
).fetchall()
party_vecs: Dict[str, List[np.ndarray]] = {} Same MPparty mapping as load_party_axis_scores(), suitable for bootstrap
for entity_id, raw_vec in rows: CI computation.
party = mp_party.get(entity_id)
if party is None or party not in CURRENT_PARLIAMENT_PARTIES:
continue
if isinstance(raw_vec, str):
vec = json.loads(raw_vec)
elif isinstance(raw_vec, (bytes, bytearray)):
vec = json.loads(raw_vec.decode())
elif isinstance(raw_vec, list):
vec = raw_vec
else:
try:
vec = list(raw_vec)
except Exception:
continue
fvec = np.array([float(v) if v is not None else 0.0 for v in vec])
party_vecs.setdefault(party, []).append(fvec)
return party_vecs Returns:
{party_name: [np.ndarray(50,), ...]} one array per MP.
"""
try:
return _load_mp_vectors_by_party(db_path)
except Exception: except Exception:
logger.exception("Failed to load party MP vectors") logger.exception("Failed to load party MP vectors")
return {} return {}
finally:
try:
con.close()
except Exception:
pass
@st.cache_data(show_spinner="Bootstrap CI berekenen…") @st.cache_data(show_spinner="Bootstrap CI berekenen…")
def _cached_bootstrap_cis( def _cached_bootstrap_cis(
_party_mp_vectors: Dict[str, List[np.ndarray]], party_mp_vectors: Dict[str, List[np.ndarray]],
) -> Dict[str, Dict]: ) -> Dict[str, Dict]:
"""Thin caching wrapper around compute_party_bootstrap_cis.""" """Thin caching wrapper around compute_party_bootstrap_cis."""
from analysis.political_axis import compute_party_bootstrap_cis from analysis.political_axis import compute_party_bootstrap_cis
return compute_party_bootstrap_cis(_party_mp_vectors) return compute_party_bootstrap_cis(party_mp_vectors)
@st.cache_data(show_spinner="Scree-plot laden…") @st.cache_data(show_spinner="Scree-plot laden…")

Loading…
Cancel
Save