You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/scripts/compare_svd_exclude_parties.py

204 lines
6.6 KiB

"""Compare PCA axes with and without party-level vectors present.
Generates diagnostics and HTML plots (when plotly available) into outputs/.
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import sys
from typing import Dict, List
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("compare_svd_exclude_parties")
def main(argv: List[str] | None = None):
p = argparse.ArgumentParser()
p.add_argument("--db", default="data/motions.db")
p.add_argument("--out", default="outputs")
args = p.parse_args(argv)
os.makedirs(args.out, exist_ok=True)
try:
from analysis import trajectory as traj
from analysis.visualize import (
_load_party_map,
plot_political_compass,
plot_2d_trajectories,
)
import numpy as np
except Exception as e:
logger.exception("Failed to import analysis modules: %s", e)
raise
window_ids = traj._load_window_ids(args.db)
if not window_ids:
logger.error("No SVD windows found")
return 1
latest = sorted(window_ids)[-1]
# load raw vectors for latest window
conn = None
try:
# build party name set from mp_metadata
import duckdb
conn = duckdb.connect(args.db)
rows = conn.execute(
"SELECT DISTINCT party FROM mp_metadata WHERE party IS NOT NULL"
).fetchall()
party_names = set(r[0] for r in rows if r[0])
finally:
if conn:
try:
conn.close()
except Exception:
pass
raw = traj._load_mp_vectors_for_window(args.db, latest)
# group by vector JSON-like key
groups: Dict[str, List[str]] = {}
for ent, vec in raw.items():
key = tuple([round(float(x), 8) for x in vec.tolist()])
groups.setdefault(str(key), []).append(ent)
group_list = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)
top_groups = [(len(v), v[:8]) for k, v in group_list[:20]]
logger.info("Top duplicate groups (count, sample entities): %s", top_groups)
# entities that are party names
party_entities = [ent for ent in raw.keys() if ent in party_names]
logger.info(
"Found %d party-like entities in svd_vectors for %s",
len(party_entities),
latest,
)
# Build aligned windows excluding party-level entities
raw_window_vecs = {
wid: traj._load_mp_vectors_for_window(args.db, wid) for wid in window_ids
}
# create filtered copy that removes party-level entity ids
filtered_window_vecs = {
wid: {ent: vec for ent, vec in d.items() if ent not in party_names}
for wid, d in raw_window_vecs.items()
}
aligned_filtered = traj._procrustes_align_windows(filtered_window_vecs)
# stack and compute PCA
all_vecs = []
entity_index = []
for wid, d in aligned_filtered.items():
for ent, v in d.items():
n = np.linalg.norm(v)
all_vecs.append(v / n if n > 1e-10 else v)
entity_index.append((wid, ent))
if not all_vecs:
logger.error("No vectors left after excluding parties — aborting")
return 2
M = np.vstack(all_vecs)
Mc = M - M.mean(axis=0)
try:
U, s, Vt = np.linalg.svd(Mc, full_matrices=False)
except Exception:
logger.exception("SVD failed on filtered data")
return 3
sv2 = s**2
evr = sv2 / (sv2.sum() + 1e-20)
logger.info("Filtered PCA EVR top2: %s", evr[:2].tolist())
comp1 = Vt[0]
comp1_hat = comp1 / (np.linalg.norm(comp1) + 1e-12)
comp2 = Vt[1] if Vt.shape[0] > 1 else np.zeros_like(comp1)
comp2_hat = comp2 / (np.linalg.norm(comp2) + 1e-12)
# project filtered entities for latest window
filtered_positions = {}
global_mean = M.mean(axis=0)
for (wid, ent), vec in zip(entity_index, M):
if wid != latest:
continue
v_centered = vec - global_mean
x = float(np.dot(v_centered, comp1_hat))
y = float(np.dot(v_centered, comp2_hat))
filtered_positions[ent] = (x, y)
# save JSON and small report
out_json = os.path.join(args.out, "svd_filtered_positions.json")
with open(out_json, "w", encoding="utf-8") as f:
json.dump(
{
"latest": latest,
"positions": filtered_positions,
"evr": evr[:2].tolist(),
},
f,
indent=2,
)
logger.info("Wrote filtered positions to %s", out_json)
# Also generate plots if plotly available
try:
party_map = _load_party_map(args.db)
# positions_by_window format expected by plot functions — include only latest
positions_by_window = {latest: filtered_positions}
pcomp_out = os.path.join(args.out, f"political_compass_filtered_{latest}.html")
plot_political_compass(
positions_by_window,
window_id=latest,
party_of=party_map,
axis_def={"method": "pca", "explained_variance_ratio": evr[:2]},
output_path=pcomp_out,
)
logger.info("Wrote filtered compass to %s", pcomp_out)
# simple trajectory plotting for filtered set — top movers by count
traj_out = os.path.join(args.out, f"trajectories_filtered_{latest}.html")
# Build simple per-MP coords across windows for filtered set
mp_coords = {}
for wid in window_ids:
for ent, coord in aligned_filtered.get(wid, {}).items():
if ent not in mp_coords:
mp_coords[ent] = []
mp_coords[ent].append((wid, tuple(coord.tolist())))
# pick MPs with at least 2 windows
names = [n for n, v in mp_coords.items() if len(v) >= 2]
plot_2d_trajectories(
{
wid: {
n: mp_coords[n][i][1]
for n in names
for i, (w, _) in enumerate(mp_coords[n])
if w == wid
}
for wid in window_ids
},
mp_names=names[:50],
output_path=traj_out,
)
logger.info("Wrote filtered trajectories to %s", traj_out)
except Exception:
logger.exception("Plotting filtered results failed — plots skipped")
# console summary
print("Top duplicate groups (count, sample):")
for k, v in group_list[:20]:
print(len(v), v[:6])
return 0
if __name__ == "__main__":
raise SystemExit(main())