You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/analysis/political_axis.py

125 lines
3.8 KiB

"""political_axis.py — Project MP SVD vectors onto an ideological axis.
Two modes:
1. PCA mode (default): compute the first principal component of all MP SVD
vectors for a window and project each MP onto it. The sign is arbitrary
but consistent within a window.
2. Anchor mode: define the axis as the vector from the centroid of
``left_parties`` to the centroid of ``right_parties``. Project all MPs
onto this normalised anchor axis.
Both modes return a dict mapping mp_name → scalar score for the given window.
"""
import json
import logging
from typing import Dict, List, Optional
import numpy as np
import duckdb
_logger = logging.getLogger(__name__)
def _load_mp_svd_vectors(db_path: str, window_id: str) -> Dict[str, np.ndarray]:
"""Load all MP SVD vectors for a window from svd_vectors table."""
conn = duckdb.connect(db_path)
rows = conn.execute(
"SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = 'mp'",
(window_id,),
).fetchall()
conn.close()
result = {}
for mp_name, vec_json in rows:
try:
result[mp_name] = np.array(json.loads(vec_json), dtype=float)
except Exception:
_logger.warning("Could not parse SVD vector for MP %s", mp_name)
return result
def compute_pca_axis(db_path: str, window_id: str) -> Dict[str, float]:
"""Project MP SVD vectors onto their first principal component.
Returns {mp_name: score}. Returns empty dict if fewer than 2 MPs.
"""
mp_vecs = _load_mp_svd_vectors(db_path, window_id)
if len(mp_vecs) < 2:
_logger.warning(
"window %s has only %d MPs; skipping PCA axis", window_id, len(mp_vecs)
)
return {}
names = list(mp_vecs.keys())
mat = np.vstack([mp_vecs[n] for n in names]) # (n_mps, k)
# Centre
mat_centred = mat - mat.mean(axis=0)
# First PC via SVD
try:
_, _, Vt = np.linalg.svd(mat_centred, full_matrices=False)
axis = Vt[0] # (k,)
except np.linalg.LinAlgError:
_logger.exception("SVD failed in compute_pca_axis for window %s", window_id)
return {}
projections = mat_centred.dot(axis)
return {name: float(score) for name, score in zip(names, projections)}
def compute_anchor_axis(
db_path: str,
window_id: str,
left_parties: List[str],
right_parties: List[str],
) -> Dict[str, float]:
"""Project MP SVD vectors onto a left↔right anchor axis.
The axis runs from the centroid of ``left_parties`` to the centroid of
``right_parties``. Positive scores are toward the right.
Returns {mp_name: score}.
"""
mp_vecs = _load_mp_svd_vectors(db_path, window_id)
if not mp_vecs:
return {}
# Load party affiliation for this window from mp_metadata
conn = duckdb.connect(db_path)
rows = conn.execute("SELECT mp_name, party FROM mp_metadata").fetchall()
conn.close()
party_of = {mp: party for mp, party in rows}
left_vecs = [
mp_vecs[mp]
for mp, party in party_of.items()
if party in left_parties and mp in mp_vecs
]
right_vecs = [
mp_vecs[mp]
for mp, party in party_of.items()
if party in right_parties and mp in mp_vecs
]
if not left_vecs or not right_vecs:
_logger.warning(
"window %s: insufficient anchor parties (left=%d, right=%d)",
window_id,
len(left_vecs),
len(right_vecs),
)
return {}
left_centroid = np.mean(left_vecs, axis=0)
right_centroid = np.mean(right_vecs, axis=0)
axis = right_centroid - left_centroid
norm = np.linalg.norm(axis)
if norm < 1e-10:
_logger.warning("Anchor axis has near-zero norm for window %s", window_id)
return {}
axis = axis / norm
return {name: float(np.dot(vec, axis)) for name, vec in mp_vecs.items()}