You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/analysis/explorer_data.py

568 lines
18 KiB

"""Data loading functions for the parliamentary explorer.
This module contains all data loading functions extracted from explorer.py.
It is intentionally free of Streamlit side-effects to be easy to unit test.
"""
from __future__ import annotations
import logging
from typing import Dict, List, Set, Tuple
try:
import duckdb
except (
Exception
): # pragma: no cover - allow lightweight import without duckdb installed
duckdb = None # type: ignore
import numpy as np
import pandas as pd
from analysis.config import CURRENT_PARLIAMENT_PARTIES, _PARTY_NORMALIZE
__all__ = [
"get_available_windows",
"get_uniform_dim_windows",
"load_party_map",
"load_active_mps",
"load_mp_vectors_by_window",
"load_mp_vectors_by_party",
"load_mp_vectors_by_party_for_window",
"load_party_axis_scores",
"load_party_axis_scores_for_window",
"load_party_scores_all_windows",
"load_party_scores_all_windows_aligned",
"load_party_mp_vectors",
"build_window_party_scores",
"load_motions_df",
"query_similar",
"compute_party_axis_scores",
]
logger = logging.getLogger(__name__)
_WINDOW_SQL = """
SELECT DISTINCT window_id FROM svd_vectors ORDER BY window_id
"""
_UNIFORM_DIM_SQL = """
WITH vec_dims AS (
SELECT window_id, json_array_length(vector) AS dim
FROM svd_vectors
WHERE entity_type = 'mp'
),
window_dim_counts AS (
SELECT window_id, dim, COUNT(*) AS cnt
FROM vec_dims
GROUP BY window_id, dim
),
dominant AS (
SELECT DISTINCT ON (window_id) window_id, dim, cnt
FROM window_dim_counts
ORDER BY window_id, cnt DESC, dim DESC
)
SELECT window_id
FROM dominant
WHERE dim >= 25 AND cnt >= 10
ORDER BY window_id
"""
def get_available_windows(db_path: str) -> List[str]:
"""Return sorted list of distinct window_ids from svd_vectors."""
con = duckdb.connect(database=db_path, read_only=True)
try:
rows = con.execute(_WINDOW_SQL).fetchall()
return [r[0] for r in rows]
except Exception:
logger.exception("Failed to query available windows")
return []
finally:
con.close()
def get_uniform_dim_windows(db_path: str) -> List[str]:
"""Return only windows whose dominant MP-vector dimension is >= 25.
Some windows contain a mix of vector lengths due to multiple pipeline runs
(e.g. 2016 has both dim=1 and dim=50 rows). We find the most common dimension
per window and include only windows where that dominant dim >= 25.
Windows with too few dim-25+ entities (< 10) are also excluded to avoid
degenerate PCA inputs.
"""
con = duckdb.connect(database=db_path, read_only=True)
try:
rows = con.execute(_UNIFORM_DIM_SQL).fetchall()
return [r[0] for r in rows]
except Exception:
logger.exception("Failed to query uniform-dim windows")
return []
finally:
con.close()
def load_party_map(db_path: str) -> Dict[str, str]:
"""Return {mp_name: party} mapping, with party names normalised to abbreviations."""
try:
con = duckdb.connect(database=db_path, read_only=True)
rows = con.execute(
"SELECT mp_name, party FROM mp_metadata WHERE party IS NOT NULL"
).fetchall()
con.close()
return {
mp: _PARTY_NORMALIZE.get(party, party) for mp, party in rows if mp and party
}
except Exception:
logger.exception("Failed to load party map")
return {}
def load_active_mps(db_path: str) -> Set[str]:
"""Return the set of mp_name values that are currently seated in parliament.
An MP is considered active if their mp_metadata row has tot_en_met IS NULL,
meaning they have no recorded end date for their current seat.
"""
try:
con = duckdb.connect(database=db_path, read_only=True)
rows = con.execute(
"SELECT mp_name FROM mp_metadata WHERE tot_en_met IS NULL"
).fetchall()
con.close()
return {r[0] for r in rows if r[0]}
except Exception:
logger.exception("Failed to load active MPs")
return set()
def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
"""Return party scores for all windows (non-aligned).
Returns dict mapping party_abbrev -> list of axis scores, one per window.
"""
try:
con = duckdb.connect(database=db_path, read_only=True)
rows = con.execute(
"""
SELECT party_abbrev, window_id, x_axis, y_axis
FROM party_axis_scores
ORDER BY party_abbrev, window_id
"""
).fetchall()
con.close()
scores: Dict[str, List[float]] = {}
for party, window, x, y in rows:
if party not in scores:
scores[party] = []
if x is not None and y is not None:
scores[party].extend([x, y])
return scores
except Exception:
logger.exception("Failed to load party axis scores")
return {}
def load_party_axis_scores_for_window(
db_path: str, window: str
) -> Dict[str, List[float]]:
"""Return party scores for a specific window (aligned)."""
try:
con = duckdb.connect(database=db_path, read_only=True)
rows = con.execute(
"""
SELECT party_abbrev, x_axis, y_axis
FROM party_axis_scores
WHERE window_id = ?
ORDER BY party_abbrev
""",
[window],
).fetchall()
con.close()
return {party: [x or 0.0, y or 0.0] for party, x, y in rows}
except Exception:
logger.exception("Failed to load party axis scores for window %s", window)
return {}
def load_party_scores_all_windows(db_path: str) -> Dict[str, List[List[float]]]:
"""Return party scores across all windows (non-aligned)."""
try:
con = duckdb.connect(database=db_path, read_only=True)
rows = con.execute(
"""
SELECT party_abbrev, window_id, x_axis, y_axis
FROM party_axis_scores
ORDER BY party_abbrev, window_id
"""
).fetchall()
con.close()
scores: Dict[str, List[List[float]]] = {}
current_party = None
for party, window, x, y in rows:
if party != current_party:
scores[party] = []
current_party = party
if x is not None and y is not None:
scores[party].append([x, y])
else:
scores[party].append([0.0, 0.0])
return scores
except Exception:
logger.exception("Failed to load party scores all windows")
return {}
def load_party_scores_all_windows_aligned(
db_path: str,
) -> Dict[str, List[List[float]]]:
"""Return party scores across all windows (Procrustes-aligned)."""
try:
con = duckdb.connect(database=db_path, read_only=True)
rows = con.execute(
"""
SELECT party_abbrev, window_id, x_axis_aligned, y_axis_aligned
FROM party_axis_scores
ORDER BY party_abbrev, window_id
"""
).fetchall()
con.close()
scores: Dict[str, List[List[float]]] = {}
current_party = None
for party, window, x, y in rows:
if party != current_party:
scores[party] = []
current_party = party
if x is not None and y is not None:
scores[party].append([x, y])
else:
scores[party].append([0.0, 0.0])
return scores
except Exception:
logger.exception("Failed to load aligned party scores all windows")
return {}
def build_window_party_scores(
scores_by_party: Dict[str, List[List[float]]],
window_idx: int,
) -> Dict[str, List[float]]:
"""Extract scores for one window as {party: [x, y]} for compute_flip_direction.
Args:
scores_by_party: Output of load_party_scores_all_windows_aligned —
{party: [[x, y], [x, y], ...]} per window.
window_idx: Zero-based index of the window to extract.
Returns:
{party: [x, y]} for the given window. Returns empty dict if
window_idx is out of range.
"""
if window_idx < 0:
return {}
result: Dict[str, List[float]] = {}
for party, window_scores in scores_by_party.items():
if window_idx < len(window_scores):
result[party] = window_scores[window_idx]
return result
def load_party_mp_vectors(db_path: str) -> Dict[str, List[np.ndarray]]:
"""Load individual MP SVD vectors grouped by party.
Returns {party_name: [np.ndarray(50,), ...]} — one array per MP.
"""
con = duckdb.connect(database=db_path, read_only=True)
try:
meta_rows = con.execute(
"SELECT mp_name, party FROM mp_metadata "
"WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' "
"ORDER BY van ASC"
).fetchall()
mp_party: Dict[str, str] = {}
for mp_name, party in meta_rows:
if mp_name and party:
mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party)
rows = con.execute(
"SELECT entity_id, vector FROM svd_vectors "
"WHERE entity_type = 'mp' AND window_id = 'current_parliament'"
).fetchall()
vectors_by_party: Dict[str, List[np.ndarray]] = {}
for entity_id, vector_json in rows:
if entity_id in mp_party:
party = mp_party[entity_id]
if party not in vectors_by_party:
vectors_by_party[party] = []
vectors_by_party[party].append(np.array(vector_json))
return vectors_by_party
except Exception:
logger.exception("Failed to load party MP vectors")
return {}
finally:
con.close()
def load_scree_data(db_path: str) -> List[float]:
"""Load scree plot data (explained variance) for current_parliament."""
try:
con = duckdb.connect(database=db_path, read_only=True)
row = con.execute(
"""
SELECT sv_metadata FROM svd_vectors
WHERE window_id = 'current_parliament' AND entity_type = 'singular_values'
LIMIT 1
"""
).fetchone()
con.close()
if row and row[0]:
import json
return json.loads(row[0])
return []
except Exception:
logger.exception("Failed to load scree data")
return []
def load_motions_df(db_path: str) -> pd.DataFrame:
"""Load the full motions table as a pandas DataFrame (read-only)."""
try:
con = duckdb.connect(database=db_path, read_only=True)
df = con.execute(
"""
SELECT id, title, description, date, policy_area,
voting_results, layman_explanation,
winning_margin, controversy_score, url
FROM motions
"""
).fetchdf()
con.close()
df["date"] = pd.to_datetime(df["date"], errors="coerce")
df["year"] = df["date"].dt.year
return df
except Exception:
logger.exception("Failed to load motions DataFrame")
return pd.DataFrame()
def load_mp_vectors_by_window(db_path: str, window: str) -> Dict[str, np.ndarray]:
"""Load individual MP SVD vectors for a specific window.
Args:
db_path: Path to DuckDB database
window: Window ID (e.g., "2015", "current_parliament")
Returns:
{mp_name: np.ndarray(50,)} — one vector per MP
"""
import json as _json
try:
con = duckdb.connect(database=db_path, read_only=True)
rows = con.execute(
"""
SELECT entity_id, vector FROM svd_vectors
WHERE entity_type = 'mp' AND window_id = ?
""",
[window],
).fetchall()
con.close()
mp_vecs: Dict[str, np.ndarray] = {}
for entity_id, raw_vec in rows:
if isinstance(raw_vec, str):
vec = _json.loads(raw_vec)
elif isinstance(raw_vec, (bytes, bytearray)):
vec = _json.loads(raw_vec.decode())
elif isinstance(raw_vec, list):
vec = raw_vec
else:
try:
vec = list(raw_vec)
except Exception:
continue
fvec = np.array([float(v) if v is not None else 0.0 for v in vec])
mp_vecs[entity_id] = fvec
return mp_vecs
except Exception:
logger.exception("Failed to load MP vectors for window %s", window)
return {}
def query_similar(
db_path: str,
source_motion_id: int,
vector_type: str = "fused",
top_k: int = 10,
) -> pd.DataFrame:
"""Return top-k similar motions from similarity_cache (read-only)."""
try:
con = duckdb.connect(database=db_path, read_only=True)
rows = con.execute(
"""
SELECT sc.target_motion_id, sc.score, sc.window_id,
m.title, m.date, m.policy_area
FROM similarity_cache sc
JOIN motions m ON m.id = sc.target_motion_id
WHERE sc.source_motion_id = ?
AND sc.vector_type = ?
ORDER BY sc.score DESC
LIMIT ?
""",
[source_motion_id, vector_type, top_k],
).fetchdf()
con.close()
return rows
except Exception:
logger.exception(
"Failed to query similarity cache for motion %s", source_motion_id
)
return pd.DataFrame()
def load_mp_vectors_by_party(db_path: str) -> Dict[str, List[np.ndarray]]:
"""Load individual MP SVD vectors grouped by party for current_parliament.
Returns:
{party_name: [np.ndarray(50,), ...]} — one array per MP.
"""
import json as _json
try:
con = duckdb.connect(database=db_path, read_only=True)
meta_rows = con.execute(
"SELECT mp_name, party FROM mp_metadata "
"WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' "
"ORDER BY van ASC"
).fetchall()
mp_party: Dict[str, str] = {}
for mp_name, party in meta_rows:
if mp_name and party:
mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party)
rows = con.execute(
"SELECT entity_id, vector FROM svd_vectors "
"WHERE entity_type='mp' AND window_id='current_parliament'"
).fetchall()
con.close()
party_vecs: Dict[str, List[np.ndarray]] = {}
for entity_id, raw_vec in rows:
party = mp_party.get(entity_id)
if party is None or party not in CURRENT_PARLIAMENT_PARTIES:
continue
if isinstance(raw_vec, str):
vec = _json.loads(raw_vec)
elif isinstance(raw_vec, (bytes, bytearray)):
vec = _json.loads(raw_vec.decode())
elif isinstance(raw_vec, list):
vec = raw_vec
else:
try:
vec = list(raw_vec)
except Exception:
continue
fvec = np.array([float(v) if v is not None else 0.0 for v in vec])
party_vecs.setdefault(party, []).append(fvec)
return party_vecs
except Exception:
logger.exception("Failed to load MP vectors by party")
return {}
def load_mp_vectors_by_party_for_window(
db_path: str, window: str
) -> Dict[str, List[np.ndarray]]:
"""Load individual MP SVD vectors grouped by party for a specific window.
For historical windows, uses the MP→party mapping from that time period.
Returns:
{party_name: [np.ndarray(50,), ...]} — one array per MP.
"""
import json as _json
try:
con = duckdb.connect(database=db_path, read_only=True)
is_current = window == "current_parliament"
if is_current:
meta_rows = con.execute(
"SELECT mp_name, party FROM mp_metadata "
"WHERE van >= '2023-11-22' OR tot_en_met IS NULL OR tot_en_met >= '2023-11-22' "
"ORDER BY van ASC"
).fetchall()
else:
try:
year = int(window.split("-")[0])
except ValueError:
year = 2023
meta_rows = con.execute(
"SELECT mp_name, party FROM mp_metadata "
"WHERE van <= ? AND (tot_en_met IS NULL OR tot_en_met >= ?) "
"ORDER BY van ASC",
[f"{year}-12-31", f"{year}-01-01"],
).fetchall()
mp_party: Dict[str, str] = {}
for mp_name, party in meta_rows:
if mp_name and party:
mp_party[mp_name] = _PARTY_NORMALIZE.get(party, party)
rows = con.execute(
"SELECT entity_id, vector FROM svd_vectors "
"WHERE entity_type='mp' AND window_id=?",
[window],
).fetchall()
con.close()
party_vecs: Dict[str, List[np.ndarray]] = {}
for entity_id, raw_vec in rows:
party = mp_party.get(entity_id)
if party is None:
continue
if is_current and party not in CURRENT_PARLIAMENT_PARTIES:
continue
if isinstance(raw_vec, str):
vec = _json.loads(raw_vec)
elif isinstance(raw_vec, (bytes, bytearray)):
vec = _json.loads(raw_vec.decode())
elif isinstance(raw_vec, list):
vec = raw_vec
else:
try:
vec = list(raw_vec)
except Exception:
continue
fvec = np.array([float(v) if v is not None else 0.0 for v in vec])
party_vecs.setdefault(party, []).append(fvec)
return party_vecs
except Exception:
logger.exception("Failed to load MP vectors by party for window %s", window)
return {}
def compute_party_axis_scores(
party_vecs: Dict[str, List[np.ndarray]],
) -> Dict[str, List[float]]:
"""Compute per-party axis scores as mean of MP vectors.
Returns:
{party_name: [float * k]} — k = 50, mean over all MPs in that party.
"""
try:
return {
party: np.array(vecs).mean(axis=0).tolist()
for party, vecs in party_vecs.items()
}
except Exception:
logger.exception("Failed to compute party axis scores")
return {}