You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
127 lines
4.1 KiB
127 lines
4.1 KiB
"""political_axis.py — Project MP SVD vectors onto an ideological axis.
|
|
|
|
Two modes:
|
|
1. PCA mode (default): compute the first principal component of all MP SVD
|
|
vectors for a window and project each MP onto it. The sign is arbitrary
|
|
but consistent within a window.
|
|
|
|
2. Anchor mode: define the axis as the vector from the centroid of
|
|
``left_parties`` to the centroid of ``right_parties``. Project all MPs
|
|
onto this normalised anchor axis.
|
|
|
|
Both modes return a dict mapping mp_name → scalar score for the given window.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from typing import Dict, List, Optional
|
|
|
|
import numpy as np
|
|
import duckdb
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _load_mp_svd_vectors(db_path: str, window_id: str) -> Dict[str, np.ndarray]:
|
|
"""Load all MP SVD vectors for a window from svd_vectors table."""
|
|
conn = duckdb.connect(db_path)
|
|
rows = conn.execute(
|
|
"SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = 'mp'",
|
|
(window_id,),
|
|
).fetchall()
|
|
conn.close()
|
|
|
|
result = {}
|
|
for mp_name, vec_json in rows:
|
|
try:
|
|
result[mp_name] = np.array(json.loads(vec_json), dtype=float)
|
|
except Exception:
|
|
_logger.warning("Could not parse SVD vector for MP %s", mp_name)
|
|
return result
|
|
|
|
|
|
def compute_pca_axis(db_path: str, window_id: str) -> Dict[str, float]:
|
|
"""Project MP SVD vectors onto their first principal component.
|
|
|
|
Returns {mp_name: score}. Returns empty dict if fewer than 2 MPs.
|
|
"""
|
|
mp_vecs = _load_mp_svd_vectors(db_path, window_id)
|
|
if len(mp_vecs) < 2:
|
|
_logger.warning(
|
|
"window %s has only %d MPs; skipping PCA axis", window_id, len(mp_vecs)
|
|
)
|
|
return {}
|
|
|
|
names = list(mp_vecs.keys())
|
|
mat = np.vstack([mp_vecs[n] for n in names]) # (n_mps, k)
|
|
|
|
# Centre
|
|
mat_centred = mat - mat.mean(axis=0)
|
|
|
|
# First PC via SVD
|
|
try:
|
|
_, _, Vt = np.linalg.svd(mat_centred, full_matrices=False)
|
|
axis = Vt[0] # (k,)
|
|
except np.linalg.LinAlgError:
|
|
_logger.exception("SVD failed in compute_pca_axis for window %s", window_id)
|
|
return {}
|
|
|
|
projections = mat_centred.dot(axis)
|
|
return {name: float(score) for name, score in zip(names, projections)}
|
|
|
|
|
|
def compute_anchor_axis(
|
|
db_path: str,
|
|
window_id: str,
|
|
left_parties: List[str],
|
|
right_parties: List[str],
|
|
) -> Dict[str, float]:
|
|
"""Project MP SVD vectors onto a left↔right anchor axis.
|
|
|
|
The axis runs from the centroid of ``left_parties`` to the centroid of
|
|
``right_parties``. Positive scores are toward the right.
|
|
|
|
Returns {mp_name: score}.
|
|
"""
|
|
mp_vecs = _load_mp_svd_vectors(db_path, window_id)
|
|
if not mp_vecs:
|
|
return {}
|
|
|
|
left_set = set(left_parties)
|
|
right_set = set(right_parties)
|
|
|
|
# 1. Party-level actors whose entity_id IS a party name (e.g. "GroenLinks-PvdA")
|
|
left_vecs = [mp_vecs[p] for p in left_set if p in mp_vecs]
|
|
right_vecs = [mp_vecs[p] for p in right_set if p in mp_vecs]
|
|
|
|
# 2. Individual MPs via mp_metadata party affiliation
|
|
conn = duckdb.connect(db_path)
|
|
rows = conn.execute("SELECT mp_name, party FROM mp_metadata").fetchall()
|
|
conn.close()
|
|
for mp_name, party in rows:
|
|
if mp_name not in mp_vecs:
|
|
continue
|
|
if party in left_set and mp_name not in left_set:
|
|
left_vecs.append(mp_vecs[mp_name])
|
|
elif party in right_set and mp_name not in right_set:
|
|
right_vecs.append(mp_vecs[mp_name])
|
|
|
|
if not left_vecs or not right_vecs:
|
|
_logger.warning(
|
|
"window %s: insufficient anchor parties (left=%d, right=%d)",
|
|
window_id,
|
|
len(left_vecs),
|
|
len(right_vecs),
|
|
)
|
|
return {}
|
|
|
|
left_centroid = np.mean(left_vecs, axis=0)
|
|
right_centroid = np.mean(right_vecs, axis=0)
|
|
axis = right_centroid - left_centroid
|
|
norm = np.linalg.norm(axis)
|
|
if norm < 1e-10:
|
|
_logger.warning("Anchor axis has near-zero norm for window %s", window_id)
|
|
return {}
|
|
axis = axis / norm
|
|
|
|
return {name: float(np.dot(vec, axis)) for name, vec in mp_vecs.items()}
|
|
|