You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
4.2 KiB
137 lines
4.2 KiB
"""Inspect PCA axes and per-MP projections for diagnostics.
|
|
|
|
Usage:
|
|
uv run python3 scripts/inspect_axis.py --db data/motions.db --out outputs
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
from typing import Dict, List
|
|
|
|
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
if ROOT not in sys.path:
|
|
sys.path.insert(0, ROOT)
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
logger = logging.getLogger("inspect_axis")
|
|
|
|
|
|
def main(argv: List[str] | None = None):
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--db", default="data/motions.db")
|
|
p.add_argument("--out", default="outputs")
|
|
p.add_argument("--method", choices=["pca", "anchor"], default="pca")
|
|
p.add_argument("--pca-residual", action="store_true")
|
|
p.add_argument("--normalize", action="store_true", default=True)
|
|
args = p.parse_args(argv)
|
|
|
|
os.makedirs(args.out, exist_ok=True)
|
|
|
|
try:
|
|
from analysis.political_axis import compute_2d_axes
|
|
from analysis.visualize import _load_party_map
|
|
except Exception as e:
|
|
logger.exception("Failed to import analysis modules: %s", e)
|
|
raise
|
|
|
|
positions_by_window, axes = compute_2d_axes(
|
|
args.db,
|
|
method=args.method,
|
|
pca_residual=args.pca_residual,
|
|
normalize_vectors=args.normalize,
|
|
)
|
|
|
|
if not positions_by_window:
|
|
logger.error("No positions produced")
|
|
return 2
|
|
|
|
latest = sorted(positions_by_window.keys())[-1]
|
|
pos = positions_by_window[latest]
|
|
|
|
names = list(pos.keys())
|
|
coords = list(pos.values())
|
|
xs = [c[0] for c in coords]
|
|
ys = [c[1] for c in coords]
|
|
|
|
import numpy as _np
|
|
|
|
x_std = float(_np.std(xs))
|
|
y_std = float(_np.std(ys))
|
|
x_min, x_max = min(xs), max(xs)
|
|
y_min, y_max = min(ys), max(ys)
|
|
|
|
party_map = _load_party_map(args.db)
|
|
|
|
# load mp_votes counts
|
|
try:
|
|
import duckdb
|
|
|
|
conn = duckdb.connect(args.db)
|
|
rows = conn.execute(
|
|
"SELECT mp_name, COUNT(*) FROM mp_votes GROUP BY mp_name"
|
|
).fetchall()
|
|
conn.close()
|
|
vote_counts = {r[0]: int(r[1]) for r in rows}
|
|
except Exception:
|
|
vote_counts = {}
|
|
|
|
# extremes
|
|
sorted_by_x = sorted(pos.items(), key=lambda kv: kv[1][0])
|
|
sorted_by_y = sorted(pos.items(), key=lambda kv: kv[1][1])
|
|
|
|
def info_for(name: str):
|
|
party = party_map.get(name)
|
|
count = vote_counts.get(name, None)
|
|
x, y = pos.get(name, (None, None))
|
|
return {"name": name, "party": party, "count": count, "x": x, "y": y}
|
|
|
|
report = {
|
|
"db": args.db,
|
|
"latest_window": latest,
|
|
"n_entities": len(names),
|
|
"x_std": x_std,
|
|
"y_std": y_std,
|
|
"x_min": x_min,
|
|
"x_max": x_max,
|
|
"y_min": y_min,
|
|
"y_max": y_max,
|
|
"evr": axes.get("explained_variance_ratio") if axes else None,
|
|
"top_left_by_x": [info_for(n) for n, _ in sorted_by_x[:10]],
|
|
"top_right_by_x": [info_for(n) for n, _ in sorted_by_x[-10:]],
|
|
"top_by_y": [info_for(n) for n, _ in sorted_by_y[-10:]],
|
|
"bottom_by_y": [info_for(n) for n, _ in sorted_by_y[:10]],
|
|
}
|
|
|
|
# count how many are near-center along x within small fraction of std
|
|
threshold = 0.2 * x_std if x_std > 0 else 0.01
|
|
near_center = [n for n, (x, y) in pos.items() if abs(x) < threshold]
|
|
report["near_center_count"] = len(near_center)
|
|
report["near_center_sample"] = near_center[:40]
|
|
|
|
# check duplicate coordinate pairs
|
|
coord_pairs = [(_np.round(c[0], 6), _np.round(c[1], 6)) for c in coords]
|
|
unique_coords = set(coord_pairs)
|
|
report["n_unique_coords"] = len(unique_coords)
|
|
report["n_total_entities"] = len(names)
|
|
|
|
# look up particular MPs
|
|
for q in ("Ouwehand", "Keijzer", "Mona"):
|
|
found = [n for n in names if q.lower() in n.lower()]
|
|
report[f"matches_{q}"] = [info_for(n) for n in found]
|
|
|
|
out_json = os.path.join(args.out, "inspect_axis.json")
|
|
with open(out_json, "w", encoding="utf-8") as f:
|
|
json.dump(report, f, indent=2)
|
|
|
|
logger.info("Wrote inspection to %s", out_json)
|
|
print(json.dumps(report, indent=2))
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|
|
|