You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
123 lines
3.7 KiB
123 lines
3.7 KiB
"""trajectory.py — Compute MP political drift across aligned time windows.
|
|
|
|
For each MP that appears in multiple windows, computes:
|
|
- The aligned SVD vector per window
|
|
- The Euclidean distance between consecutive windows (drift)
|
|
- Total cumulative drift
|
|
|
|
Returns a dict keyed by mp_name containing per-window positions and drift scores.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from typing import Dict, List, Optional
|
|
|
|
import numpy as np
|
|
import duckdb
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _load_window_ids(db_path: str) -> List[str]:
|
|
"""Return all distinct window IDs from svd_vectors, in lexicographic order."""
|
|
conn = duckdb.connect(db_path)
|
|
rows = conn.execute(
|
|
"SELECT DISTINCT window_id FROM svd_vectors WHERE entity_type = 'mp' ORDER BY window_id"
|
|
).fetchall()
|
|
conn.close()
|
|
return [r[0] for r in rows]
|
|
|
|
|
|
def _load_mp_vectors_for_window(db_path: str, window_id: str) -> Dict[str, np.ndarray]:
|
|
conn = duckdb.connect(db_path)
|
|
rows = conn.execute(
|
|
"SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = 'mp'",
|
|
(window_id,),
|
|
).fetchall()
|
|
conn.close()
|
|
result = {}
|
|
for mp_name, vec_json in rows:
|
|
try:
|
|
result[mp_name] = np.array(json.loads(vec_json), dtype=float)
|
|
except Exception:
|
|
_logger.warning(
|
|
"Could not parse vector for MP %s window %s", mp_name, window_id
|
|
)
|
|
return result
|
|
|
|
|
|
def compute_trajectories(
|
|
db_path: str,
|
|
window_ids: Optional[List[str]] = None,
|
|
) -> Dict[str, Dict]:
|
|
"""Compute per-MP trajectories across windows.
|
|
|
|
Returns:
|
|
{
|
|
mp_name: {
|
|
"windows": [window_id, ...],
|
|
"vectors": [[...], ...], # one vector per window
|
|
"drift": [float, ...], # consecutive Euclidean distances
|
|
"total_drift": float,
|
|
}
|
|
}
|
|
Only MPs present in at least 2 windows are included.
|
|
"""
|
|
if window_ids is None:
|
|
window_ids = _load_window_ids(db_path)
|
|
|
|
if len(window_ids) < 2:
|
|
_logger.info("Fewer than 2 windows — no trajectories to compute")
|
|
return {}
|
|
|
|
# Collect per-window vectors for each MP
|
|
mp_data: Dict[str, Dict] = {}
|
|
|
|
for wid in window_ids:
|
|
vecs = _load_mp_vectors_for_window(db_path, wid)
|
|
for mp_name, vec in vecs.items():
|
|
if mp_name not in mp_data:
|
|
mp_data[mp_name] = {"windows": [], "vectors": []}
|
|
mp_data[mp_name]["windows"].append(wid)
|
|
mp_data[mp_name]["vectors"].append(vec)
|
|
|
|
# Compute drift for MPs with >= 2 windows
|
|
result = {}
|
|
for mp_name, data in mp_data.items():
|
|
if len(data["windows"]) < 2:
|
|
continue
|
|
vecs = data["vectors"]
|
|
drifts = [
|
|
float(np.linalg.norm(vecs[i + 1] - vecs[i])) for i in range(len(vecs) - 1)
|
|
]
|
|
result[mp_name] = {
|
|
"windows": data["windows"],
|
|
"vectors": [v.tolist() for v in vecs],
|
|
"drift": drifts,
|
|
"total_drift": float(sum(drifts)),
|
|
}
|
|
|
|
_logger.info(
|
|
"Trajectories computed for %d MPs across %d windows",
|
|
len(result),
|
|
len(window_ids),
|
|
)
|
|
return result
|
|
|
|
|
|
def top_drifters(trajectories: Dict[str, Dict], n: int = 10) -> List[Dict]:
|
|
"""Return the top-n MPs by total drift, sorted descending.
|
|
|
|
Each entry: {"mp_name": ..., "total_drift": ..., "windows": [...]}
|
|
"""
|
|
ranked = sorted(
|
|
trajectories.items(), key=lambda kv: kv[1]["total_drift"], reverse=True
|
|
)
|
|
return [
|
|
{
|
|
"mp_name": mp,
|
|
"total_drift": data["total_drift"],
|
|
"windows": data["windows"],
|
|
}
|
|
for mp, data in ranked[:n]
|
|
]
|
|
|