"""Data loading functions for the parliamentary explorer. This module contains all data loading functions extracted from explorer.py. It is intentionally free of Streamlit side-effects to be easy to unit test. """ from __future__ import annotations import logging from typing import Dict, List, Set, Tuple try: import duckdb except ( Exception ): # pragma: no cover - allow lightweight import without duckdb installed duckdb = None # type: ignore import numpy as np import pandas as pd from analysis.config import CURRENT_PARLIAMENT_PARTIES, _PARTY_NORMALIZE __all__ = [ "get_available_windows", "get_uniform_dim_windows", "load_party_map", "load_active_mps", "load_mp_vectors_by_window", "load_mp_vectors_by_party", "load_mp_vectors_by_party_for_window", "load_party_axis_scores", "load_party_axis_scores_for_window", "load_party_scores_all_windows", "load_party_scores_all_windows_aligned", "load_party_mp_vectors", "build_window_party_scores", "load_motions_df", "query_similar", "compute_party_axis_scores", ] logger = logging.getLogger(__name__) _WINDOW_SQL = """ SELECT DISTINCT window_id FROM svd_vectors ORDER BY window_id """ _UNIFORM_DIM_SQL = """ WITH vec_dims AS ( SELECT window_id, json_array_length(vector) AS dim FROM svd_vectors WHERE entity_type = 'mp' ), window_dim_counts AS ( SELECT window_id, dim, COUNT(*) AS cnt FROM vec_dims GROUP BY window_id, dim ), dominant AS ( SELECT DISTINCT ON (window_id) window_id, dim, cnt FROM window_dim_counts ORDER BY window_id, cnt DESC, dim DESC ) SELECT window_id FROM dominant WHERE dim >= 25 AND cnt >= 10 AND window_id NOT LIKE '%-Q%' ORDER BY window_id """ def get_available_windows(db_path: str) -> List[str]: """Return sorted list of distinct window_ids from svd_vectors.""" con = duckdb.connect(database=db_path, read_only=True) try: rows = con.execute(_WINDOW_SQL).fetchall() return [r[0] for r in rows] except Exception: logger.exception("Failed to query available windows") return [] finally: con.close() def get_uniform_dim_windows(db_path: str) -> List[str]: """Return only windows whose dominant MP-vector dimension is >= 25. Some windows contain a mix of vector lengths due to multiple pipeline runs (e.g. 2016 has both dim=1 and dim=50 rows). We find the most common dimension per window and include only windows where that dominant dim >= 25. Windows with too few dim-25+ entities (< 10) are also excluded to avoid degenerate PCA inputs. """ con = duckdb.connect(database=db_path, read_only=True) try: rows = con.execute(_UNIFORM_DIM_SQL).fetchall() return [r[0] for r in rows] except Exception: logger.exception("Failed to query uniform-dim windows") return [] finally: con.close() def load_party_map(db_path: str) -> Dict[str, str]: """Return {mp_name: party} mapping, with party names normalised to abbreviations.""" try: con = duckdb.connect(database=db_path, read_only=True) rows = con.execute( "SELECT mp_name, party FROM mp_metadata WHERE party IS NOT NULL" ).fetchall() con.close() return { mp: _PARTY_NORMALIZE.get(party, party) for mp, party in rows if mp and party } except Exception: logger.exception("Failed to load party map") return {} def load_active_mps(db_path: str) -> Set[str]: """Return the set of mp_name values that are currently seated in parliament. An MP is considered active if their mp_metadata row has tot_en_met IS NULL, meaning they have no recorded end date for their current seat. """ try: con = duckdb.connect(database=db_path, read_only=True) rows = con.execute( "SELECT mp_name FROM mp_metadata WHERE tot_en_met IS NULL" ).fetchall() con.close() return {r[0] for r in rows if r[0]} except Exception: logger.exception("Failed to load active MPs") return set() def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]: """Return party scores for all windows (non-aligned). Returns dict mapping party_abbrev -> list of axis scores, one per window. """ try: con = duckdb.connect(database=db_path, read_only=True) rows = con.execute( """ SELECT party_abbrev, window_id, x_axis, y_axis FROM party_axis_scores ORDER BY party_abbrev, window_id """ ).fetchall() con.close() scores: Dict[str, List[float]] = {} for party, window, x, y in rows: if party not in scores: scores[party] = [] if x is not None and y is not None: scores[party].extend([x, y]) return scores except Exception: logger.exception("Failed to load party axis scores") return {} def load_party_axis_scores_for_window( db_path: str, window: str ) -> Dict[str, List[float]]: """Return party scores for a specific window (aligned).""" try: con = duckdb.connect(database=db_path, read_only=True) rows = con.execute( """ SELECT party_abbrev, x_axis, y_axis FROM party_axis_scores WHERE window_id = ? ORDER BY party_abbrev """, [window], ).fetchall() con.close() return {party: [x or 0.0, y or 0.0] for party, x, y in rows} except Exception: logger.exception("Failed to load party axis scores for window %s", window) return {} def load_party_scores_all_windows(db_path: str) -> Dict[str, List[List[float]]]: """Return party scores across all windows (non-aligned).""" try: con = duckdb.connect(database=db_path, read_only=True) rows = con.execute( """ SELECT party_abbrev, window_id, x_axis, y_axis FROM party_axis_scores ORDER BY party_abbrev, window_id """ ).fetchall() con.close() scores: Dict[str, List[List[float]]] = {} current_party = None for party, window, x, y in rows: if party != current_party: scores[party] = [] current_party = party if x is not None and y is not None: scores[party].append([x, y]) else: scores[party].append([0.0, 0.0]) return scores except Exception: logger.exception("Failed to load party scores all windows") return {} def load_party_scores_all_windows_aligned( db_path: str, ) -> Dict[str, List[List[float]]]: """Return party scores across all windows (Procrustes-aligned).""" try: con = duckdb.connect(database=db_path, read_only=True) rows = con.execute( """ SELECT party_abbrev, window_id, x_axis_aligned, y_axis_aligned FROM party_axis_scores ORDER BY party_abbrev, window_id """ ).fetchall() con.close() scores: Dict[str, List[List[float]]] = {} current_party = None for party, window, x, y in rows: if party != current_party: scores[party] = [] current_party = party if x is not None and y is not None: scores[party].append([x, y]) else: scores[party].append([0.0, 0.0]) return scores except Exception: logger.exception("Failed to load aligned party scores all windows") return {} def build_window_party_scores( scores_by_party: Dict[str, List[List[float]]], window_idx: int, ) -> Dict[str, List[float]]: """Extract scores for one window as {party: [x, y]} for compute_flip_direction. Args: scores_by_party: Output of load_party_scores_all_windows_aligned — {party: [[x, y], [x, y], ...]} per window. window_idx: Zero-based index of the window to extract. Returns: {party: [x, y]} for the given window. Returns empty dict if window_idx is out of range. """ if window_idx < 0: return {} result: Dict[str, List[float]] = {} for party, window_scores in scores_by_party.items(): if window_idx < len(window_scores): result[party] = window_scores[window_idx] return result def load_party_mp_vectors(db_path: str) -> Dict[str, List[np.ndarray]]: """Load individual MP SVD vectors grouped by party. Returns {party_name: [np.ndarray(50,), ...]} — one array per MP. """ con = duckdb.connect(database=db_path, read_only=True) try: meta_rows = con.execute( "SELECT mp_name, party FROM mp_metadata " "WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' " "ORDER BY van ASC" ).fetchall() mp_party: Dict[str, str] = {} for mp_name, party in meta_rows: if mp_name and party: mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party) rows = con.execute( "SELECT entity_id, vector FROM svd_vectors " "WHERE entity_type = 'mp' AND window_id = 'current_parliament'" ).fetchall() vectors_by_party: Dict[str, List[np.ndarray]] = {} for entity_id, vector_json in rows: if entity_id in mp_party: party = mp_party[entity_id] if party not in vectors_by_party: vectors_by_party[party] = [] vectors_by_party[party].append(np.array(vector_json)) return vectors_by_party except Exception: logger.exception("Failed to load party MP vectors") return {} finally: con.close() def load_scree_data(db_path: str) -> List[float]: """Load scree plot data (explained variance) for current_parliament.""" try: con = duckdb.connect(database=db_path, read_only=True) row = con.execute( """ SELECT sv_metadata FROM svd_vectors WHERE window_id = 'current_parliament' AND entity_type = 'singular_values' LIMIT 1 """ ).fetchone() con.close() if row and row[0]: import json return json.loads(row[0]) return [] except Exception: logger.exception("Failed to load scree data") return [] def load_motions_df(db_path: str) -> pd.DataFrame: """Load the full motions table as a pandas DataFrame (read-only).""" try: con = duckdb.connect(database=db_path, read_only=True) df = con.execute( """ SELECT id, title, description, date, policy_area, voting_results, layman_explanation, winning_margin, controversy_score, url FROM motions """ ).fetchdf() con.close() df["date"] = pd.to_datetime(df["date"], errors="coerce") df["year"] = df["date"].dt.year return df except Exception: logger.exception("Failed to load motions DataFrame") return pd.DataFrame() def load_mp_vectors_by_window(db_path: str, window: str) -> Dict[str, np.ndarray]: """Load individual MP SVD vectors for a specific window. Args: db_path: Path to DuckDB database window: Window ID (e.g., "2015", "current_parliament") Returns: {mp_name: np.ndarray(50,)} — one vector per MP """ import json as _json try: con = duckdb.connect(database=db_path, read_only=True) rows = con.execute( """ SELECT entity_id, vector FROM svd_vectors WHERE entity_type = 'mp' AND window_id = ? """, [window], ).fetchall() con.close() mp_vecs: Dict[str, np.ndarray] = {} for entity_id, raw_vec in rows: if isinstance(raw_vec, str): vec = _json.loads(raw_vec) elif isinstance(raw_vec, (bytes, bytearray)): vec = _json.loads(raw_vec.decode()) elif isinstance(raw_vec, list): vec = raw_vec else: try: vec = list(raw_vec) except Exception: continue fvec = np.array([float(v) if v is not None else 0.0 for v in vec]) mp_vecs[entity_id] = fvec return mp_vecs except Exception: logger.exception("Failed to load MP vectors for window %s", window) return {} def query_similar( db_path: str, source_motion_id: int, vector_type: str = "fused", top_k: int = 10, ) -> pd.DataFrame: """Return top-k similar motions from similarity_cache (read-only).""" try: con = duckdb.connect(database=db_path, read_only=True) rows = con.execute( """ SELECT sc.target_motion_id, sc.score, sc.window_id, m.title, m.date, m.policy_area FROM similarity_cache sc JOIN motions m ON m.id = sc.target_motion_id WHERE sc.source_motion_id = ? AND sc.vector_type = ? ORDER BY sc.score DESC LIMIT ? """, [source_motion_id, vector_type, top_k], ).fetchdf() con.close() return rows except Exception: logger.exception( "Failed to query similarity cache for motion %s", source_motion_id ) return pd.DataFrame() def load_mp_vectors_by_party(db_path: str) -> Dict[str, List[np.ndarray]]: """Load individual MP SVD vectors grouped by party for current_parliament. Returns: {party_name: [np.ndarray(50,), ...]} — one array per MP. """ import json as _json try: con = duckdb.connect(database=db_path, read_only=True) meta_rows = con.execute( "SELECT mp_name, party FROM mp_metadata " "WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' " "ORDER BY van ASC" ).fetchall() mp_party: Dict[str, str] = {} for mp_name, party in meta_rows: if mp_name and party: mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party) rows = con.execute( "SELECT entity_id, vector FROM svd_vectors " "WHERE entity_type='mp' AND window_id='current_parliament'" ).fetchall() con.close() party_vecs: Dict[str, List[np.ndarray]] = {} for entity_id, raw_vec in rows: party = mp_party.get(entity_id) if party is None or party not in CURRENT_PARLIAMENT_PARTIES: continue if isinstance(raw_vec, str): vec = _json.loads(raw_vec) elif isinstance(raw_vec, (bytes, bytearray)): vec = _json.loads(raw_vec.decode()) elif isinstance(raw_vec, list): vec = raw_vec else: try: vec = list(raw_vec) except Exception: continue fvec = np.array([float(v) if v is not None else 0.0 for v in vec]) party_vecs.setdefault(party, []).append(fvec) return party_vecs except Exception: logger.exception("Failed to load MP vectors by party") return {} def load_mp_vectors_by_party_for_window( db_path: str, window: str ) -> Dict[str, List[np.ndarray]]: """Load individual MP SVD vectors grouped by party for a specific window. For historical windows, uses the MP→party mapping from that time period. Returns: {party_name: [np.ndarray(50,), ...]} — one array per MP. """ import json as _json try: con = duckdb.connect(database=db_path, read_only=True) is_current = window == "current_parliament" if is_current: meta_rows = con.execute( "SELECT mp_name, party FROM mp_metadata " "WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' " "ORDER BY van ASC" ).fetchall() else: try: year = int(window.split("-")[0]) except ValueError: year = 2023 meta_rows = con.execute( "SELECT mp_name, party FROM mp_metadata " "WHERE van <= ? AND (tot_en_met IS NULL OR tot_en_met >= ?) " "ORDER BY van ASC", [f"{year}-12-31", f"{year}-01-01"], ).fetchall() mp_party: Dict[str, str] = {} for mp_name, party in meta_rows: if mp_name and party: mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party) rows = con.execute( "SELECT entity_id, vector FROM svd_vectors " "WHERE entity_type='mp' AND window_id=?", [window], ).fetchall() con.close() party_vecs: Dict[str, List[np.ndarray]] = {} for entity_id, raw_vec in rows: party = mp_party.get(entity_id) if party is None: continue if is_current and party not in CURRENT_PARLIAMENT_PARTIES: continue if isinstance(raw_vec, str): vec = _json.loads(raw_vec) elif isinstance(raw_vec, (bytes, bytearray)): vec = _json.loads(raw_vec.decode()) elif isinstance(raw_vec, list): vec = raw_vec else: try: vec = list(raw_vec) except Exception: continue fvec = np.array([float(v) if v is not None else 0.0 for v in vec]) party_vecs.setdefault(party, []).append(fvec) return party_vecs except Exception: logger.exception("Failed to load MP vectors by party for window %s", window) return {} def compute_party_axis_scores( party_vecs: Dict[str, List[np.ndarray]], ) -> Dict[str, List[float]]: """Compute per-party axis scores as mean of MP vectors. Returns: {party_name: [float * k]} — k = 50, mean over all MPs in that party. """ try: return { party: np.array(vecs).mean(axis=0).tolist() for party, vecs in party_vecs.items() } except Exception: logger.exception("Failed to compute party axis scores") return {}