You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
563 lines
18 KiB
563 lines
18 KiB
"""Data loading functions for the parliamentary explorer.
|
|
|
|
This module contains all data loading functions extracted from explorer.py.
|
|
It is intentionally free of Streamlit side-effects to be easy to unit test.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Dict, List, Set, Tuple
|
|
|
|
import duckdb
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from analysis.config import CURRENT_PARLIAMENT_PARTIES, _PARTY_NORMALIZE
|
|
|
|
__all__ = [
|
|
"get_available_windows",
|
|
"get_uniform_dim_windows",
|
|
"load_party_map",
|
|
"load_active_mps",
|
|
"load_mp_vectors_by_window",
|
|
"load_mp_vectors_by_party",
|
|
"load_mp_vectors_by_party_for_window",
|
|
"load_party_axis_scores",
|
|
"load_party_axis_scores_for_window",
|
|
"load_party_scores_all_windows",
|
|
"load_party_scores_all_windows_aligned",
|
|
"load_party_mp_vectors",
|
|
"build_window_party_scores",
|
|
"load_motions_df",
|
|
"query_similar",
|
|
"compute_party_axis_scores",
|
|
]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_WINDOW_SQL = """
|
|
SELECT DISTINCT window_id FROM svd_vectors ORDER BY window_id
|
|
"""
|
|
|
|
_UNIFORM_DIM_SQL = """
|
|
WITH vec_dims AS (
|
|
SELECT window_id, json_array_length(vector) AS dim
|
|
FROM svd_vectors
|
|
WHERE entity_type = 'mp'
|
|
),
|
|
window_dim_counts AS (
|
|
SELECT window_id, dim, COUNT(*) AS cnt
|
|
FROM vec_dims
|
|
GROUP BY window_id, dim
|
|
),
|
|
dominant AS (
|
|
SELECT DISTINCT ON (window_id) window_id, dim, cnt
|
|
FROM window_dim_counts
|
|
ORDER BY window_id, cnt DESC, dim DESC
|
|
)
|
|
SELECT window_id
|
|
FROM dominant
|
|
WHERE dim >= 25 AND cnt >= 10
|
|
ORDER BY window_id
|
|
"""
|
|
|
|
|
|
def get_available_windows(db_path: str) -> List[str]:
|
|
"""Return sorted list of distinct window_ids from svd_vectors."""
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
try:
|
|
rows = con.execute(_WINDOW_SQL).fetchall()
|
|
return [r[0] for r in rows]
|
|
except Exception:
|
|
logger.exception("Failed to query available windows")
|
|
return []
|
|
finally:
|
|
con.close()
|
|
|
|
|
|
def get_uniform_dim_windows(db_path: str) -> List[str]:
|
|
"""Return only windows whose dominant MP-vector dimension is >= 25.
|
|
|
|
Some windows contain a mix of vector lengths due to multiple pipeline runs
|
|
(e.g. 2016 has both dim=1 and dim=50 rows). We find the most common dimension
|
|
per window and include only windows where that dominant dim >= 25.
|
|
Windows with too few dim-25+ entities (< 10) are also excluded to avoid
|
|
degenerate PCA inputs.
|
|
"""
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
try:
|
|
rows = con.execute(_UNIFORM_DIM_SQL).fetchall()
|
|
return [r[0] for r in rows]
|
|
except Exception:
|
|
logger.exception("Failed to query uniform-dim windows")
|
|
return []
|
|
finally:
|
|
con.close()
|
|
|
|
|
|
def load_party_map(db_path: str) -> Dict[str, str]:
|
|
"""Return {mp_name: party} mapping, with party names normalised to abbreviations."""
|
|
try:
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
rows = con.execute(
|
|
"SELECT mp_name, party FROM mp_metadata WHERE party IS NOT NULL"
|
|
).fetchall()
|
|
con.close()
|
|
return {
|
|
mp: _PARTY_NORMALIZE.get(party, party) for mp, party in rows if mp and party
|
|
}
|
|
except Exception:
|
|
logger.exception("Failed to load party map")
|
|
return {}
|
|
|
|
|
|
def load_active_mps(db_path: str) -> Set[str]:
|
|
"""Return the set of mp_name values that are currently seated in parliament.
|
|
|
|
An MP is considered active if their mp_metadata row has tot_en_met IS NULL,
|
|
meaning they have no recorded end date for their current seat.
|
|
"""
|
|
try:
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
rows = con.execute(
|
|
"SELECT mp_name FROM mp_metadata WHERE tot_en_met IS NULL"
|
|
).fetchall()
|
|
con.close()
|
|
return {r[0] for r in rows if r[0]}
|
|
except Exception:
|
|
logger.exception("Failed to load active MPs")
|
|
return set()
|
|
|
|
|
|
def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
|
|
"""Return party scores for all windows (non-aligned).
|
|
|
|
Returns dict mapping party_abbrev -> list of axis scores, one per window.
|
|
"""
|
|
try:
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
rows = con.execute(
|
|
"""
|
|
SELECT party_abbrev, window_id, x_axis, y_axis
|
|
FROM party_axis_scores
|
|
ORDER BY party_abbrev, window_id
|
|
"""
|
|
).fetchall()
|
|
con.close()
|
|
|
|
scores: Dict[str, List[float]] = {}
|
|
for party, window, x, y in rows:
|
|
if party not in scores:
|
|
scores[party] = []
|
|
if x is not None and y is not None:
|
|
scores[party].extend([x, y])
|
|
return scores
|
|
except Exception:
|
|
logger.exception("Failed to load party axis scores")
|
|
return {}
|
|
|
|
|
|
def load_party_axis_scores_for_window(
|
|
db_path: str, window: str
|
|
) -> Dict[str, List[float]]:
|
|
"""Return party scores for a specific window (aligned)."""
|
|
try:
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
rows = con.execute(
|
|
"""
|
|
SELECT party_abbrev, x_axis, y_axis
|
|
FROM party_axis_scores
|
|
WHERE window_id = ?
|
|
ORDER BY party_abbrev
|
|
""",
|
|
[window],
|
|
).fetchall()
|
|
con.close()
|
|
|
|
return {party: [x or 0.0, y or 0.0] for party, x, y in rows}
|
|
except Exception:
|
|
logger.exception("Failed to load party axis scores for window %s", window)
|
|
return {}
|
|
|
|
|
|
def load_party_scores_all_windows(db_path: str) -> Dict[str, List[List[float]]]:
|
|
"""Return party scores across all windows (non-aligned)."""
|
|
try:
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
rows = con.execute(
|
|
"""
|
|
SELECT party_abbrev, window_id, x_axis, y_axis
|
|
FROM party_axis_scores
|
|
ORDER BY party_abbrev, window_id
|
|
"""
|
|
).fetchall()
|
|
con.close()
|
|
|
|
scores: Dict[str, List[List[float]]] = {}
|
|
current_party = None
|
|
for party, window, x, y in rows:
|
|
if party != current_party:
|
|
scores[party] = []
|
|
current_party = party
|
|
if x is not None and y is not None:
|
|
scores[party].append([x, y])
|
|
else:
|
|
scores[party].append([0.0, 0.0])
|
|
return scores
|
|
except Exception:
|
|
logger.exception("Failed to load party scores all windows")
|
|
return {}
|
|
|
|
|
|
def load_party_scores_all_windows_aligned(
|
|
db_path: str,
|
|
) -> Dict[str, List[List[float]]]:
|
|
"""Return party scores across all windows (Procrustes-aligned)."""
|
|
try:
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
rows = con.execute(
|
|
"""
|
|
SELECT party_abbrev, window_id, x_axis_aligned, y_axis_aligned
|
|
FROM party_axis_scores
|
|
ORDER BY party_abbrev, window_id
|
|
"""
|
|
).fetchall()
|
|
con.close()
|
|
|
|
scores: Dict[str, List[List[float]]] = {}
|
|
current_party = None
|
|
for party, window, x, y in rows:
|
|
if party != current_party:
|
|
scores[party] = []
|
|
current_party = party
|
|
if x is not None and y is not None:
|
|
scores[party].append([x, y])
|
|
else:
|
|
scores[party].append([0.0, 0.0])
|
|
return scores
|
|
except Exception:
|
|
logger.exception("Failed to load aligned party scores all windows")
|
|
return {}
|
|
|
|
|
|
def build_window_party_scores(
|
|
scores_by_party: Dict[str, List[List[float]]],
|
|
window_idx: int,
|
|
) -> Dict[str, List[float]]:
|
|
"""Extract scores for one window as {party: [x, y]} for compute_flip_direction.
|
|
|
|
Args:
|
|
scores_by_party: Output of load_party_scores_all_windows_aligned —
|
|
{party: [[x, y], [x, y], ...]} per window.
|
|
window_idx: Zero-based index of the window to extract.
|
|
|
|
Returns:
|
|
{party: [x, y]} for the given window. Returns empty dict if
|
|
window_idx is out of range.
|
|
"""
|
|
if window_idx < 0:
|
|
return {}
|
|
result: Dict[str, List[float]] = {}
|
|
for party, window_scores in scores_by_party.items():
|
|
if window_idx < len(window_scores):
|
|
result[party] = window_scores[window_idx]
|
|
return result
|
|
|
|
|
|
def load_party_mp_vectors(db_path: str) -> Dict[str, List[np.ndarray]]:
|
|
"""Load individual MP SVD vectors grouped by party.
|
|
|
|
Returns {party_name: [np.ndarray(50,), ...]} — one array per MP.
|
|
"""
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
try:
|
|
meta_rows = con.execute(
|
|
"SELECT mp_name, party FROM mp_metadata "
|
|
"WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' "
|
|
"ORDER BY van ASC"
|
|
).fetchall()
|
|
mp_party: Dict[str, str] = {}
|
|
for mp_name, party in meta_rows:
|
|
if mp_name and party:
|
|
mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party)
|
|
|
|
rows = con.execute(
|
|
"SELECT entity_id, vector FROM svd_vectors "
|
|
"WHERE entity_type = 'mp' AND window_id = 'current_parliament'"
|
|
).fetchall()
|
|
|
|
vectors_by_party: Dict[str, List[np.ndarray]] = {}
|
|
for entity_id, vector_json in rows:
|
|
if entity_id in mp_party:
|
|
party = mp_party[entity_id]
|
|
if party not in vectors_by_party:
|
|
vectors_by_party[party] = []
|
|
vectors_by_party[party].append(np.array(vector_json))
|
|
|
|
return vectors_by_party
|
|
except Exception:
|
|
logger.exception("Failed to load party MP vectors")
|
|
return {}
|
|
finally:
|
|
con.close()
|
|
|
|
|
|
def load_scree_data(db_path: str) -> List[float]:
|
|
"""Load scree plot data (explained variance) for current_parliament."""
|
|
try:
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
row = con.execute(
|
|
"""
|
|
SELECT sv_metadata FROM svd_vectors
|
|
WHERE window_id = 'current_parliament' AND entity_type = 'singular_values'
|
|
LIMIT 1
|
|
"""
|
|
).fetchone()
|
|
con.close()
|
|
|
|
if row and row[0]:
|
|
import json
|
|
|
|
return json.loads(row[0])
|
|
return []
|
|
except Exception:
|
|
logger.exception("Failed to load scree data")
|
|
return []
|
|
|
|
|
|
def load_motions_df(db_path: str) -> pd.DataFrame:
|
|
"""Load the full motions table as a pandas DataFrame (read-only)."""
|
|
try:
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
df = con.execute(
|
|
"""
|
|
SELECT id, title, description, date, policy_area,
|
|
voting_results, layman_explanation,
|
|
winning_margin, controversy_score, url
|
|
FROM motions
|
|
"""
|
|
).fetchdf()
|
|
con.close()
|
|
df["date"] = pd.to_datetime(df["date"], errors="coerce")
|
|
df["year"] = df["date"].dt.year
|
|
return df
|
|
except Exception:
|
|
logger.exception("Failed to load motions DataFrame")
|
|
return pd.DataFrame()
|
|
|
|
|
|
def load_mp_vectors_by_window(db_path: str, window: str) -> Dict[str, np.ndarray]:
|
|
"""Load individual MP SVD vectors for a specific window.
|
|
|
|
Args:
|
|
db_path: Path to DuckDB database
|
|
window: Window ID (e.g., "2015", "current_parliament")
|
|
|
|
Returns:
|
|
{mp_name: np.ndarray(50,)} — one vector per MP
|
|
"""
|
|
import json as _json
|
|
|
|
try:
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
rows = con.execute(
|
|
"""
|
|
SELECT entity_id, vector FROM svd_vectors
|
|
WHERE entity_type = 'mp' AND window_id = ?
|
|
""",
|
|
[window],
|
|
).fetchall()
|
|
con.close()
|
|
|
|
mp_vecs: Dict[str, np.ndarray] = {}
|
|
for entity_id, raw_vec in rows:
|
|
if isinstance(raw_vec, str):
|
|
vec = _json.loads(raw_vec)
|
|
elif isinstance(raw_vec, (bytes, bytearray)):
|
|
vec = _json.loads(raw_vec.decode())
|
|
elif isinstance(raw_vec, list):
|
|
vec = raw_vec
|
|
else:
|
|
try:
|
|
vec = list(raw_vec)
|
|
except Exception:
|
|
continue
|
|
fvec = np.array([float(v) if v is not None else 0.0 for v in vec])
|
|
mp_vecs[entity_id] = fvec
|
|
|
|
return mp_vecs
|
|
except Exception:
|
|
logger.exception("Failed to load MP vectors for window %s", window)
|
|
return {}
|
|
|
|
|
|
def query_similar(
|
|
db_path: str,
|
|
source_motion_id: int,
|
|
vector_type: str = "fused",
|
|
top_k: int = 10,
|
|
) -> pd.DataFrame:
|
|
"""Return top-k similar motions from similarity_cache (read-only)."""
|
|
try:
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
rows = con.execute(
|
|
"""
|
|
SELECT sc.target_motion_id, sc.score, sc.window_id,
|
|
m.title, m.date, m.policy_area
|
|
FROM similarity_cache sc
|
|
JOIN motions m ON m.id = sc.target_motion_id
|
|
WHERE sc.source_motion_id = ?
|
|
AND sc.vector_type = ?
|
|
ORDER BY sc.score DESC
|
|
LIMIT ?
|
|
""",
|
|
[source_motion_id, vector_type, top_k],
|
|
).fetchdf()
|
|
con.close()
|
|
return rows
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to query similarity cache for motion %s", source_motion_id
|
|
)
|
|
return pd.DataFrame()
|
|
|
|
|
|
def load_mp_vectors_by_party(db_path: str) -> Dict[str, List[np.ndarray]]:
|
|
"""Load individual MP SVD vectors grouped by party for current_parliament.
|
|
|
|
Returns:
|
|
{party_name: [np.ndarray(50,), ...]} — one array per MP.
|
|
"""
|
|
import json as _json
|
|
|
|
try:
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
meta_rows = con.execute(
|
|
"SELECT mp_name, party FROM mp_metadata "
|
|
"WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' "
|
|
"ORDER BY van ASC"
|
|
).fetchall()
|
|
mp_party: Dict[str, str] = {}
|
|
for mp_name, party in meta_rows:
|
|
if mp_name and party:
|
|
mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party)
|
|
|
|
rows = con.execute(
|
|
"SELECT entity_id, vector FROM svd_vectors "
|
|
"WHERE entity_type='mp' AND window_id='current_parliament'"
|
|
).fetchall()
|
|
con.close()
|
|
|
|
party_vecs: Dict[str, List[np.ndarray]] = {}
|
|
for entity_id, raw_vec in rows:
|
|
party = mp_party.get(entity_id)
|
|
if party is None or party not in CURRENT_PARLIAMENT_PARTIES:
|
|
continue
|
|
if isinstance(raw_vec, str):
|
|
vec = _json.loads(raw_vec)
|
|
elif isinstance(raw_vec, (bytes, bytearray)):
|
|
vec = _json.loads(raw_vec.decode())
|
|
elif isinstance(raw_vec, list):
|
|
vec = raw_vec
|
|
else:
|
|
try:
|
|
vec = list(raw_vec)
|
|
except Exception:
|
|
continue
|
|
fvec = np.array([float(v) if v is not None else 0.0 for v in vec])
|
|
party_vecs.setdefault(party, []).append(fvec)
|
|
return party_vecs
|
|
except Exception:
|
|
logger.exception("Failed to load MP vectors by party")
|
|
return {}
|
|
|
|
|
|
def load_mp_vectors_by_party_for_window(
|
|
db_path: str, window: str
|
|
) -> Dict[str, List[np.ndarray]]:
|
|
"""Load individual MP SVD vectors grouped by party for a specific window.
|
|
|
|
For historical windows, uses the MP→party mapping from that time period.
|
|
|
|
Returns:
|
|
{party_name: [np.ndarray(50,), ...]} — one array per MP.
|
|
"""
|
|
import json as _json
|
|
|
|
try:
|
|
con = duckdb.connect(database=db_path, read_only=True)
|
|
is_current = window == "current_parliament"
|
|
|
|
if is_current:
|
|
meta_rows = con.execute(
|
|
"SELECT mp_name, party FROM mp_metadata "
|
|
"WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' "
|
|
"ORDER BY van ASC"
|
|
).fetchall()
|
|
else:
|
|
try:
|
|
year = int(window.split("-")[0])
|
|
except ValueError:
|
|
year = 2023
|
|
meta_rows = con.execute(
|
|
"SELECT mp_name, party FROM mp_metadata "
|
|
"WHERE van <= ? AND (tot_en_met IS NULL OR tot_en_met >= ?) "
|
|
"ORDER BY van ASC",
|
|
[f"{year}-12-31", f"{year}-01-01"],
|
|
).fetchall()
|
|
|
|
mp_party: Dict[str, str] = {}
|
|
for mp_name, party in meta_rows:
|
|
if mp_name and party:
|
|
mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party)
|
|
|
|
rows = con.execute(
|
|
"SELECT entity_id, vector FROM svd_vectors "
|
|
"WHERE entity_type='mp' AND window_id=?",
|
|
[window],
|
|
).fetchall()
|
|
con.close()
|
|
|
|
party_vecs: Dict[str, List[np.ndarray]] = {}
|
|
for entity_id, raw_vec in rows:
|
|
party = mp_party.get(entity_id)
|
|
if party is None:
|
|
continue
|
|
if is_current and party not in CURRENT_PARLIAMENT_PARTIES:
|
|
continue
|
|
if isinstance(raw_vec, str):
|
|
vec = _json.loads(raw_vec)
|
|
elif isinstance(raw_vec, (bytes, bytearray)):
|
|
vec = _json.loads(raw_vec.decode())
|
|
elif isinstance(raw_vec, list):
|
|
vec = raw_vec
|
|
else:
|
|
try:
|
|
vec = list(raw_vec)
|
|
except Exception:
|
|
continue
|
|
fvec = np.array([float(v) if v is not None else 0.0 for v in vec])
|
|
party_vecs.setdefault(party, []).append(fvec)
|
|
return party_vecs
|
|
except Exception:
|
|
logger.exception("Failed to load MP vectors by party for window %s", window)
|
|
return {}
|
|
|
|
|
|
def compute_party_axis_scores(
|
|
party_vecs: Dict[str, List[np.ndarray]],
|
|
) -> Dict[str, List[float]]:
|
|
"""Compute per-party axis scores as mean of MP vectors.
|
|
|
|
Returns:
|
|
{party_name: [float * k]} — k = 50, mean over all MPs in that party.
|
|
"""
|
|
try:
|
|
return {
|
|
party: np.array(vecs).mean(axis=0).tolist()
|
|
for party, vecs in party_vecs.items()
|
|
}
|
|
except Exception:
|
|
logger.exception("Failed to compute party axis scores")
|
|
return {}
|
|
|