You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/scripts/inspect_axis.py

137 lines
4.2 KiB

"""Inspect PCA axes and per-MP projections for diagnostics.
Usage:
uv run python3 scripts/inspect_axis.py --db data/motions.db --out outputs
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import sys
from typing import Dict, List
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("inspect_axis")
def main(argv: List[str] | None = None):
p = argparse.ArgumentParser()
p.add_argument("--db", default="data/motions.db")
p.add_argument("--out", default="outputs")
p.add_argument("--method", choices=["pca", "anchor"], default="pca")
p.add_argument("--pca-residual", action="store_true")
p.add_argument("--normalize", action="store_true", default=True)
args = p.parse_args(argv)
os.makedirs(args.out, exist_ok=True)
try:
from analysis.political_axis import compute_2d_axes
from analysis.visualize import _load_party_map
except Exception as e:
logger.exception("Failed to import analysis modules: %s", e)
raise
positions_by_window, axes = compute_2d_axes(
args.db,
method=args.method,
pca_residual=args.pca_residual,
normalize_vectors=args.normalize,
)
if not positions_by_window:
logger.error("No positions produced")
return 2
latest = sorted(positions_by_window.keys())[-1]
pos = positions_by_window[latest]
names = list(pos.keys())
coords = list(pos.values())
xs = [c[0] for c in coords]
ys = [c[1] for c in coords]
import numpy as _np
x_std = float(_np.std(xs))
y_std = float(_np.std(ys))
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
party_map = _load_party_map(args.db)
# load mp_votes counts
try:
import duckdb
conn = duckdb.connect(args.db)
rows = conn.execute(
"SELECT mp_name, COUNT(*) FROM mp_votes GROUP BY mp_name"
).fetchall()
conn.close()
vote_counts = {r[0]: int(r[1]) for r in rows}
except Exception:
vote_counts = {}
# extremes
sorted_by_x = sorted(pos.items(), key=lambda kv: kv[1][0])
sorted_by_y = sorted(pos.items(), key=lambda kv: kv[1][1])
def info_for(name: str):
party = party_map.get(name)
count = vote_counts.get(name, None)
x, y = pos.get(name, (None, None))
return {"name": name, "party": party, "count": count, "x": x, "y": y}
report = {
"db": args.db,
"latest_window": latest,
"n_entities": len(names),
"x_std": x_std,
"y_std": y_std,
"x_min": x_min,
"x_max": x_max,
"y_min": y_min,
"y_max": y_max,
"evr": axes.get("explained_variance_ratio") if axes else None,
"top_left_by_x": [info_for(n) for n, _ in sorted_by_x[:10]],
"top_right_by_x": [info_for(n) for n, _ in sorted_by_x[-10:]],
"top_by_y": [info_for(n) for n, _ in sorted_by_y[-10:]],
"bottom_by_y": [info_for(n) for n, _ in sorted_by_y[:10]],
}
# count how many are near-center along x within small fraction of std
threshold = 0.2 * x_std if x_std > 0 else 0.01
near_center = [n for n, (x, y) in pos.items() if abs(x) < threshold]
report["near_center_count"] = len(near_center)
report["near_center_sample"] = near_center[:40]
# check duplicate coordinate pairs
coord_pairs = [(_np.round(c[0], 6), _np.round(c[1], 6)) for c in coords]
unique_coords = set(coord_pairs)
report["n_unique_coords"] = len(unique_coords)
report["n_total_entities"] = len(names)
# look up particular MPs
for q in ("Ouwehand", "Keijzer", "Mona"):
found = [n for n in names if q.lower() in n.lower()]
report[f"matches_{q}"] = [info_for(n) for n in found]
out_json = os.path.join(args.out, "inspect_axis.json")
with open(out_json, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2)
logger.info("Wrote inspection to %s", out_json)
print(json.dumps(report, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())