"""Inspect PCA axes and per-MP projections for diagnostics. Usage: uv run python3 scripts/inspect_axis.py --db data/motions.db --out outputs """ from __future__ import annotations import argparse import json import logging import os import sys from typing import Dict, List ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if ROOT not in sys.path: sys.path.insert(0, ROOT) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("inspect_axis") def main(argv: List[str] | None = None): p = argparse.ArgumentParser() p.add_argument("--db", default="data/motions.db") p.add_argument("--out", default="outputs") p.add_argument("--method", choices=["pca", "anchor"], default="pca") p.add_argument("--pca-residual", action="store_true") p.add_argument("--normalize", action="store_true", default=True) args = p.parse_args(argv) os.makedirs(args.out, exist_ok=True) try: from analysis.political_axis import compute_2d_axes from analysis.visualize import _load_party_map except Exception as e: logger.exception("Failed to import analysis modules: %s", e) raise positions_by_window, axes = compute_2d_axes( args.db, method=args.method, pca_residual=args.pca_residual, normalize_vectors=args.normalize, ) if not positions_by_window: logger.error("No positions produced") return 2 latest = sorted(positions_by_window.keys())[-1] pos = positions_by_window[latest] names = list(pos.keys()) coords = list(pos.values()) xs = [c[0] for c in coords] ys = [c[1] for c in coords] import numpy as _np x_std = float(_np.std(xs)) y_std = float(_np.std(ys)) x_min, x_max = min(xs), max(xs) y_min, y_max = min(ys), max(ys) party_map = _load_party_map(args.db) # load mp_votes counts try: import duckdb conn = duckdb.connect(args.db) rows = conn.execute( "SELECT mp_name, COUNT(*) FROM mp_votes GROUP BY mp_name" ).fetchall() conn.close() vote_counts = {r[0]: int(r[1]) for r in rows} except Exception: vote_counts = {} # extremes sorted_by_x = sorted(pos.items(), key=lambda kv: kv[1][0]) sorted_by_y = sorted(pos.items(), key=lambda kv: kv[1][1]) def info_for(name: str): party = party_map.get(name) count = vote_counts.get(name, None) x, y = pos.get(name, (None, None)) return {"name": name, "party": party, "count": count, "x": x, "y": y} report = { "db": args.db, "latest_window": latest, "n_entities": len(names), "x_std": x_std, "y_std": y_std, "x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max, "evr": axes.get("explained_variance_ratio") if axes else None, "top_left_by_x": [info_for(n) for n, _ in sorted_by_x[:10]], "top_right_by_x": [info_for(n) for n, _ in sorted_by_x[-10:]], "top_by_y": [info_for(n) for n, _ in sorted_by_y[-10:]], "bottom_by_y": [info_for(n) for n, _ in sorted_by_y[:10]], } # count how many are near-center along x within small fraction of std threshold = 0.2 * x_std if x_std > 0 else 0.01 near_center = [n for n, (x, y) in pos.items() if abs(x) < threshold] report["near_center_count"] = len(near_center) report["near_center_sample"] = near_center[:40] # check duplicate coordinate pairs coord_pairs = [(_np.round(c[0], 6), _np.round(c[1], 6)) for c in coords] unique_coords = set(coord_pairs) report["n_unique_coords"] = len(unique_coords) report["n_total_entities"] = len(names) # look up particular MPs for q in ("Ouwehand", "Keijzer", "Mona"): found = [n for n in names if q.lower() in n.lower()] report[f"matches_{q}"] = [info_for(n) for n in found] out_json = os.path.join(args.out, "inspect_axis.json") with open(out_json, "w", encoding="utf-8") as f: json.dump(report, f, indent=2) logger.info("Wrote inspection to %s", out_json) print(json.dumps(report, indent=2)) return 0 if __name__ == "__main__": raise SystemExit(main())