From b7129b3755eedd19e4765edc942d7c7fe135addb Mon Sep 17 00:00:00 2001 From: Sven Geboers Date: Sun, 29 Mar 2026 23:41:15 +0200 Subject: [PATCH] Extract _load_mp_vectors_by_party helper and fix cache key - Extract shared helper that both load_party_axis_scores and load_party_mp_vectors delegate to, eliminating ~40 lines of duplicated DB query + vector parsing code - Remove dead code in load_party_axis_scores that queried mp_metadata twice (first without ORDER BY, then again with ORDER BY, overwriting) - Fix _cached_bootstrap_cis parameter: remove _ prefix so Streamlit actually hashes the input dict instead of caching with no key --- explorer.py | 136 ++++++++++++++++------------------------------------ 1 file changed, 41 insertions(+), 95 deletions(-) diff --git a/explorer.py b/explorer.py index b1955de..046b341 100644 --- a/explorer.py +++ b/explorer.py @@ -397,53 +397,29 @@ def compute_party_discipline( pass -@st.cache_data(show_spinner="Partijposities op SVD-assen laden…") -def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]: - """Return per-party SVD vectors, computed as mean of individual MP vectors. +def _load_mp_vectors_by_party(db_path: str) -> Dict[str, List[np.ndarray]]: + """Load individual MP SVD vectors grouped by party. - Loads individual MP rows (entity_id LIKE '%,%') from window='current_parliament', - assigns each MP their party using the dominant party from mp_votes, then - averages SVD vectors per party. - - This matches the political compass data source (also averages individual MPs), - so axis rankings are consistent between the SVD tab and the compass. + Queries mp_metadata for the mp→party mapping (latest assignment during the + current parliament), normalises party names, loads SVD vectors from the + ``current_parliament`` window, and filters to CURRENT_PARLIAMENT_PARTIES. Returns: - {party_name: [float * k]} — k = 50, mean over all MPs in that party. + {party_name: [np.ndarray(50,), ...]} — one array per MP. """ + con = duckdb.connect(database=db_path, read_only=True) try: - con = duckdb.connect(database=db_path, read_only=True) - - # Build mp → party mapping from mp_metadata (most recent party during current parliament). - # mp_metadata format: mp_name like "Van Baarle, S.R.T.", party = "GroenLinks-PvdA" - # We take the party record with the latest `van` date (most recent assignment). + # Build mp → party mapping. ORDER BY van ASC so latest assignment wins + # via last-write-wins when an MP switched party. meta_rows = con.execute( - "SELECT mp_name, party FROM mp_metadata " - "WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22'" - ).fetchall() - # For MPs with multiple records (party switches), keep the one with latest van date. - # Simple approach: last-write-wins per mp_name after sorting by van ascending. - mp_party_raw: Dict[str, str] = {} - for mp_name, party in meta_rows: - if mp_name and party: - mp_party_raw[mp_name] = party # later rows (after ORDER BY van) win - - # Re-query ordered so latest van wins reliably - meta_ordered = con.execute( "SELECT mp_name, party FROM mp_metadata " "WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' " "ORDER BY van ASC" ).fetchall() - mp_party_raw = {} - for mp_name, party in meta_ordered: - if mp_name and party: - mp_party_raw[mp_name] = party - - # Normalize party names to canonical abbreviations mp_party: Dict[str, str] = {} - for mp_name, party in mp_party_raw.items(): - canonical = _PARTY_NORMALIZE.get(party, party) - mp_party[mp_name] = canonical + for mp_name, party in meta_rows: + if mp_name and party: + mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party) # Individual MP vectors from current_parliament rows = con.execute( @@ -451,7 +427,7 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]: "WHERE entity_type='mp' AND window_id='current_parliament'" ).fetchall() - party_vecs: Dict[str, list] = {} + party_vecs: Dict[str, List[np.ndarray]] = {} for entity_id, raw_vec in rows: party = mp_party.get(entity_id) if party is None or party not in CURRENT_PARLIAMENT_PARTIES: @@ -467,17 +443,10 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]: vec = list(raw_vec) except Exception: continue - fvec = [float(v) if v is not None else 0.0 for v in vec] + fvec = np.array([float(v) if v is not None else 0.0 for v in vec]) party_vecs.setdefault(party, []).append(fvec) - # Average vectors per party - result: Dict[str, List[float]] = {} - for party, vecs in party_vecs.items(): - result[party] = np.array(vecs).mean(axis=0).tolist() - return result - except Exception: - logger.exception("Failed to load party axis scores") - return {} + return party_vecs finally: try: con.close() @@ -485,75 +454,52 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]: pass -@st.cache_data(show_spinner="Partij-MP vectoren laden…") -def load_party_mp_vectors(db_path: str) -> Dict[str, List[np.ndarray]]: - """Return per-party lists of individual MP SVD vectors. +@st.cache_data(show_spinner="Partijposities op SVD-assen laden…") +def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]: + """Return per-party SVD vectors, computed as mean of individual MP vectors. - Same MP→party mapping as load_party_axis_scores(), but returns the raw - per-MP vectors instead of averaging them. Suitable for bootstrap CI - computation. + Loads individual MP rows from window='current_parliament', assigns each MP + their party, then averages SVD vectors per party. Returns: - {party_name: [np.ndarray(50,), ...]} — one array per MP. + {party_name: [float * k]} — k = 50, mean over all MPs in that party. """ try: - con = duckdb.connect(database=db_path, read_only=True) + party_vecs = _load_mp_vectors_by_party(db_path) + return { + party: np.array(vecs).mean(axis=0).tolist() + for party, vecs in party_vecs.items() + } + except Exception: + logger.exception("Failed to load party axis scores") + return {} - # Build mp → party mapping (same logic as load_party_axis_scores) - meta_ordered = con.execute( - "SELECT mp_name, party FROM mp_metadata " - "WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' " - "ORDER BY van ASC" - ).fetchall() - mp_party: Dict[str, str] = {} - for mp_name, party in meta_ordered: - if mp_name and party: - mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party) - # Individual MP vectors from current_parliament - rows = con.execute( - "SELECT entity_id, vector FROM svd_vectors " - "WHERE entity_type='mp' AND window_id='current_parliament'" - ).fetchall() +@st.cache_data(show_spinner="Partij-MP vectoren laden…") +def load_party_mp_vectors(db_path: str) -> Dict[str, List[np.ndarray]]: + """Return per-party lists of individual MP SVD vectors. - party_vecs: Dict[str, List[np.ndarray]] = {} - for entity_id, raw_vec in rows: - party = mp_party.get(entity_id) - if party is None or party not in CURRENT_PARLIAMENT_PARTIES: - continue - if isinstance(raw_vec, str): - vec = json.loads(raw_vec) - elif isinstance(raw_vec, (bytes, bytearray)): - vec = json.loads(raw_vec.decode()) - elif isinstance(raw_vec, list): - vec = raw_vec - else: - try: - vec = list(raw_vec) - except Exception: - continue - fvec = np.array([float(v) if v is not None else 0.0 for v in vec]) - party_vecs.setdefault(party, []).append(fvec) + Same MP→party mapping as load_party_axis_scores(), suitable for bootstrap + CI computation. - return party_vecs + Returns: + {party_name: [np.ndarray(50,), ...]} — one array per MP. + """ + try: + return _load_mp_vectors_by_party(db_path) except Exception: logger.exception("Failed to load party MP vectors") return {} - finally: - try: - con.close() - except Exception: - pass @st.cache_data(show_spinner="Bootstrap CI berekenen…") def _cached_bootstrap_cis( - _party_mp_vectors: Dict[str, List[np.ndarray]], + party_mp_vectors: Dict[str, List[np.ndarray]], ) -> Dict[str, Dict]: """Thin caching wrapper around compute_party_bootstrap_cis.""" from analysis.political_axis import compute_party_bootstrap_cis - return compute_party_bootstrap_cis(_party_mp_vectors) + return compute_party_bootstrap_cis(party_mp_vectors) @st.cache_data(show_spinner="Scree-plot laden…")