"""SVD and PCA diagnostics for the political compass pipeline. Produces a small text report and JSON summary in the outputs/ directory. Usage: uv run python3 scripts/svd_diagnostics.py --db data/motions.db --out outputs """ from __future__ import annotations import argparse import json import logging import os import sys from statistics import mean from typing import Dict, List, Optional, Tuple ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if ROOT not in sys.path: sys.path.insert(0, ROOT) logger = logging.getLogger("svd_diagnostics") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") def find_by_substring(names: List[str], query: str) -> List[str]: q = query.lower() return [n for n in names if q in n.lower()] def main(argv: Optional[list] = None): p = argparse.ArgumentParser() p.add_argument("--db", default="data/motions.db") p.add_argument("--out", default="outputs") args = p.parse_args(argv) os.makedirs(args.out, exist_ok=True) try: from analysis import trajectory as traj from analysis.political_axis import compute_2d_axes from analysis.visualize import _load_party_map except Exception as e: # pragma: no cover - runtime logger.exception("Could not import analysis modules: %s", e) raise # Load windows and aligned vectors window_ids = traj._load_window_ids(args.db) if not window_ids: logger.error("No SVD windows found in DB %s", args.db) return 1 logger.info("Found windows: %s", window_ids) raw_window_vecs = { wid: traj._load_mp_vectors_for_window(args.db, wid) for wid in window_ids } aligned_window_vecs = traj._procrustes_align_windows(raw_window_vecs) # Compute global PCA axes (residual and non-residual) for comparison positions_residual, axes_residual = compute_2d_axes( args.db, window_ids=window_ids, method="pca", normalize_vectors=True, pca_residual=True, ) positions_plain, axes_plain = compute_2d_axes( args.db, window_ids=window_ids, method="pca", normalize_vectors=True, pca_residual=False, ) out_report = [] def add(line: str): out_report.append(line) logger.info(line) add("PCA diagnostics report") add(f"DB: {args.db}") add(f"Windows: {window_ids}") add("") evr_res = axes_residual.get("explained_variance_ratio") if axes_residual else None evr_plain = axes_plain.get("explained_variance_ratio") if axes_plain else None add(f"Residual PCA EVR: {evr_res}") add(f"Plain PCA EVR: {evr_plain}") # pick latest window for detailed inspection latest = sorted(window_ids)[-1] add("") add(f"Inspecting latest window: {latest}") pos = positions_residual.get(latest, {}) names = list(pos.keys()) xs = [v[0] for v in pos.values()] ys = [v[1] for v in pos.values()] def stats(arr: List[float]) -> Tuple[float, float]: if not arr: return 0.0, 0.0 mn = min(arr) mx = max(arr) return mn, mx add(f"Entities in latest window: {len(names)}") add(f"X range (left-right): {stats(xs)}") add(f"Y range (prog-cons): {stats(ys)}") # stdevs try: import numpy as _np x_std = float(_np.std(xs)) y_std = float(_np.std(ys)) except Exception: x_std = 0.0 y_std = 0.0 add( f"Std dev X: {x_std:.6f}, Std dev Y: {y_std:.6f} (ratio Y/X = {y_std / (x_std + 1e-12):.3f})" ) # show extremes on X and Y sorted_by_x = sorted(pos.items(), key=lambda kv: kv[1][0]) sorted_by_y = sorted(pos.items(), key=lambda kv: kv[1][1]) add("") add("Left-most (by X):") for name, (x, y) in sorted_by_x[:8]: add(f" {name:40s} x={x:.4f} y={y:.4f}") add("") add("Right-most (by X):") for name, (x, y) in sorted_by_x[-8:]: add(f" {name:40s} x={x:.4f} y={y:.4f}") add("") add("Top (conservative) (by Y):") for name, (x, y) in sorted_by_y[-8:]: add(f" {name:40s} x={x:.4f} y={y:.4f}") add("") add("Bottom (progressive) (by Y):") for name, (x, y) in sorted_by_y[:8]: add(f" {name:40s} x={x:.4f} y={y:.4f}") # Find specific MPs mentioned by user matches_ouwehand = find_by_substring(names, "ouwehand") matches_mona = find_by_substring(names, "mona") add("") add(f"Matches for 'Ouwehand': {matches_ouwehand}") for n in matches_ouwehand: x, y = pos.get(n) add(f" {n} -> x={x:.4f} y={y:.4f}") add(f"Matches for 'Mona': {matches_mona}") for n in matches_mona: x, y = pos.get(n) add(f" {n} -> x={x:.4f} y={y:.4f}") # Party centroids party_map = _load_party_map(args.db) parties: Dict[str, List[Tuple[float, float]]] = {} for mp, coord in pos.items(): party = party_map.get(mp) if party: parties.setdefault(party, []).append(coord) party_centroids: Dict[str, Tuple[float, float]] = {} for party, coords in parties.items(): xs_p = [c[0] for c in coords] ys_p = [c[1] for c in coords] party_centroids[party] = (mean(xs_p), mean(ys_p)) add("") add(f"Computed {len(party_centroids)} party centroids (from mp_metadata majority)") sorted_parties_by_x = sorted(party_centroids.items(), key=lambda kv: kv[1][0]) add("Party centroids left→right:") for p, (x, y) in sorted_parties_by_x: add(f" {p:20s} x={x:.4f} y={y:.4f}") sorted_parties_by_y = sorted(party_centroids.items(), key=lambda kv: kv[1][1]) add("") add("Party centroids prog→cons:") for p, (x, y) in sorted_parties_by_y: add(f" {p:20s} x={x:.4f} y={y:.4f}") # Save report and a small JSON summary report_path = os.path.join(args.out, "svd_diagnostics.txt") summary_path = os.path.join(args.out, "svd_diagnostics.json") with open(report_path, "w", encoding="utf-8") as f: f.write("\n".join(out_report)) summary = { "db": args.db, "windows": window_ids, "latest_window": latest, "evr_residual": evr_res, "evr_plain": evr_plain, "n_entities_latest": len(names), "x_std": x_std, "y_std": y_std, "party_centroids": party_centroids, } with open(summary_path, "w", encoding="utf-8") as f: json.dump(summary, f, indent=2) logger.info("Diagnostic report written to %s and %s", report_path, summary_path) return 0 if __name__ == "__main__": raise SystemExit(main())