"""Compare PCA axes with and without party-level vectors present. Generates diagnostics and HTML plots (when plotly available) into outputs/. """ from __future__ import annotations import argparse import json import logging import os import sys from typing import Dict, List ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if ROOT not in sys.path: sys.path.insert(0, ROOT) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("compare_svd_exclude_parties") def main(argv: List[str] | None = None): p = argparse.ArgumentParser() p.add_argument("--db", default="data/motions.db") p.add_argument("--out", default="outputs") args = p.parse_args(argv) os.makedirs(args.out, exist_ok=True) try: from analysis import trajectory as traj from analysis.visualize import ( _load_party_map, plot_political_compass, plot_2d_trajectories, ) import numpy as np except Exception as e: logger.exception("Failed to import analysis modules: %s", e) raise window_ids = traj._load_window_ids(args.db) if not window_ids: logger.error("No SVD windows found") return 1 latest = sorted(window_ids)[-1] # load raw vectors for latest window conn = None try: # build party name set from mp_metadata import duckdb conn = duckdb.connect(args.db) rows = conn.execute( "SELECT DISTINCT party FROM mp_metadata WHERE party IS NOT NULL" ).fetchall() party_names = set(r[0] for r in rows if r[0]) finally: if conn: try: conn.close() except Exception: pass raw = traj._load_mp_vectors_for_window(args.db, latest) # group by vector JSON-like key groups: Dict[str, List[str]] = {} for ent, vec in raw.items(): key = tuple([round(float(x), 8) for x in vec.tolist()]) groups.setdefault(str(key), []).append(ent) group_list = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True) top_groups = [(len(v), v[:8]) for k, v in group_list[:20]] logger.info("Top duplicate groups (count, sample entities): %s", top_groups) # entities that are party names party_entities = [ent for ent in raw.keys() if ent in party_names] logger.info( "Found %d party-like entities in svd_vectors for %s", len(party_entities), latest, ) # Build aligned windows excluding party-level entities raw_window_vecs = { wid: traj._load_mp_vectors_for_window(args.db, wid) for wid in window_ids } # create filtered copy that removes party-level entity ids filtered_window_vecs = { wid: {ent: vec for ent, vec in d.items() if ent not in party_names} for wid, d in raw_window_vecs.items() } aligned_filtered = traj._procrustes_align_windows(filtered_window_vecs) # stack and compute PCA all_vecs = [] entity_index = [] for wid, d in aligned_filtered.items(): for ent, v in d.items(): n = np.linalg.norm(v) all_vecs.append(v / n if n > 1e-10 else v) entity_index.append((wid, ent)) if not all_vecs: logger.error("No vectors left after excluding parties — aborting") return 2 M = np.vstack(all_vecs) Mc = M - M.mean(axis=0) try: U, s, Vt = np.linalg.svd(Mc, full_matrices=False) except Exception: logger.exception("SVD failed on filtered data") return 3 sv2 = s**2 evr = sv2 / (sv2.sum() + 1e-20) logger.info("Filtered PCA EVR top2: %s", evr[:2].tolist()) comp1 = Vt[0] comp1_hat = comp1 / (np.linalg.norm(comp1) + 1e-12) comp2 = Vt[1] if Vt.shape[0] > 1 else np.zeros_like(comp1) comp2_hat = comp2 / (np.linalg.norm(comp2) + 1e-12) # project filtered entities for latest window filtered_positions = {} global_mean = M.mean(axis=0) for (wid, ent), vec in zip(entity_index, M): if wid != latest: continue v_centered = vec - global_mean x = float(np.dot(v_centered, comp1_hat)) y = float(np.dot(v_centered, comp2_hat)) filtered_positions[ent] = (x, y) # save JSON and small report out_json = os.path.join(args.out, "svd_filtered_positions.json") with open(out_json, "w", encoding="utf-8") as f: json.dump( { "latest": latest, "positions": filtered_positions, "evr": evr[:2].tolist(), }, f, indent=2, ) logger.info("Wrote filtered positions to %s", out_json) # Also generate plots if plotly available try: party_map = _load_party_map(args.db) # positions_by_window format expected by plot functions — include only latest positions_by_window = {latest: filtered_positions} pcomp_out = os.path.join(args.out, f"political_compass_filtered_{latest}.html") plot_political_compass( positions_by_window, window_id=latest, party_of=party_map, axis_def={"method": "pca", "explained_variance_ratio": evr[:2]}, output_path=pcomp_out, ) logger.info("Wrote filtered compass to %s", pcomp_out) # simple trajectory plotting for filtered set — top movers by count traj_out = os.path.join(args.out, f"trajectories_filtered_{latest}.html") # Build simple per-MP coords across windows for filtered set mp_coords = {} for wid in window_ids: for ent, coord in aligned_filtered.get(wid, {}).items(): if ent not in mp_coords: mp_coords[ent] = [] mp_coords[ent].append((wid, tuple(coord.tolist()))) # pick MPs with at least 2 windows names = [n for n, v in mp_coords.items() if len(v) >= 2] plot_2d_trajectories( { wid: { n: mp_coords[n][i][1] for n in names for i, (w, _) in enumerate(mp_coords[n]) if w == wid } for wid in window_ids }, mp_names=names[:50], output_path=traj_out, ) logger.info("Wrote filtered trajectories to %s", traj_out) except Exception: logger.exception("Plotting filtered results failed — plots skipped") # console summary print("Top duplicate groups (count, sample):") for k, v in group_list[:20]: print(len(v), v[:6]) return 0 if __name__ == "__main__": raise SystemExit(main())