You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
214 lines
6.5 KiB
214 lines
6.5 KiB
"""SVD and PCA diagnostics for the political compass pipeline.
|
|
|
|
Produces a small text report and JSON summary in the outputs/ directory.
|
|
|
|
Usage:
|
|
uv run python3 scripts/svd_diagnostics.py --db data/motions.db --out outputs
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
from statistics import mean
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
if ROOT not in sys.path:
|
|
sys.path.insert(0, ROOT)
|
|
|
|
logger = logging.getLogger("svd_diagnostics")
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
|
|
|
|
def find_by_substring(names: List[str], query: str) -> List[str]:
|
|
q = query.lower()
|
|
return [n for n in names if q in n.lower()]
|
|
|
|
|
|
def main(argv: Optional[list] = None):
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--db", default="data/motions.db")
|
|
p.add_argument("--out", default="outputs")
|
|
args = p.parse_args(argv)
|
|
|
|
os.makedirs(args.out, exist_ok=True)
|
|
|
|
try:
|
|
from analysis import trajectory as traj
|
|
from analysis.political_axis import compute_2d_axes
|
|
from analysis.visualize import _load_party_map
|
|
except Exception as e: # pragma: no cover - runtime
|
|
logger.exception("Could not import analysis modules: %s", e)
|
|
raise
|
|
|
|
# Load windows and aligned vectors
|
|
window_ids = traj._load_window_ids(args.db)
|
|
if not window_ids:
|
|
logger.error("No SVD windows found in DB %s", args.db)
|
|
return 1
|
|
|
|
logger.info("Found windows: %s", window_ids)
|
|
|
|
raw_window_vecs = {
|
|
wid: traj._load_mp_vectors_for_window(args.db, wid) for wid in window_ids
|
|
}
|
|
aligned_window_vecs = traj._procrustes_align_windows(raw_window_vecs)
|
|
|
|
# Compute global PCA axes (residual and non-residual) for comparison
|
|
positions_residual, axes_residual = compute_2d_axes(
|
|
args.db,
|
|
window_ids=window_ids,
|
|
method="pca",
|
|
normalize_vectors=True,
|
|
pca_residual=True,
|
|
)
|
|
positions_plain, axes_plain = compute_2d_axes(
|
|
args.db,
|
|
window_ids=window_ids,
|
|
method="pca",
|
|
normalize_vectors=True,
|
|
pca_residual=False,
|
|
)
|
|
|
|
out_report = []
|
|
|
|
def add(line: str):
|
|
out_report.append(line)
|
|
logger.info(line)
|
|
|
|
add("PCA diagnostics report")
|
|
add(f"DB: {args.db}")
|
|
add(f"Windows: {window_ids}")
|
|
|
|
add("")
|
|
evr_res = axes_residual.get("explained_variance_ratio") if axes_residual else None
|
|
evr_plain = axes_plain.get("explained_variance_ratio") if axes_plain else None
|
|
add(f"Residual PCA EVR: {evr_res}")
|
|
add(f"Plain PCA EVR: {evr_plain}")
|
|
|
|
# pick latest window for detailed inspection
|
|
latest = sorted(window_ids)[-1]
|
|
add("")
|
|
add(f"Inspecting latest window: {latest}")
|
|
|
|
pos = positions_residual.get(latest, {})
|
|
names = list(pos.keys())
|
|
xs = [v[0] for v in pos.values()]
|
|
ys = [v[1] for v in pos.values()]
|
|
|
|
def stats(arr: List[float]) -> Tuple[float, float]:
|
|
if not arr:
|
|
return 0.0, 0.0
|
|
mn = min(arr)
|
|
mx = max(arr)
|
|
return mn, mx
|
|
|
|
add(f"Entities in latest window: {len(names)}")
|
|
add(f"X range (left-right): {stats(xs)}")
|
|
add(f"Y range (prog-cons): {stats(ys)}")
|
|
# stdevs
|
|
try:
|
|
import numpy as _np
|
|
|
|
x_std = float(_np.std(xs))
|
|
y_std = float(_np.std(ys))
|
|
except Exception:
|
|
x_std = 0.0
|
|
y_std = 0.0
|
|
add(
|
|
f"Std dev X: {x_std:.6f}, Std dev Y: {y_std:.6f} (ratio Y/X = {y_std / (x_std + 1e-12):.3f})"
|
|
)
|
|
|
|
# show extremes on X and Y
|
|
sorted_by_x = sorted(pos.items(), key=lambda kv: kv[1][0])
|
|
sorted_by_y = sorted(pos.items(), key=lambda kv: kv[1][1])
|
|
|
|
add("")
|
|
add("Left-most (by X):")
|
|
for name, (x, y) in sorted_by_x[:8]:
|
|
add(f" {name:40s} x={x:.4f} y={y:.4f}")
|
|
|
|
add("")
|
|
add("Right-most (by X):")
|
|
for name, (x, y) in sorted_by_x[-8:]:
|
|
add(f" {name:40s} x={x:.4f} y={y:.4f}")
|
|
|
|
add("")
|
|
add("Top (conservative) (by Y):")
|
|
for name, (x, y) in sorted_by_y[-8:]:
|
|
add(f" {name:40s} x={x:.4f} y={y:.4f}")
|
|
|
|
add("")
|
|
add("Bottom (progressive) (by Y):")
|
|
for name, (x, y) in sorted_by_y[:8]:
|
|
add(f" {name:40s} x={x:.4f} y={y:.4f}")
|
|
|
|
# Find specific MPs mentioned by user
|
|
matches_ouwehand = find_by_substring(names, "ouwehand")
|
|
matches_mona = find_by_substring(names, "mona")
|
|
add("")
|
|
add(f"Matches for 'Ouwehand': {matches_ouwehand}")
|
|
for n in matches_ouwehand:
|
|
x, y = pos.get(n)
|
|
add(f" {n} -> x={x:.4f} y={y:.4f}")
|
|
add(f"Matches for 'Mona': {matches_mona}")
|
|
for n in matches_mona:
|
|
x, y = pos.get(n)
|
|
add(f" {n} -> x={x:.4f} y={y:.4f}")
|
|
|
|
# Party centroids
|
|
party_map = _load_party_map(args.db)
|
|
parties: Dict[str, List[Tuple[float, float]]] = {}
|
|
for mp, coord in pos.items():
|
|
party = party_map.get(mp)
|
|
if party:
|
|
parties.setdefault(party, []).append(coord)
|
|
party_centroids: Dict[str, Tuple[float, float]] = {}
|
|
for party, coords in parties.items():
|
|
xs_p = [c[0] for c in coords]
|
|
ys_p = [c[1] for c in coords]
|
|
party_centroids[party] = (mean(xs_p), mean(ys_p))
|
|
|
|
add("")
|
|
add(f"Computed {len(party_centroids)} party centroids (from mp_metadata majority)")
|
|
sorted_parties_by_x = sorted(party_centroids.items(), key=lambda kv: kv[1][0])
|
|
add("Party centroids left→right:")
|
|
for p, (x, y) in sorted_parties_by_x:
|
|
add(f" {p:20s} x={x:.4f} y={y:.4f}")
|
|
|
|
sorted_parties_by_y = sorted(party_centroids.items(), key=lambda kv: kv[1][1])
|
|
add("")
|
|
add("Party centroids prog→cons:")
|
|
for p, (x, y) in sorted_parties_by_y:
|
|
add(f" {p:20s} x={x:.4f} y={y:.4f}")
|
|
|
|
# Save report and a small JSON summary
|
|
report_path = os.path.join(args.out, "svd_diagnostics.txt")
|
|
summary_path = os.path.join(args.out, "svd_diagnostics.json")
|
|
with open(report_path, "w", encoding="utf-8") as f:
|
|
f.write("\n".join(out_report))
|
|
|
|
summary = {
|
|
"db": args.db,
|
|
"windows": window_ids,
|
|
"latest_window": latest,
|
|
"evr_residual": evr_res,
|
|
"evr_plain": evr_plain,
|
|
"n_entities_latest": len(names),
|
|
"x_std": x_std,
|
|
"y_std": y_std,
|
|
"party_centroids": party_centroids,
|
|
}
|
|
with open(summary_path, "w", encoding="utf-8") as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
logger.info("Diagnostic report written to %s and %s", report_path, summary_path)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|
|
|